max044 commited on
Commit
2fd5fdc
·
verified ·
1 Parent(s): 97ed96b

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ - fr
5
+ license: apache-2.0
6
+ library_name: transformers
7
+ tags:
8
+ - video-search
9
+ - v-jepa
10
+ - multi-modal
11
+ - temporal-grounding
12
+ - action-retrieval
13
+ datasets:
14
+ - max044/Charades_v1_480
15
+ metrics:
16
+ - loss
17
+ ---
18
+
19
+ # VL-JEPA Custom (V-JEPA 2 + Qwen 2.5 + MiniLM)
20
+
21
+ ## English Description
22
+
23
+ This model is a custom implementation of the **VL-JEPA** (Video-Language Joint
24
+ Embedding Predictive Architecture) inspired by Meta AI's research. It is
25
+ designed for **Temporal Moment Retrieval** (finding specific actions in videos).
26
+
27
+ ### Architecture
28
+
29
+ - **X-Encoder (Video)**: Frozen
30
+ [V-JEPA 2 (ViT-L)](https://huggingface.co/facebook/vjepa2-vitl-fpc64-256).
31
+ - **Predictor (Refinement)**:
32
+ [Qwen 2.5 0.5B](https://huggingface.co/Qwen/Qwen2.5-0.5B) fine-tuned using
33
+ **LoRA** (Low-Rank Adaptation).
34
+ - **Y-Encoder (Text Target)**: Frozen
35
+ [all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2).
36
+
37
+ ### Training Details
38
+
39
+ - **Dataset**:
40
+ [Charades-STA](https://huggingface.co/datasets/max044/Charades_v1_480)
41
+ (Academic dataset for video action localization).
42
+ - **Optimization**: LoRA with $r=64$ and $\alpha=128$, targeting `q_proj` and
43
+ `v_proj` in Qwen.
44
+ - **Learning Rate**: 3e-4 with Cosine Warmup.
45
+ - **Outcome**: Only 0.2% of parameters are trainable, making it extremely
46
+ lightweight to train and run.
47
+
48
+ ---
49
+
50
+ ## Description en Français
51
+
52
+ Ce modèle est une implémentation personnalisée de **VL-JEPA**, inspirée des
53
+ travaux de Meta AI. Il est optimisé pour la recherche d'actions temporelles dans
54
+ les vidéos (**Temporal Moment Retrieval**).
55
+
56
+ ### Architecture
57
+
58
+ - **Encodeur Vidéo (X)** :
59
+ [V-JEPA 2 (ViT-L)](https://huggingface.co/facebook/vjepa2-vitl-fpc64-256)
60
+ gelé.
61
+ - **Prédicteur** : [Qwen 2.5 0.5B](https://huggingface.co/Qwen/Qwen2.5-0.5B)
62
+ adapté avec **LoRA**.
63
+ - **Encodeur Texte (Y)** :
64
+ [all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)
65
+ gelé.
66
+
67
+ ### Détails d'Entraînement
68
+
69
+ - **Dataset** :
70
+ [Charades-STA](https://huggingface.co/datasets/max044/Charades_v1_480).
71
+ - **Méthode** : Entraînement via LoRA ($r=64$, $\alpha=128$).
72
+ - **Coût** : Approche très économique, entraînée pour environ 5$ sur Vast.ai.
73
+
74
+ ## Usage / Utilisation
75
+
76
+ ```python
77
+ import torch
78
+ from vljepa.config import Config
79
+ from vljepa.models import VLJepa
80
+
81
+ # Load model
82
+ config = Config()
83
+ model = VLJepa(config)
84
+ checkpoint = torch.load("best.pth", map_location="cpu")
85
+ model.predictor.load_state_dict(checkpoint["predictor_state_dict"])
86
+ model.y_encoder.projection.load_state_dict(checkpoint["y_projection_state_dict"])
87
+ model.eval()
88
+
89
+ # Localizing an action
90
+ # (Requires preprocessing frames and tokenizing query)
91
+ ```
92
+
93
+ Refer to the source code for full inference pipeline with sliding window and
94
+ NMS.
best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6393f56b7528ad91a3281ebcd0bb368b44dc041a5b50bc7569d466e91e992750
3
+ size 2045205003
original_readme.md ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # VL-JEPA: Simplified Video-Language Alignment
2
+
3
+ A simplified implementation of the Video-Language Joint Embedding Predictive
4
+ Architecture (VL-JEPA) for **Temporal Moment Retrieval** (Temporal Grounding).
5
+
6
+ This project uses **V-JEPA 2** for video understanding and **Qwen 2.5 0.5B** as
7
+ a predictor to align video features with language queries in a high-dimensional
8
+ embedding space.
9
+
10
+ ## 🚀 Architecture
11
+
12
+ The model follows the JEPA framework by aligning video features (X) and text
13
+ descriptions (Y) through a predictor (P):
14
+
15
+ - **X-Encoder (Video)**: Frozen **V-JEPA 2** (ViT-L). High-fidelity hierarchical
16
+ video features.
17
+ - **Y-Encoder (Text)**: Frozen **MiniLM** (all-MiniLM-L6-v2). Compact and
18
+ efficient semantic text embeddings.
19
+ - **Predictor (Alignment)**: **Qwen 2.5 0.5B** with **LoRA** (Low-Rank
20
+ Adaptation). Learns to predict the target text embedding from the joint
21
+ video+query representation.
22
+
23
+ ## 🛠️ Installation
24
+
25
+ This project uses `uv` for lightning-fast dependency management.
26
+
27
+ ```bash
28
+ # Clone the repository
29
+ git clone https://github.com/max044/vl-jepa.git
30
+ cd vl-jepa
31
+
32
+ # Create environment and install dependencies
33
+ uv sync
34
+ ```
35
+
36
+ ## 📊 Data Preparation
37
+
38
+ The model is trained on the **Charades-STA** dataset for temporal grounding.
39
+
40
+ 1. **Videos**: Download
41
+ [Charades v1](https://ai2-public-datasets.s3-us-west-2.amazonaws.com/charades/Charades_v1_480.zip)
42
+ and place them in `data/Charades_v1_480`.
43
+ 2. **Annotations**: Use `download_annotations.py` to download the annotations.
44
+
45
+ Structure:
46
+
47
+ ```text
48
+ data/
49
+ ├── Charades_v1_480/ # Video files (.mp4)
50
+ ├── charades_sta_train.txt
51
+ └── charades_sta_test.txt
52
+ ```
53
+
54
+ ## 🏋️ Training
55
+
56
+ Start training with default hyperparameters:
57
+
58
+ ```bash
59
+ # Regular training (local, MPS/CPU)
60
+ uv run train.py
61
+
62
+ # Debug mode (small subset, only 2 epochs)
63
+ uv run train.py --debug --device mps
64
+ ```
65
+
66
+ ### Key Training Features:
67
+
68
+ - **Bidirectional InfoNCE Loss**: Maximizes mutual information between predicted
69
+ and target embeddings.
70
+ - **LoRA Tuning**: Only 0.2% of the predictor parameters (Qwen) are trained,
71
+ making it extremely memory-efficient.
72
+ - **MPS Support**: Optimized for Mac M1/M2/M3 chips.
73
+ - **W&B Integration**: Full experiment tracking with model versioning.
74
+
75
+ ## ☁️ Cloud GPU Training
76
+
77
+ Train on GPU with [Vast.ai](https://vast.ai/) (~$0.50–2/h for A100/H100).
78
+
79
+ ### Quick Start
80
+
81
+ ```bash
82
+ # 1. On the cloud instance — bootstrap
83
+ curl -sSL https://raw.githubusercontent.com/max044/vl-jepa/main/scripts/bootstrap.sh | bash
84
+
85
+ # 2. Configure W&B
86
+ cd ~/vl-jepa
87
+ cp .env.example .env
88
+ nano .env # Set WANDB_API_KEY (get it at https://wandb.ai/authorize)
89
+
90
+ # 3. Download videos
91
+ wget -P data/ https://ai2-public-datasets.s3-us-west-2.amazonaws.com/charades/Charades_v1_480.zip
92
+ unzip data/Charades_v1_480.zip -d data/
93
+
94
+ or
95
+
96
+ uv run hf download max044/Charades_v1_480 --local-dir data/Charades_v1_480 --repo-type dataset
97
+
98
+ # 4. Launch training
99
+ bash scripts/train_cloud.sh
100
+ ```
101
+
102
+ ### W&B Experiment Tracking
103
+
104
+ All training runs are tracked on [Weights & Biases](https://wandb.ai/):
105
+
106
+ - **Metrics**: loss, InfoNCE, learning rate (per step + per epoch)
107
+ - **System**: GPU utilization, memory usage (automatic)
108
+ - **Model versioning**: checkpoints uploaded as W&B Artifacts (`vl-jepa-best`,
109
+ `vl-jepa-last`) — every version is preserved and downloadable
110
+
111
+ ```bash
112
+ # Train with W&B (default)
113
+ uv run train.py --device cuda --wandb-project vl-jepa
114
+
115
+ # Train without W&B
116
+ uv run train.py --device cuda --no-wandb
117
+
118
+ # Custom W&B run name
119
+ uv run train.py --device cuda --wandb-run-name "exp-lr3e4-bs16"
120
+ ```
121
+
122
+ ### Environment Variables
123
+
124
+ | Variable | Description | Required |
125
+ | --------------- | ---------------------------------------------------- | ------------ |
126
+ | `WANDB_API_KEY` | W&B API key ([get here](https://wandb.ai/authorize)) | For tracking |
127
+ | `WANDB_PROJECT` | W&B project name (default: `vl-jepa`) | No |
128
+ | `WANDB_ENTITY` | W&B team/organization | No |
129
+ | `EPOCHS` | Override epoch count | No |
130
+ | `BATCH_SIZE` | Override batch size | No |
131
+
132
+ ## 🔍 Inference (Moment Retrieval)
133
+
134
+ Once trained, you can use the model to find specific moments in a video based on
135
+ a text query. The script uses a sliding window approach with NMS to find the
136
+ best matching segments.
137
+
138
+ ```bash
139
+ # Example: Local inference
140
+ uv run infer.py \
141
+ --video data/Charades_v1_480/3MSZA.mp4 \
142
+ --query "person turns on the light" \
143
+ --checkpoint checkpoints/best.pth \
144
+ --device mps
145
+ ```
146
+
147
+ ## 🔍 Implementation Details
148
+
149
+ Unlike standard VLM (Visual-Language Models) that use generative heads, this
150
+ VL-JEPA implementation focuses on **embedding alignment**. This makes it an
151
+ order of magnitude faster for retrieval tasks (search) as embeddings can be
152
+ pre-computed and indexed using vector databases (Faiss, Milvus, Chroma).
153
+
154
+ ## 📚 References
155
+
156
+ This implementation is based on the official VL-JEPA paper:
157
+
158
+ ```bibtex
159
+ @misc{chen2026vljepajointembeddingpredictive,
160
+ title={VL-JEPA: Joint Embedding Predictive Architecture for Vision-language},
161
+ author={Delong Chen and Mustafa Shukor and Theo Moutakanni and Willy Chung and Jade Yu and Tejaswi Kasarla and Yejin Bang and Allen Bolourchi and Yann LeCun and Pascale Fung},
162
+ year={2026},
163
+ eprint={2512.10942},
164
+ archivePrefix={arXiv},
165
+ primaryClass={cs.CV},
166
+ url={https://arxiv.org/abs/2512.10942},
167
+ }
168
+ ```
169
+
170
+ ## 📄 License
171
+
172
+ MIT
sample_inference.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import cv2
3
+ import numpy as np
4
+ from PIL import Image
5
+ from vljepa.config import Config
6
+ from vljepa.models import VLJepa
7
+ from vljepa.utils import nms
8
+
9
+ def load_model(checkpoint_path, device="cpu"):
10
+ config = Config()
11
+ config.device = device
12
+ model = VLJepa(config)
13
+
14
+ print(f"Loading weights from {checkpoint_path}...")
15
+ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
16
+ model.predictor.load_state_dict(checkpoint["predictor_state_dict"])
17
+ model.y_encoder.projection.load_state_dict(checkpoint["y_projection_state_dict"])
18
+
19
+ model.eval()
20
+ return model, config
21
+
22
+ def extract_frames(video_path, num_frames=16):
23
+ cap = cv2.VideoCapture(video_path)
24
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
25
+ if total_frames <= 0:
26
+ return []
27
+
28
+ indices = np.linspace(0, total_frames - 1, num_frames).astype(int)
29
+ frames = []
30
+ for idx in indices:
31
+ cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
32
+ ret, frame = cap.read()
33
+ if ret:
34
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
35
+ frames.append(frame)
36
+ cap.release()
37
+ return frames
38
+
39
+ def main():
40
+ device = "cuda" if torch.cuda.is_available() else "cpu"
41
+ checkpoint_path = "best.pth"
42
+ video_path = "sample_video.mp4" # Replace with a real video path
43
+ query = "a person is opening a door"
44
+
45
+ model, config = load_model(checkpoint_path, device)
46
+
47
+ # This is a simplified inference demonstration.
48
+ # In a real scenario, you would use a sliding window approach as seen in infer.py
49
+ print(f"Ready for inference on {device}.")
50
+ print(f"Model architecture: {config.clip_model} + {config.predictor_model} (LoRA) + {config.text_model}")
51
+
52
+ # Example Tokenization
53
+ query_tokens = model.query_encoder.tokenize([query], device=device)
54
+
55
+ # Example Text Encoding
56
+ with torch.no_grad():
57
+ text_embedding = model.encode_text([query], device=device)
58
+
59
+ print(f"Query: '{query}'")
60
+ print(f"Text embedding shape: {text_embedding.shape}")
61
+ print("\nTo perform full temporal localization, use the infer.py script which implements sliding window and NMS.")
62
+
63
+ if __name__ == "__main__":
64
+ main()
vljepa/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """VL-JEPA: Simplified Video-Language Joint Embedding Predictive Architecture."""
2
+
3
+ from vljepa.config import Config
4
+ from vljepa.models import VLJepa
5
+ from vljepa.losses import vl_jepa_loss
6
+
7
+ __all__ = ["Config", "VLJepa", "vl_jepa_loss"]
vljepa/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (435 Bytes). View file
 
vljepa/__pycache__/config.cpython-313.pyc ADDED
Binary file (3.85 kB). View file
 
vljepa/__pycache__/dataset.cpython-313.pyc ADDED
Binary file (6.86 kB). View file
 
vljepa/__pycache__/losses.cpython-313.pyc ADDED
Binary file (3.71 kB). View file
 
vljepa/__pycache__/models.cpython-313.pyc ADDED
Binary file (14.7 kB). View file
 
vljepa/__pycache__/utils.cpython-313.pyc ADDED
Binary file (5.8 kB). View file
 
vljepa/config.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration for VL-JEPA training and inference."""
2
+
3
+ from dataclasses import dataclass, field
4
+ from pathlib import Path
5
+ import torch
6
+
7
+
8
+ @dataclass
9
+ class Config:
10
+ """All hyperparameters and paths for VL-JEPA."""
11
+
12
+ # ── Device ──────────────────────────────────────────────
13
+ device: str = "" # auto-detected if empty
14
+
15
+ # ── Model ────────────────────────────────────────────────────────────
16
+ # X-Encoder: V-JEPA 2 ViT-L (frozen, ~300M)
17
+ clip_model: str = "facebook/vjepa2-vitl-fpc64-256"
18
+
19
+ # Predictor: Qwen 2.5 0.5B (LoRA)
20
+ predictor_model: str = "Qwen/Qwen2.5-0.5B"
21
+ use_lora: bool = True
22
+ lora_r: int = 64
23
+ lora_alpha: int = 128
24
+ lora_dropout: float = 0.05
25
+ lora_target_modules: list[str] = field(default_factory=lambda: ["q_proj", "v_proj"])
26
+
27
+ # Y-Encoder: MiniLM (frozen, ~22M)
28
+ text_model: str = "sentence-transformers/all-MiniLM-L6-v2"
29
+
30
+ # Embedding and model dimensions
31
+ x_dim: int = 1024 # V-JEPA ViT-L output dim
32
+ predictor_dim: int = 896 # Qwen 2.5 0.5B hidden dim
33
+ text_dim: int = 384 # MiniLM-L6-v2 output dim
34
+ embed_dim: int = 384 # Shared projection target
35
+
36
+ # ── Video ────────────────────────────────────────────────────────────
37
+ num_frames: int = 16
38
+ frame_size: int = 224 # V-JEPA input resolution
39
+
40
+ # ── Training ─────────────────────────────────────────────────────────
41
+ batch_size: int = 4 # Start small (increase if GPU RAM allows)
42
+ lr: float = 3e-4
43
+ weight_decay: float = 0.01
44
+ epochs: int = 20
45
+ warmup_steps: int = 200
46
+ grad_clip: float = 1.0
47
+
48
+ # Loss
49
+ temperature: float = 0.07
50
+ sigreg_weight: float = 0.1
51
+
52
+ # ── Data ────────────────────────────────────────────────
53
+ data_dir: str = "./data"
54
+ videos_dir: str = "./data/Charades_v1_480"
55
+ anno_train: str = "./data/charades_sta_train.txt"
56
+ anno_test: str = "./data/charades_sta_test.txt"
57
+ hf_dataset_id: str = "max044/Charades_v1_480"
58
+
59
+ # ── Checkpoints ─────────────────────────────────────────
60
+ checkpoint_dir: str = "./checkpoints"
61
+ save_every: int = 2 # save checkpoint every N epochs
62
+ val_every: int = 2 # run validation every N epochs
63
+ val_samples: int = 500 # limit validation samples for speed
64
+
65
+ # ── Inference ───────────────────────────────────────────
66
+ window_sizes: list[float] = field(default_factory=lambda: [2.0, 4.0, 8.0, 16.0])
67
+ window_stride: float = 1.0
68
+ nms_threshold: float = 0.5
69
+ top_k: int = 5
70
+
71
+ # ── Debug ───────────────────────────────────────────────
72
+ debug: bool = False
73
+ debug_samples: int = 100
74
+ num_workers: int = 0 # 0 for MPS compatibility
75
+
76
+ def __post_init__(self):
77
+ if not self.device:
78
+ if torch.cuda.is_available():
79
+ self.device = "cuda"
80
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
81
+ self.device = "mps"
82
+ else:
83
+ self.device = "cpu"
84
+
85
+ # Ensure directories exist
86
+ Path(self.checkpoint_dir).mkdir(parents=True, exist_ok=True)
87
+ Path(self.data_dir).mkdir(parents=True, exist_ok=True)
vljepa/dataset.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Charades-STA dataset for VL-JEPA training."""
2
+
3
+ import os
4
+ import numpy as np
5
+ import torch
6
+ from torch.utils.data import Dataset
7
+
8
+ from vljepa.config import Config
9
+ from vljepa.utils import load_video_frames
10
+
11
+ try:
12
+ from huggingface_hub import hf_hub_download
13
+ HAS_HF_HUB = True
14
+ except ImportError:
15
+ HAS_HF_HUB = False
16
+
17
+
18
+ class CharadesSTADataset(Dataset):
19
+ """Dataset for Charades-STA temporal grounding.
20
+
21
+ Annotation format: video_id start end##sentence
22
+ Example: 3MSZA 24.3 30.4##person turn a light on
23
+
24
+ For training, the query is a neutral prompt ("What is happening in this video?")
25
+ and the target is the ground-truth caption.
26
+ """
27
+
28
+ NEUTRAL_QUERIES = [
29
+ "What is happening in this video?",
30
+ "Describe this video clip.",
31
+ "What action is being performed?",
32
+ ]
33
+
34
+ def __init__(
35
+ self,
36
+ anno_file: str,
37
+ videos_dir: str,
38
+ config: Config,
39
+ split: str = "train",
40
+ ):
41
+ self.videos_dir = videos_dir
42
+ self.config = config
43
+ self.split = split
44
+ self.samples = []
45
+
46
+ self._load_annotations(anno_file)
47
+
48
+ if config.debug:
49
+ self.samples = self.samples[: config.debug_samples]
50
+
51
+ print(f"[{split}] Loaded {len(self.samples)} samples")
52
+
53
+ def _load_annotations(self, anno_file: str):
54
+ """Parse Charades-STA annotation file."""
55
+ if not os.path.exists(anno_file):
56
+ # Try loading from HuggingFace datasets
57
+ self._load_from_hf()
58
+ return
59
+
60
+ with open(anno_file, "r") as f:
61
+ for line in f:
62
+ line = line.strip()
63
+ if not line:
64
+ continue
65
+
66
+ # Format: video_id start end##sentence
67
+ parts = line.split("##")
68
+ if len(parts) < 2:
69
+ continue
70
+
71
+ meta = parts[0].strip().split()
72
+ sentence = parts[1].strip()
73
+
74
+ if len(meta) < 3:
75
+ continue
76
+
77
+ video_id = meta[0]
78
+ start = float(meta[1])
79
+ end = float(meta[2])
80
+
81
+ video_path = os.path.join(self.videos_dir, f"{video_id}.mp4")
82
+
83
+ # If streaming/lazy loading is enabled, we add even if not local
84
+ if os.path.exists(video_path) or self.config.hf_dataset_id:
85
+ self.samples.append({
86
+ "video_path": video_path,
87
+ "video_id": video_id,
88
+ "start": start,
89
+ "end": end,
90
+ "caption": sentence,
91
+ })
92
+
93
+ def _load_from_hf(self):
94
+ """Fallback: load annotations from HuggingFace datasets."""
95
+ try:
96
+ from datasets import load_dataset
97
+
98
+ print("Loading annotations from HuggingFace (lmms-lab/charades_sta)...")
99
+ ds = load_dataset("lmms-lab/charades_sta", split="test")
100
+
101
+ for item in ds:
102
+ video_id = item.get("video_id") or item.get("video", "")
103
+ start = float(item.get("start", 0))
104
+ end = float(item.get("end", 10))
105
+ caption = item.get("query", "") or item.get("description", "")
106
+
107
+ video_path = os.path.join(self.videos_dir, f"{video_id}.mp4")
108
+ if os.path.exists(video_path) and caption:
109
+ self.samples.append({
110
+ "video_path": video_path,
111
+ "video_id": video_id,
112
+ "start": start,
113
+ "end": end,
114
+ "caption": caption,
115
+ })
116
+
117
+ except Exception as e:
118
+ print(f"Failed to load from HuggingFace: {e}")
119
+ print("Please download annotations manually. See download_annotations.py")
120
+
121
+ def __len__(self):
122
+ return len(self.samples)
123
+
124
+ def __getitem__(self, idx: int) -> dict | None:
125
+ sample = self.samples[idx]
126
+ video_path = sample["video_path"]
127
+
128
+ # ── Lazy Loading from HF ────────────────────────────
129
+ if not os.path.exists(video_path) and self.config.hf_dataset_id:
130
+ if HAS_HF_HUB:
131
+ try:
132
+ # Download only the specific file needed
133
+ video_path = hf_hub_download(
134
+ repo_id=self.config.hf_dataset_id,
135
+ filename=f"{sample['video_id']}.mp4",
136
+ repo_type="dataset",
137
+ local_dir=self.videos_dir, # Cache it in the videos dir
138
+ )
139
+ except Exception as e:
140
+ print(f"Error downloading {sample['video_id']}: {e}")
141
+ return None
142
+ else:
143
+ print("Error: huggingface_hub not installed, cannot lazy load.")
144
+ return None
145
+
146
+ # Load frames from the annotated temporal segment
147
+ frames = load_video_frames(
148
+ video_path,
149
+ start_sec=sample["start"],
150
+ end_sec=sample["end"],
151
+ num_frames=self.config.num_frames,
152
+ )
153
+
154
+ if frames is None or len(frames) == 0:
155
+ return None
156
+
157
+ # Use a neutral query for training
158
+ # (VL-JEPA learns to predict the target caption embedding from video + query)
159
+ query_idx = idx % len(self.NEUTRAL_QUERIES)
160
+ query = self.NEUTRAL_QUERIES[query_idx]
161
+
162
+ return {
163
+ "frames": frames, # list of numpy arrays (H, W, 3)
164
+ "query": query, # neutral text query
165
+ "caption": sample["caption"], # target caption
166
+ "video_id": sample["video_id"],
167
+ "start": sample["start"],
168
+ "end": sample["end"],
169
+ }
170
+
171
+
172
+ def collate_fn(batch: list[dict | None]) -> dict | None:
173
+ """Custom collate that filters out None samples."""
174
+ batch = [b for b in batch if b is not None]
175
+ if len(batch) == 0:
176
+ return None
177
+
178
+ return {
179
+ "frames": [b["frames"] for b in batch],
180
+ "queries": [b["query"] for b in batch],
181
+ "captions": [b["caption"] for b in batch],
182
+ "video_ids": [b["video_id"] for b in batch],
183
+ "starts": [b["start"] for b in batch],
184
+ "ends": [b["end"] for b in batch],
185
+ }
vljepa/losses.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Loss functions for VL-JEPA: bidirectional InfoNCE + SIGReg regularization."""
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+
7
+ def infonce_bidirectional(
8
+ pred: torch.Tensor,
9
+ target: torch.Tensor,
10
+ temperature: float = 0.07,
11
+ ) -> torch.Tensor:
12
+ """Symmetric InfoNCE loss between predicted and target embeddings.
13
+
14
+ Args:
15
+ pred: predicted embeddings (B, D), L2-normalized inside.
16
+ target: target embeddings (B, D), L2-normalized inside.
17
+ temperature: scaling factor for logits.
18
+
19
+ Returns:
20
+ Scalar loss (average of forward + backward directions).
21
+ """
22
+ pred = F.normalize(pred, dim=-1)
23
+ target = F.normalize(target, dim=-1)
24
+
25
+ # Cosine similarity matrix (B, B)
26
+ logits = pred @ target.T / temperature
27
+
28
+ labels = torch.arange(pred.size(0), device=pred.device)
29
+ loss_fwd = F.cross_entropy(logits, labels)
30
+ loss_bwd = F.cross_entropy(logits.T, labels)
31
+
32
+ return (loss_fwd + loss_bwd) / 2
33
+
34
+
35
+ def sigreg_loss(
36
+ embeddings: torch.Tensor,
37
+ lambda_reg: float = 0.1,
38
+ ) -> torch.Tensor:
39
+ """Regularize embeddings towards unit-variance isotropic distribution.
40
+
41
+ Simplified SIGReg: penalizes deviation of the covariance from identity.
42
+ """
43
+ if embeddings.size(0) < 2:
44
+ return torch.tensor(0.0, device=embeddings.device)
45
+
46
+ # Center
47
+ embeddings = embeddings - embeddings.mean(dim=0, keepdim=True)
48
+
49
+ # Covariance (D, D)
50
+ B, D = embeddings.shape
51
+ cov = (embeddings.T @ embeddings) / (B - 1)
52
+
53
+ # Variance: encourage diagonal to be 1
54
+ var_loss = F.relu(1.0 - cov.diagonal()).mean()
55
+
56
+ # Covariance: decorrelate off-diagonal
57
+ off_diag = cov - torch.diag(cov.diagonal())
58
+ cov_loss = (off_diag ** 2).mean()
59
+
60
+ return lambda_reg * (var_loss + cov_loss)
61
+
62
+
63
+ def vl_jepa_loss(
64
+ pred: torch.Tensor,
65
+ target: torch.Tensor,
66
+ temperature: float = 0.07,
67
+ sigreg_weight: float = 0.1,
68
+ ) -> tuple[torch.Tensor, dict[str, float]]:
69
+ """Combined VL-JEPA training loss.
70
+
71
+ Returns:
72
+ total_loss: scalar tensor for backprop.
73
+ metrics: dict with breakdown of loss components.
74
+ """
75
+ align = infonce_bidirectional(pred, target, temperature)
76
+ reg_pred = sigreg_loss(pred, sigreg_weight)
77
+ reg_target = sigreg_loss(target, sigreg_weight)
78
+
79
+ total = align + reg_pred + reg_target
80
+
81
+ metrics = {
82
+ "loss/total": total.item(),
83
+ "loss/infonce": align.item(),
84
+ "loss/sigreg_pred": reg_pred.item(),
85
+ "loss/sigreg_target": reg_target.item(),
86
+ }
87
+
88
+ return total, metrics
vljepa/models.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """VL-JEPA model components: V-JEPA 2 (X-Encoder), Qwen 2.5 (Predictor), MiniLM (Y-Encoder)."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from transformers import AutoModel, AutoTokenizer
7
+ from peft import get_peft_model, LoraConfig, TaskType
8
+ from sentence_transformers import SentenceTransformer
9
+ import numpy as np
10
+
11
+ from vljepa.config import Config
12
+
13
+
14
+ class XEncoder(nn.Module):
15
+ """Frozen V-JEPA 2 Video Encoder.
16
+
17
+ Extracts hierarchical video features.
18
+ """
19
+
20
+ def __init__(self, config: Config):
21
+ super().__init__()
22
+ # Load V-JEPA 2 model
23
+ try:
24
+ self.model = AutoModel.from_pretrained(config.clip_model, trust_remote_code=True)
25
+ except Exception:
26
+ print(f"Warning: Failed to load {config.clip_model}. Trying fallback 'facebook/vjepa-vit-h-14-224'.")
27
+ self.model = AutoModel.from_pretrained("facebook/vjepa-vit-h-14-224", trust_remote_code=True)
28
+ config.x_dim = self.model.config.hidden_size
29
+
30
+ # Freeze
31
+ for p in self.model.parameters():
32
+ p.requires_grad = False
33
+ self.model.eval()
34
+
35
+ # Move to device if needed
36
+ self.model.to(config.device)
37
+
38
+ self.hidden_size = config.x_dim
39
+
40
+ @torch.no_grad()
41
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
42
+ """Encode video frames.
43
+
44
+ Args:
45
+ pixel_values: (B, C, T, H, W) preprocessed frames (0-1 float, normalized)
46
+ """
47
+ if pixel_values.shape[1] == 3 and pixel_values.shape[2] > 3:
48
+ # (B, C, T, H, W) -> (B, T, C, H, W)
49
+ pixel_values = pixel_values.permute(0, 2, 1, 3, 4)
50
+
51
+ try:
52
+ outputs = self.model(pixel_values_videos=pixel_values)
53
+ except TypeError:
54
+ # Fallback
55
+ outputs = self.model(pixel_values=pixel_values)
56
+
57
+ last_hidden = outputs.last_hidden_state # (B, seq_len, hidden)
58
+ sv = last_hidden.mean(dim=1) # (B, hidden)
59
+ return sv
60
+
61
+ def preprocess_frames(self, frames_batch: list[list], device: str = "cpu") -> torch.Tensor:
62
+ """Preprocess frames."""
63
+ mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1, 1)
64
+ std = torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1, 1)
65
+
66
+ padded = []
67
+ for frames in frames_batch:
68
+ if len(frames) == 0:
69
+ t = torch.zeros((16, 3, 224, 224), device=device)
70
+ padded.append(t)
71
+ continue
72
+
73
+ # Stack to (T, H, W, 3)
74
+ t = torch.tensor(np.stack(frames), dtype=torch.float32, device=device)
75
+
76
+ # Permute to (T, 3, H, W)
77
+ t = t.permute(0, 3, 1, 2) / 255.0
78
+
79
+ # Resize
80
+ t = F.interpolate(t, size=(224, 224), mode='bilinear', align_corners=False)
81
+
82
+ padded.append(t)
83
+
84
+ max_t = max((t.size(0) for t in padded), default=16)
85
+ final_padded = []
86
+ for t in padded:
87
+ if t.size(0) < max_t:
88
+ pad = t[-1:].expand(max_t - t.size(0), -1, -1, -1)
89
+ t = torch.cat([t, pad], dim=0)
90
+ final_padded.append(t)
91
+
92
+ # Stack -> (B, T, 3, H, W)
93
+ pixel_values = torch.stack(final_padded, dim=0)
94
+
95
+ # Input to V-JEPA 2 (via HF) usually expects (B, T, C, H, W)
96
+
97
+ # Normalize (broadcasting T)
98
+ # mean/std are (1, 3, 1, 1, 1). We need to align with (B, T, 3, H, W)
99
+ # Permute to (B, 3, T, H, W) for normalization
100
+ pixel_values = pixel_values.permute(0, 2, 1, 3, 4)
101
+ pixel_values = (pixel_values - mean) / std
102
+
103
+ # Permute back to (B, T, 3, H, W)
104
+ pixel_values = pixel_values.permute(0, 2, 1, 3, 4)
105
+
106
+ return pixel_values
107
+
108
+
109
+ class QueryEncoder(nn.Module):
110
+ """Tokenizer for Qwen."""
111
+
112
+ def __init__(self, config: Config):
113
+ super().__init__()
114
+ self.tokenizer = AutoTokenizer.from_pretrained(config.predictor_model, trust_remote_code=True)
115
+ if self.tokenizer.pad_token is None:
116
+ self.tokenizer.pad_token = self.tokenizer.eos_token
117
+
118
+ def tokenize(self, texts: list[str], device: str = "cpu") -> dict:
119
+ return self.tokenizer(
120
+ texts, return_tensors="pt", padding=True, truncation=True, max_length=64
121
+ ).to(device)
122
+
123
+
124
+ class Predictor(nn.Module):
125
+ """Qwen 2.5 0.5B Predictor with LoRA."""
126
+
127
+ def __init__(self, config: Config):
128
+ super().__init__()
129
+ self.model = AutoModel.from_pretrained(
130
+ config.predictor_model,
131
+ torch_dtype=torch.float16 if config.device == "cuda" else torch.float32,
132
+ trust_remote_code=True
133
+ )
134
+ if config.use_lora:
135
+ peft_config = LoraConfig(
136
+ task_type=TaskType.FEATURE_EXTRACTION,
137
+ inference_mode=False,
138
+ r=config.lora_r,
139
+ lora_alpha=config.lora_alpha,
140
+ lora_dropout=config.lora_dropout,
141
+ target_modules=config.lora_target_modules
142
+ )
143
+ self.model = get_peft_model(self.model, peft_config)
144
+ self.model.print_trainable_parameters()
145
+
146
+ self.visual_proj = nn.Linear(config.x_dim, config.predictor_dim)
147
+ self.output_proj = nn.Linear(config.predictor_dim, config.embed_dim)
148
+
149
+ # Move to device
150
+ self.to(config.device)
151
+
152
+ def forward(self, sv: torch.Tensor, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
153
+ B = sv.size(0)
154
+ sv_embeds = self.visual_proj(sv).unsqueeze(1) # (B, 1, predictor_dim)
155
+
156
+ if hasattr(self.model, "base_model"):
157
+ base = self.model.base_model.model
158
+ else:
159
+ base = self.model
160
+
161
+ # Qwen2 uses model.embed_tokens
162
+ # We try to access it via property or direct module
163
+ if hasattr(base, "model"):
164
+ embed_layer = base.model.embed_tokens
165
+ elif hasattr(base, "embed_tokens"):
166
+ embed_layer = base.embed_tokens
167
+ else:
168
+ # General fallback for AutoModel
169
+ embed_layer = base.get_input_embeddings()
170
+
171
+ inputs_embeds = embed_layer(input_ids)
172
+ combined_embeds = torch.cat([sv_embeds, inputs_embeds], dim=1)
173
+
174
+ ones = torch.ones((B, 1), device=sv.device, dtype=attention_mask.dtype)
175
+ combined_mask = torch.cat([ones, attention_mask], dim=1)
176
+
177
+ outputs = self.model(inputs_embeds=combined_embeds, attention_mask=combined_mask)
178
+ last_hidden = outputs.last_hidden_state[:, -1, :]
179
+
180
+ return self.output_proj(last_hidden)
181
+
182
+
183
+ class YEncoder(nn.Module):
184
+ """Frozen MiniLM Y-Encoder."""
185
+
186
+ def __init__(self, config: Config):
187
+ super().__init__()
188
+ self.model = SentenceTransformer(config.text_model)
189
+ self.projection = nn.Linear(config.text_dim, config.embed_dim)
190
+
191
+ for p in self.model.parameters():
192
+ p.requires_grad = False
193
+ self.model.eval()
194
+
195
+ def forward(self, texts: list[str], device: str = "cpu") -> torch.Tensor:
196
+ with torch.no_grad():
197
+ embeddings = self.model.encode(texts, convert_to_tensor=True, device=device)
198
+ # Clone to avoid "Inference tensors cannot be saved for backward" error
199
+ return self.projection(embeddings.clone())
200
+
201
+
202
+ class VLJepa(nn.Module):
203
+ """V-JEPA 2 + Qwen 2.5 + MiniLM."""
204
+
205
+ def __init__(self, config: Config):
206
+ super().__init__()
207
+ self.config = config
208
+ self.x_encoder = XEncoder(config)
209
+ self.query_encoder = QueryEncoder(config)
210
+ self.predictor = Predictor(config)
211
+ self.y_encoder = YEncoder(config)
212
+
213
+ def forward(self, pixel_values, query_ids, query_mask, target_texts):
214
+ sv = self.x_encoder(pixel_values)
215
+ sy_hat = self.predictor(sv, query_ids, query_mask)
216
+ sy = self.y_encoder(target_texts, device=str(pixel_values.device))
217
+ return sy_hat, sy
218
+
219
+ def encode_video_query(self, pixel_values, query_ids, query_mask):
220
+ sv = self.x_encoder(pixel_values)
221
+ sy_hat = self.predictor(sv, query_ids, query_mask)
222
+ return sy_hat
223
+
224
+ def encode_text(self, texts, device="cpu"):
225
+ return self.y_encoder(texts, device=device)
226
+
227
+ def trainable_parameters(self):
228
+ return list(self.predictor.parameters()) + list(self.y_encoder.projection.parameters())
229
+
230
+ def count_parameters(self):
231
+ def _count(m):
232
+ return {
233
+ "total": sum(p.numel() for p in m.parameters()),
234
+ "trainable": sum(p.numel() for p in m.parameters() if p.requires_grad)
235
+ }
236
+ return {
237
+ "x_encoder": _count(self.x_encoder),
238
+ "predictor": _count(self.predictor),
239
+ "y_encoder": _count(self.y_encoder)
240
+ }
vljepa/utils.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions: video I/O, temporal IoU, NMS, sliding windows."""
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+
7
+
8
+ def load_video_frames(
9
+ video_path: str,
10
+ start_sec: float = 0.0,
11
+ end_sec: float | None = None,
12
+ num_frames: int = 16,
13
+ ) -> list[np.ndarray] | None:
14
+ """Load uniformly sampled RGB frames from a video segment.
15
+
16
+ Args:
17
+ video_path: path to .mp4 file
18
+ start_sec: start of segment in seconds
19
+ end_sec: end of segment in seconds (None = end of video)
20
+ num_frames: number of frames to sample
21
+
22
+ Returns:
23
+ List of RGB numpy arrays (H, W, 3), or None on failure.
24
+ """
25
+ cap = cv2.VideoCapture(video_path)
26
+ if not cap.isOpened():
27
+ return None
28
+
29
+ fps = cap.get(cv2.CAP_PROP_FPS)
30
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
31
+
32
+ if fps <= 0 or total_frames <= 0:
33
+ cap.release()
34
+ return None
35
+
36
+ duration = total_frames / fps
37
+ if end_sec is None:
38
+ end_sec = duration
39
+
40
+ start_frame = max(0, int(start_sec * fps))
41
+ end_frame = min(total_frames - 1, int(end_sec * fps))
42
+
43
+ if end_frame <= start_frame:
44
+ cap.release()
45
+ return None
46
+
47
+ n_available = end_frame - start_frame + 1
48
+ n_sample = min(num_frames, n_available)
49
+ indices = np.linspace(start_frame, end_frame, n_sample, dtype=int)
50
+
51
+ frames = []
52
+ for idx in indices:
53
+ cap.set(cv2.CAP_PROP_POS_FRAMES, int(idx))
54
+ ret, frame = cap.read()
55
+ if ret:
56
+ frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
57
+
58
+ cap.release()
59
+
60
+ if len(frames) == 0:
61
+ return None
62
+
63
+ return frames
64
+
65
+
66
+ def get_video_duration(video_path: str) -> float:
67
+ """Get video duration in seconds."""
68
+ cap = cv2.VideoCapture(video_path)
69
+ if not cap.isOpened():
70
+ return 0.0
71
+ fps = cap.get(cv2.CAP_PROP_FPS)
72
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
73
+ cap.release()
74
+ if fps <= 0:
75
+ return 0.0
76
+ return total_frames / fps
77
+
78
+
79
+ def temporal_iou(
80
+ pred_start: float,
81
+ pred_end: float,
82
+ gt_start: float,
83
+ gt_end: float,
84
+ ) -> float:
85
+ """Compute temporal Intersection over Union between two segments."""
86
+ inter_start = max(pred_start, gt_start)
87
+ inter_end = min(pred_end, gt_end)
88
+ inter = max(0.0, inter_end - inter_start)
89
+ union = (pred_end - pred_start) + (gt_end - gt_start) - inter
90
+ if union <= 0:
91
+ return 0.0
92
+ return inter / union
93
+
94
+
95
+ def nms(
96
+ proposals: list[tuple[float, float]],
97
+ scores: list[float],
98
+ iou_threshold: float = 0.5,
99
+ ) -> list[int]:
100
+ """Non-maximum suppression for temporal proposals.
101
+
102
+ Args:
103
+ proposals: list of (start, end) tuples
104
+ scores: corresponding scores
105
+ iou_threshold: suppress proposals with IoU above this
106
+
107
+ Returns:
108
+ List of kept indices (sorted by score descending).
109
+ """
110
+ if len(proposals) == 0:
111
+ return []
112
+
113
+ sorted_idx = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)
114
+ kept = []
115
+
116
+ for i in sorted_idx:
117
+ should_keep = True
118
+ for j in kept:
119
+ iou = temporal_iou(
120
+ proposals[i][0], proposals[i][1],
121
+ proposals[j][0], proposals[j][1],
122
+ )
123
+ if iou > iou_threshold:
124
+ should_keep = False
125
+ break
126
+ if should_keep:
127
+ kept.append(i)
128
+
129
+ return kept
130
+
131
+
132
+ def sliding_window_proposals(
133
+ duration: float,
134
+ window_sizes: list[float],
135
+ stride: float = 1.0,
136
+ ) -> list[tuple[float, float]]:
137
+ """Generate candidate temporal proposals using sliding windows.
138
+
139
+ Args:
140
+ duration: total video duration in seconds
141
+ window_sizes: list of window durations to use
142
+ stride: step size in seconds
143
+
144
+ Returns:
145
+ List of (start, end) proposals.
146
+ """
147
+ proposals = []
148
+ for ws in window_sizes:
149
+ if ws > duration:
150
+ # Single proposal covering the whole video
151
+ proposals.append((0.0, duration))
152
+ continue
153
+ start = 0.0
154
+ while start + ws <= duration + 0.01: # small epsilon for float
155
+ end = min(start + ws, duration)
156
+ proposals.append((start, end))
157
+ start += stride
158
+ return proposals