Upload folder using huggingface_hub
Browse files- README.md +94 -0
- best.pth +3 -0
- original_readme.md +172 -0
- sample_inference.py +64 -0
- vljepa/__init__.py +7 -0
- vljepa/__pycache__/__init__.cpython-313.pyc +0 -0
- vljepa/__pycache__/config.cpython-313.pyc +0 -0
- vljepa/__pycache__/dataset.cpython-313.pyc +0 -0
- vljepa/__pycache__/losses.cpython-313.pyc +0 -0
- vljepa/__pycache__/models.cpython-313.pyc +0 -0
- vljepa/__pycache__/utils.cpython-313.pyc +0 -0
- vljepa/config.py +87 -0
- vljepa/dataset.py +185 -0
- vljepa/losses.py +88 -0
- vljepa/models.py +240 -0
- vljepa/utils.py +158 -0
README.md
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language:
|
| 3 |
+
- en
|
| 4 |
+
- fr
|
| 5 |
+
license: apache-2.0
|
| 6 |
+
library_name: transformers
|
| 7 |
+
tags:
|
| 8 |
+
- video-search
|
| 9 |
+
- v-jepa
|
| 10 |
+
- multi-modal
|
| 11 |
+
- temporal-grounding
|
| 12 |
+
- action-retrieval
|
| 13 |
+
datasets:
|
| 14 |
+
- max044/Charades_v1_480
|
| 15 |
+
metrics:
|
| 16 |
+
- loss
|
| 17 |
+
---
|
| 18 |
+
|
| 19 |
+
# VL-JEPA Custom (V-JEPA 2 + Qwen 2.5 + MiniLM)
|
| 20 |
+
|
| 21 |
+
## English Description
|
| 22 |
+
|
| 23 |
+
This model is a custom implementation of the **VL-JEPA** (Video-Language Joint
|
| 24 |
+
Embedding Predictive Architecture) inspired by Meta AI's research. It is
|
| 25 |
+
designed for **Temporal Moment Retrieval** (finding specific actions in videos).
|
| 26 |
+
|
| 27 |
+
### Architecture
|
| 28 |
+
|
| 29 |
+
- **X-Encoder (Video)**: Frozen
|
| 30 |
+
[V-JEPA 2 (ViT-L)](https://huggingface.co/facebook/vjepa2-vitl-fpc64-256).
|
| 31 |
+
- **Predictor (Refinement)**:
|
| 32 |
+
[Qwen 2.5 0.5B](https://huggingface.co/Qwen/Qwen2.5-0.5B) fine-tuned using
|
| 33 |
+
**LoRA** (Low-Rank Adaptation).
|
| 34 |
+
- **Y-Encoder (Text Target)**: Frozen
|
| 35 |
+
[all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2).
|
| 36 |
+
|
| 37 |
+
### Training Details
|
| 38 |
+
|
| 39 |
+
- **Dataset**:
|
| 40 |
+
[Charades-STA](https://huggingface.co/datasets/max044/Charades_v1_480)
|
| 41 |
+
(Academic dataset for video action localization).
|
| 42 |
+
- **Optimization**: LoRA with $r=64$ and $\alpha=128$, targeting `q_proj` and
|
| 43 |
+
`v_proj` in Qwen.
|
| 44 |
+
- **Learning Rate**: 3e-4 with Cosine Warmup.
|
| 45 |
+
- **Outcome**: Only 0.2% of parameters are trainable, making it extremely
|
| 46 |
+
lightweight to train and run.
|
| 47 |
+
|
| 48 |
+
---
|
| 49 |
+
|
| 50 |
+
## Description en Français
|
| 51 |
+
|
| 52 |
+
Ce modèle est une implémentation personnalisée de **VL-JEPA**, inspirée des
|
| 53 |
+
travaux de Meta AI. Il est optimisé pour la recherche d'actions temporelles dans
|
| 54 |
+
les vidéos (**Temporal Moment Retrieval**).
|
| 55 |
+
|
| 56 |
+
### Architecture
|
| 57 |
+
|
| 58 |
+
- **Encodeur Vidéo (X)** :
|
| 59 |
+
[V-JEPA 2 (ViT-L)](https://huggingface.co/facebook/vjepa2-vitl-fpc64-256)
|
| 60 |
+
gelé.
|
| 61 |
+
- **Prédicteur** : [Qwen 2.5 0.5B](https://huggingface.co/Qwen/Qwen2.5-0.5B)
|
| 62 |
+
adapté avec **LoRA**.
|
| 63 |
+
- **Encodeur Texte (Y)** :
|
| 64 |
+
[all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)
|
| 65 |
+
gelé.
|
| 66 |
+
|
| 67 |
+
### Détails d'Entraînement
|
| 68 |
+
|
| 69 |
+
- **Dataset** :
|
| 70 |
+
[Charades-STA](https://huggingface.co/datasets/max044/Charades_v1_480).
|
| 71 |
+
- **Méthode** : Entraînement via LoRA ($r=64$, $\alpha=128$).
|
| 72 |
+
- **Coût** : Approche très économique, entraînée pour environ 5$ sur Vast.ai.
|
| 73 |
+
|
| 74 |
+
## Usage / Utilisation
|
| 75 |
+
|
| 76 |
+
```python
|
| 77 |
+
import torch
|
| 78 |
+
from vljepa.config import Config
|
| 79 |
+
from vljepa.models import VLJepa
|
| 80 |
+
|
| 81 |
+
# Load model
|
| 82 |
+
config = Config()
|
| 83 |
+
model = VLJepa(config)
|
| 84 |
+
checkpoint = torch.load("best.pth", map_location="cpu")
|
| 85 |
+
model.predictor.load_state_dict(checkpoint["predictor_state_dict"])
|
| 86 |
+
model.y_encoder.projection.load_state_dict(checkpoint["y_projection_state_dict"])
|
| 87 |
+
model.eval()
|
| 88 |
+
|
| 89 |
+
# Localizing an action
|
| 90 |
+
# (Requires preprocessing frames and tokenizing query)
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
Refer to the source code for full inference pipeline with sliding window and
|
| 94 |
+
NMS.
|
best.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6393f56b7528ad91a3281ebcd0bb368b44dc041a5b50bc7569d466e91e992750
|
| 3 |
+
size 2045205003
|
original_readme.md
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# VL-JEPA: Simplified Video-Language Alignment
|
| 2 |
+
|
| 3 |
+
A simplified implementation of the Video-Language Joint Embedding Predictive
|
| 4 |
+
Architecture (VL-JEPA) for **Temporal Moment Retrieval** (Temporal Grounding).
|
| 5 |
+
|
| 6 |
+
This project uses **V-JEPA 2** for video understanding and **Qwen 2.5 0.5B** as
|
| 7 |
+
a predictor to align video features with language queries in a high-dimensional
|
| 8 |
+
embedding space.
|
| 9 |
+
|
| 10 |
+
## 🚀 Architecture
|
| 11 |
+
|
| 12 |
+
The model follows the JEPA framework by aligning video features (X) and text
|
| 13 |
+
descriptions (Y) through a predictor (P):
|
| 14 |
+
|
| 15 |
+
- **X-Encoder (Video)**: Frozen **V-JEPA 2** (ViT-L). High-fidelity hierarchical
|
| 16 |
+
video features.
|
| 17 |
+
- **Y-Encoder (Text)**: Frozen **MiniLM** (all-MiniLM-L6-v2). Compact and
|
| 18 |
+
efficient semantic text embeddings.
|
| 19 |
+
- **Predictor (Alignment)**: **Qwen 2.5 0.5B** with **LoRA** (Low-Rank
|
| 20 |
+
Adaptation). Learns to predict the target text embedding from the joint
|
| 21 |
+
video+query representation.
|
| 22 |
+
|
| 23 |
+
## 🛠️ Installation
|
| 24 |
+
|
| 25 |
+
This project uses `uv` for lightning-fast dependency management.
|
| 26 |
+
|
| 27 |
+
```bash
|
| 28 |
+
# Clone the repository
|
| 29 |
+
git clone https://github.com/max044/vl-jepa.git
|
| 30 |
+
cd vl-jepa
|
| 31 |
+
|
| 32 |
+
# Create environment and install dependencies
|
| 33 |
+
uv sync
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
## 📊 Data Preparation
|
| 37 |
+
|
| 38 |
+
The model is trained on the **Charades-STA** dataset for temporal grounding.
|
| 39 |
+
|
| 40 |
+
1. **Videos**: Download
|
| 41 |
+
[Charades v1](https://ai2-public-datasets.s3-us-west-2.amazonaws.com/charades/Charades_v1_480.zip)
|
| 42 |
+
and place them in `data/Charades_v1_480`.
|
| 43 |
+
2. **Annotations**: Use `download_annotations.py` to download the annotations.
|
| 44 |
+
|
| 45 |
+
Structure:
|
| 46 |
+
|
| 47 |
+
```text
|
| 48 |
+
data/
|
| 49 |
+
├── Charades_v1_480/ # Video files (.mp4)
|
| 50 |
+
├── charades_sta_train.txt
|
| 51 |
+
└── charades_sta_test.txt
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
## 🏋️ Training
|
| 55 |
+
|
| 56 |
+
Start training with default hyperparameters:
|
| 57 |
+
|
| 58 |
+
```bash
|
| 59 |
+
# Regular training (local, MPS/CPU)
|
| 60 |
+
uv run train.py
|
| 61 |
+
|
| 62 |
+
# Debug mode (small subset, only 2 epochs)
|
| 63 |
+
uv run train.py --debug --device mps
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
### Key Training Features:
|
| 67 |
+
|
| 68 |
+
- **Bidirectional InfoNCE Loss**: Maximizes mutual information between predicted
|
| 69 |
+
and target embeddings.
|
| 70 |
+
- **LoRA Tuning**: Only 0.2% of the predictor parameters (Qwen) are trained,
|
| 71 |
+
making it extremely memory-efficient.
|
| 72 |
+
- **MPS Support**: Optimized for Mac M1/M2/M3 chips.
|
| 73 |
+
- **W&B Integration**: Full experiment tracking with model versioning.
|
| 74 |
+
|
| 75 |
+
## ☁️ Cloud GPU Training
|
| 76 |
+
|
| 77 |
+
Train on GPU with [Vast.ai](https://vast.ai/) (~$0.50–2/h for A100/H100).
|
| 78 |
+
|
| 79 |
+
### Quick Start
|
| 80 |
+
|
| 81 |
+
```bash
|
| 82 |
+
# 1. On the cloud instance — bootstrap
|
| 83 |
+
curl -sSL https://raw.githubusercontent.com/max044/vl-jepa/main/scripts/bootstrap.sh | bash
|
| 84 |
+
|
| 85 |
+
# 2. Configure W&B
|
| 86 |
+
cd ~/vl-jepa
|
| 87 |
+
cp .env.example .env
|
| 88 |
+
nano .env # Set WANDB_API_KEY (get it at https://wandb.ai/authorize)
|
| 89 |
+
|
| 90 |
+
# 3. Download videos
|
| 91 |
+
wget -P data/ https://ai2-public-datasets.s3-us-west-2.amazonaws.com/charades/Charades_v1_480.zip
|
| 92 |
+
unzip data/Charades_v1_480.zip -d data/
|
| 93 |
+
|
| 94 |
+
or
|
| 95 |
+
|
| 96 |
+
uv run hf download max044/Charades_v1_480 --local-dir data/Charades_v1_480 --repo-type dataset
|
| 97 |
+
|
| 98 |
+
# 4. Launch training
|
| 99 |
+
bash scripts/train_cloud.sh
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
### W&B Experiment Tracking
|
| 103 |
+
|
| 104 |
+
All training runs are tracked on [Weights & Biases](https://wandb.ai/):
|
| 105 |
+
|
| 106 |
+
- **Metrics**: loss, InfoNCE, learning rate (per step + per epoch)
|
| 107 |
+
- **System**: GPU utilization, memory usage (automatic)
|
| 108 |
+
- **Model versioning**: checkpoints uploaded as W&B Artifacts (`vl-jepa-best`,
|
| 109 |
+
`vl-jepa-last`) — every version is preserved and downloadable
|
| 110 |
+
|
| 111 |
+
```bash
|
| 112 |
+
# Train with W&B (default)
|
| 113 |
+
uv run train.py --device cuda --wandb-project vl-jepa
|
| 114 |
+
|
| 115 |
+
# Train without W&B
|
| 116 |
+
uv run train.py --device cuda --no-wandb
|
| 117 |
+
|
| 118 |
+
# Custom W&B run name
|
| 119 |
+
uv run train.py --device cuda --wandb-run-name "exp-lr3e4-bs16"
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
### Environment Variables
|
| 123 |
+
|
| 124 |
+
| Variable | Description | Required |
|
| 125 |
+
| --------------- | ---------------------------------------------------- | ------------ |
|
| 126 |
+
| `WANDB_API_KEY` | W&B API key ([get here](https://wandb.ai/authorize)) | For tracking |
|
| 127 |
+
| `WANDB_PROJECT` | W&B project name (default: `vl-jepa`) | No |
|
| 128 |
+
| `WANDB_ENTITY` | W&B team/organization | No |
|
| 129 |
+
| `EPOCHS` | Override epoch count | No |
|
| 130 |
+
| `BATCH_SIZE` | Override batch size | No |
|
| 131 |
+
|
| 132 |
+
## 🔍 Inference (Moment Retrieval)
|
| 133 |
+
|
| 134 |
+
Once trained, you can use the model to find specific moments in a video based on
|
| 135 |
+
a text query. The script uses a sliding window approach with NMS to find the
|
| 136 |
+
best matching segments.
|
| 137 |
+
|
| 138 |
+
```bash
|
| 139 |
+
# Example: Local inference
|
| 140 |
+
uv run infer.py \
|
| 141 |
+
--video data/Charades_v1_480/3MSZA.mp4 \
|
| 142 |
+
--query "person turns on the light" \
|
| 143 |
+
--checkpoint checkpoints/best.pth \
|
| 144 |
+
--device mps
|
| 145 |
+
```
|
| 146 |
+
|
| 147 |
+
## 🔍 Implementation Details
|
| 148 |
+
|
| 149 |
+
Unlike standard VLM (Visual-Language Models) that use generative heads, this
|
| 150 |
+
VL-JEPA implementation focuses on **embedding alignment**. This makes it an
|
| 151 |
+
order of magnitude faster for retrieval tasks (search) as embeddings can be
|
| 152 |
+
pre-computed and indexed using vector databases (Faiss, Milvus, Chroma).
|
| 153 |
+
|
| 154 |
+
## 📚 References
|
| 155 |
+
|
| 156 |
+
This implementation is based on the official VL-JEPA paper:
|
| 157 |
+
|
| 158 |
+
```bibtex
|
| 159 |
+
@misc{chen2026vljepajointembeddingpredictive,
|
| 160 |
+
title={VL-JEPA: Joint Embedding Predictive Architecture for Vision-language},
|
| 161 |
+
author={Delong Chen and Mustafa Shukor and Theo Moutakanni and Willy Chung and Jade Yu and Tejaswi Kasarla and Yejin Bang and Allen Bolourchi and Yann LeCun and Pascale Fung},
|
| 162 |
+
year={2026},
|
| 163 |
+
eprint={2512.10942},
|
| 164 |
+
archivePrefix={arXiv},
|
| 165 |
+
primaryClass={cs.CV},
|
| 166 |
+
url={https://arxiv.org/abs/2512.10942},
|
| 167 |
+
}
|
| 168 |
+
```
|
| 169 |
+
|
| 170 |
+
## 📄 License
|
| 171 |
+
|
| 172 |
+
MIT
|
sample_inference.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from vljepa.config import Config
|
| 6 |
+
from vljepa.models import VLJepa
|
| 7 |
+
from vljepa.utils import nms
|
| 8 |
+
|
| 9 |
+
def load_model(checkpoint_path, device="cpu"):
|
| 10 |
+
config = Config()
|
| 11 |
+
config.device = device
|
| 12 |
+
model = VLJepa(config)
|
| 13 |
+
|
| 14 |
+
print(f"Loading weights from {checkpoint_path}...")
|
| 15 |
+
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
|
| 16 |
+
model.predictor.load_state_dict(checkpoint["predictor_state_dict"])
|
| 17 |
+
model.y_encoder.projection.load_state_dict(checkpoint["y_projection_state_dict"])
|
| 18 |
+
|
| 19 |
+
model.eval()
|
| 20 |
+
return model, config
|
| 21 |
+
|
| 22 |
+
def extract_frames(video_path, num_frames=16):
|
| 23 |
+
cap = cv2.VideoCapture(video_path)
|
| 24 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 25 |
+
if total_frames <= 0:
|
| 26 |
+
return []
|
| 27 |
+
|
| 28 |
+
indices = np.linspace(0, total_frames - 1, num_frames).astype(int)
|
| 29 |
+
frames = []
|
| 30 |
+
for idx in indices:
|
| 31 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
|
| 32 |
+
ret, frame = cap.read()
|
| 33 |
+
if ret:
|
| 34 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 35 |
+
frames.append(frame)
|
| 36 |
+
cap.release()
|
| 37 |
+
return frames
|
| 38 |
+
|
| 39 |
+
def main():
|
| 40 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 41 |
+
checkpoint_path = "best.pth"
|
| 42 |
+
video_path = "sample_video.mp4" # Replace with a real video path
|
| 43 |
+
query = "a person is opening a door"
|
| 44 |
+
|
| 45 |
+
model, config = load_model(checkpoint_path, device)
|
| 46 |
+
|
| 47 |
+
# This is a simplified inference demonstration.
|
| 48 |
+
# In a real scenario, you would use a sliding window approach as seen in infer.py
|
| 49 |
+
print(f"Ready for inference on {device}.")
|
| 50 |
+
print(f"Model architecture: {config.clip_model} + {config.predictor_model} (LoRA) + {config.text_model}")
|
| 51 |
+
|
| 52 |
+
# Example Tokenization
|
| 53 |
+
query_tokens = model.query_encoder.tokenize([query], device=device)
|
| 54 |
+
|
| 55 |
+
# Example Text Encoding
|
| 56 |
+
with torch.no_grad():
|
| 57 |
+
text_embedding = model.encode_text([query], device=device)
|
| 58 |
+
|
| 59 |
+
print(f"Query: '{query}'")
|
| 60 |
+
print(f"Text embedding shape: {text_embedding.shape}")
|
| 61 |
+
print("\nTo perform full temporal localization, use the infer.py script which implements sliding window and NMS.")
|
| 62 |
+
|
| 63 |
+
if __name__ == "__main__":
|
| 64 |
+
main()
|
vljepa/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""VL-JEPA: Simplified Video-Language Joint Embedding Predictive Architecture."""
|
| 2 |
+
|
| 3 |
+
from vljepa.config import Config
|
| 4 |
+
from vljepa.models import VLJepa
|
| 5 |
+
from vljepa.losses import vl_jepa_loss
|
| 6 |
+
|
| 7 |
+
__all__ = ["Config", "VLJepa", "vl_jepa_loss"]
|
vljepa/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (435 Bytes). View file
|
|
|
vljepa/__pycache__/config.cpython-313.pyc
ADDED
|
Binary file (3.85 kB). View file
|
|
|
vljepa/__pycache__/dataset.cpython-313.pyc
ADDED
|
Binary file (6.86 kB). View file
|
|
|
vljepa/__pycache__/losses.cpython-313.pyc
ADDED
|
Binary file (3.71 kB). View file
|
|
|
vljepa/__pycache__/models.cpython-313.pyc
ADDED
|
Binary file (14.7 kB). View file
|
|
|
vljepa/__pycache__/utils.cpython-313.pyc
ADDED
|
Binary file (5.8 kB). View file
|
|
|
vljepa/config.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Configuration for VL-JEPA training and inference."""
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclass
|
| 9 |
+
class Config:
|
| 10 |
+
"""All hyperparameters and paths for VL-JEPA."""
|
| 11 |
+
|
| 12 |
+
# ── Device ──────────────────────────────────────────────
|
| 13 |
+
device: str = "" # auto-detected if empty
|
| 14 |
+
|
| 15 |
+
# ── Model ────────────────────────────────────────────────────────────
|
| 16 |
+
# X-Encoder: V-JEPA 2 ViT-L (frozen, ~300M)
|
| 17 |
+
clip_model: str = "facebook/vjepa2-vitl-fpc64-256"
|
| 18 |
+
|
| 19 |
+
# Predictor: Qwen 2.5 0.5B (LoRA)
|
| 20 |
+
predictor_model: str = "Qwen/Qwen2.5-0.5B"
|
| 21 |
+
use_lora: bool = True
|
| 22 |
+
lora_r: int = 64
|
| 23 |
+
lora_alpha: int = 128
|
| 24 |
+
lora_dropout: float = 0.05
|
| 25 |
+
lora_target_modules: list[str] = field(default_factory=lambda: ["q_proj", "v_proj"])
|
| 26 |
+
|
| 27 |
+
# Y-Encoder: MiniLM (frozen, ~22M)
|
| 28 |
+
text_model: str = "sentence-transformers/all-MiniLM-L6-v2"
|
| 29 |
+
|
| 30 |
+
# Embedding and model dimensions
|
| 31 |
+
x_dim: int = 1024 # V-JEPA ViT-L output dim
|
| 32 |
+
predictor_dim: int = 896 # Qwen 2.5 0.5B hidden dim
|
| 33 |
+
text_dim: int = 384 # MiniLM-L6-v2 output dim
|
| 34 |
+
embed_dim: int = 384 # Shared projection target
|
| 35 |
+
|
| 36 |
+
# ── Video ────────────────────────────────────────────────────────────
|
| 37 |
+
num_frames: int = 16
|
| 38 |
+
frame_size: int = 224 # V-JEPA input resolution
|
| 39 |
+
|
| 40 |
+
# ── Training ─────────────────────────────────────────────────────────
|
| 41 |
+
batch_size: int = 4 # Start small (increase if GPU RAM allows)
|
| 42 |
+
lr: float = 3e-4
|
| 43 |
+
weight_decay: float = 0.01
|
| 44 |
+
epochs: int = 20
|
| 45 |
+
warmup_steps: int = 200
|
| 46 |
+
grad_clip: float = 1.0
|
| 47 |
+
|
| 48 |
+
# Loss
|
| 49 |
+
temperature: float = 0.07
|
| 50 |
+
sigreg_weight: float = 0.1
|
| 51 |
+
|
| 52 |
+
# ── Data ────────────────────────────────────────────────
|
| 53 |
+
data_dir: str = "./data"
|
| 54 |
+
videos_dir: str = "./data/Charades_v1_480"
|
| 55 |
+
anno_train: str = "./data/charades_sta_train.txt"
|
| 56 |
+
anno_test: str = "./data/charades_sta_test.txt"
|
| 57 |
+
hf_dataset_id: str = "max044/Charades_v1_480"
|
| 58 |
+
|
| 59 |
+
# ── Checkpoints ─────────────────────────────────────────
|
| 60 |
+
checkpoint_dir: str = "./checkpoints"
|
| 61 |
+
save_every: int = 2 # save checkpoint every N epochs
|
| 62 |
+
val_every: int = 2 # run validation every N epochs
|
| 63 |
+
val_samples: int = 500 # limit validation samples for speed
|
| 64 |
+
|
| 65 |
+
# ── Inference ───────────────────────────────────────────
|
| 66 |
+
window_sizes: list[float] = field(default_factory=lambda: [2.0, 4.0, 8.0, 16.0])
|
| 67 |
+
window_stride: float = 1.0
|
| 68 |
+
nms_threshold: float = 0.5
|
| 69 |
+
top_k: int = 5
|
| 70 |
+
|
| 71 |
+
# ── Debug ───────────────────────────────────────────────
|
| 72 |
+
debug: bool = False
|
| 73 |
+
debug_samples: int = 100
|
| 74 |
+
num_workers: int = 0 # 0 for MPS compatibility
|
| 75 |
+
|
| 76 |
+
def __post_init__(self):
|
| 77 |
+
if not self.device:
|
| 78 |
+
if torch.cuda.is_available():
|
| 79 |
+
self.device = "cuda"
|
| 80 |
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
| 81 |
+
self.device = "mps"
|
| 82 |
+
else:
|
| 83 |
+
self.device = "cpu"
|
| 84 |
+
|
| 85 |
+
# Ensure directories exist
|
| 86 |
+
Path(self.checkpoint_dir).mkdir(parents=True, exist_ok=True)
|
| 87 |
+
Path(self.data_dir).mkdir(parents=True, exist_ok=True)
|
vljepa/dataset.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Charades-STA dataset for VL-JEPA training."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from torch.utils.data import Dataset
|
| 7 |
+
|
| 8 |
+
from vljepa.config import Config
|
| 9 |
+
from vljepa.utils import load_video_frames
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
from huggingface_hub import hf_hub_download
|
| 13 |
+
HAS_HF_HUB = True
|
| 14 |
+
except ImportError:
|
| 15 |
+
HAS_HF_HUB = False
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class CharadesSTADataset(Dataset):
|
| 19 |
+
"""Dataset for Charades-STA temporal grounding.
|
| 20 |
+
|
| 21 |
+
Annotation format: video_id start end##sentence
|
| 22 |
+
Example: 3MSZA 24.3 30.4##person turn a light on
|
| 23 |
+
|
| 24 |
+
For training, the query is a neutral prompt ("What is happening in this video?")
|
| 25 |
+
and the target is the ground-truth caption.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
NEUTRAL_QUERIES = [
|
| 29 |
+
"What is happening in this video?",
|
| 30 |
+
"Describe this video clip.",
|
| 31 |
+
"What action is being performed?",
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
anno_file: str,
|
| 37 |
+
videos_dir: str,
|
| 38 |
+
config: Config,
|
| 39 |
+
split: str = "train",
|
| 40 |
+
):
|
| 41 |
+
self.videos_dir = videos_dir
|
| 42 |
+
self.config = config
|
| 43 |
+
self.split = split
|
| 44 |
+
self.samples = []
|
| 45 |
+
|
| 46 |
+
self._load_annotations(anno_file)
|
| 47 |
+
|
| 48 |
+
if config.debug:
|
| 49 |
+
self.samples = self.samples[: config.debug_samples]
|
| 50 |
+
|
| 51 |
+
print(f"[{split}] Loaded {len(self.samples)} samples")
|
| 52 |
+
|
| 53 |
+
def _load_annotations(self, anno_file: str):
|
| 54 |
+
"""Parse Charades-STA annotation file."""
|
| 55 |
+
if not os.path.exists(anno_file):
|
| 56 |
+
# Try loading from HuggingFace datasets
|
| 57 |
+
self._load_from_hf()
|
| 58 |
+
return
|
| 59 |
+
|
| 60 |
+
with open(anno_file, "r") as f:
|
| 61 |
+
for line in f:
|
| 62 |
+
line = line.strip()
|
| 63 |
+
if not line:
|
| 64 |
+
continue
|
| 65 |
+
|
| 66 |
+
# Format: video_id start end##sentence
|
| 67 |
+
parts = line.split("##")
|
| 68 |
+
if len(parts) < 2:
|
| 69 |
+
continue
|
| 70 |
+
|
| 71 |
+
meta = parts[0].strip().split()
|
| 72 |
+
sentence = parts[1].strip()
|
| 73 |
+
|
| 74 |
+
if len(meta) < 3:
|
| 75 |
+
continue
|
| 76 |
+
|
| 77 |
+
video_id = meta[0]
|
| 78 |
+
start = float(meta[1])
|
| 79 |
+
end = float(meta[2])
|
| 80 |
+
|
| 81 |
+
video_path = os.path.join(self.videos_dir, f"{video_id}.mp4")
|
| 82 |
+
|
| 83 |
+
# If streaming/lazy loading is enabled, we add even if not local
|
| 84 |
+
if os.path.exists(video_path) or self.config.hf_dataset_id:
|
| 85 |
+
self.samples.append({
|
| 86 |
+
"video_path": video_path,
|
| 87 |
+
"video_id": video_id,
|
| 88 |
+
"start": start,
|
| 89 |
+
"end": end,
|
| 90 |
+
"caption": sentence,
|
| 91 |
+
})
|
| 92 |
+
|
| 93 |
+
def _load_from_hf(self):
|
| 94 |
+
"""Fallback: load annotations from HuggingFace datasets."""
|
| 95 |
+
try:
|
| 96 |
+
from datasets import load_dataset
|
| 97 |
+
|
| 98 |
+
print("Loading annotations from HuggingFace (lmms-lab/charades_sta)...")
|
| 99 |
+
ds = load_dataset("lmms-lab/charades_sta", split="test")
|
| 100 |
+
|
| 101 |
+
for item in ds:
|
| 102 |
+
video_id = item.get("video_id") or item.get("video", "")
|
| 103 |
+
start = float(item.get("start", 0))
|
| 104 |
+
end = float(item.get("end", 10))
|
| 105 |
+
caption = item.get("query", "") or item.get("description", "")
|
| 106 |
+
|
| 107 |
+
video_path = os.path.join(self.videos_dir, f"{video_id}.mp4")
|
| 108 |
+
if os.path.exists(video_path) and caption:
|
| 109 |
+
self.samples.append({
|
| 110 |
+
"video_path": video_path,
|
| 111 |
+
"video_id": video_id,
|
| 112 |
+
"start": start,
|
| 113 |
+
"end": end,
|
| 114 |
+
"caption": caption,
|
| 115 |
+
})
|
| 116 |
+
|
| 117 |
+
except Exception as e:
|
| 118 |
+
print(f"Failed to load from HuggingFace: {e}")
|
| 119 |
+
print("Please download annotations manually. See download_annotations.py")
|
| 120 |
+
|
| 121 |
+
def __len__(self):
|
| 122 |
+
return len(self.samples)
|
| 123 |
+
|
| 124 |
+
def __getitem__(self, idx: int) -> dict | None:
|
| 125 |
+
sample = self.samples[idx]
|
| 126 |
+
video_path = sample["video_path"]
|
| 127 |
+
|
| 128 |
+
# ── Lazy Loading from HF ────────────────────────────
|
| 129 |
+
if not os.path.exists(video_path) and self.config.hf_dataset_id:
|
| 130 |
+
if HAS_HF_HUB:
|
| 131 |
+
try:
|
| 132 |
+
# Download only the specific file needed
|
| 133 |
+
video_path = hf_hub_download(
|
| 134 |
+
repo_id=self.config.hf_dataset_id,
|
| 135 |
+
filename=f"{sample['video_id']}.mp4",
|
| 136 |
+
repo_type="dataset",
|
| 137 |
+
local_dir=self.videos_dir, # Cache it in the videos dir
|
| 138 |
+
)
|
| 139 |
+
except Exception as e:
|
| 140 |
+
print(f"Error downloading {sample['video_id']}: {e}")
|
| 141 |
+
return None
|
| 142 |
+
else:
|
| 143 |
+
print("Error: huggingface_hub not installed, cannot lazy load.")
|
| 144 |
+
return None
|
| 145 |
+
|
| 146 |
+
# Load frames from the annotated temporal segment
|
| 147 |
+
frames = load_video_frames(
|
| 148 |
+
video_path,
|
| 149 |
+
start_sec=sample["start"],
|
| 150 |
+
end_sec=sample["end"],
|
| 151 |
+
num_frames=self.config.num_frames,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
if frames is None or len(frames) == 0:
|
| 155 |
+
return None
|
| 156 |
+
|
| 157 |
+
# Use a neutral query for training
|
| 158 |
+
# (VL-JEPA learns to predict the target caption embedding from video + query)
|
| 159 |
+
query_idx = idx % len(self.NEUTRAL_QUERIES)
|
| 160 |
+
query = self.NEUTRAL_QUERIES[query_idx]
|
| 161 |
+
|
| 162 |
+
return {
|
| 163 |
+
"frames": frames, # list of numpy arrays (H, W, 3)
|
| 164 |
+
"query": query, # neutral text query
|
| 165 |
+
"caption": sample["caption"], # target caption
|
| 166 |
+
"video_id": sample["video_id"],
|
| 167 |
+
"start": sample["start"],
|
| 168 |
+
"end": sample["end"],
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def collate_fn(batch: list[dict | None]) -> dict | None:
|
| 173 |
+
"""Custom collate that filters out None samples."""
|
| 174 |
+
batch = [b for b in batch if b is not None]
|
| 175 |
+
if len(batch) == 0:
|
| 176 |
+
return None
|
| 177 |
+
|
| 178 |
+
return {
|
| 179 |
+
"frames": [b["frames"] for b in batch],
|
| 180 |
+
"queries": [b["query"] for b in batch],
|
| 181 |
+
"captions": [b["caption"] for b in batch],
|
| 182 |
+
"video_ids": [b["video_id"] for b in batch],
|
| 183 |
+
"starts": [b["start"] for b in batch],
|
| 184 |
+
"ends": [b["end"] for b in batch],
|
| 185 |
+
}
|
vljepa/losses.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Loss functions for VL-JEPA: bidirectional InfoNCE + SIGReg regularization."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def infonce_bidirectional(
|
| 8 |
+
pred: torch.Tensor,
|
| 9 |
+
target: torch.Tensor,
|
| 10 |
+
temperature: float = 0.07,
|
| 11 |
+
) -> torch.Tensor:
|
| 12 |
+
"""Symmetric InfoNCE loss between predicted and target embeddings.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
pred: predicted embeddings (B, D), L2-normalized inside.
|
| 16 |
+
target: target embeddings (B, D), L2-normalized inside.
|
| 17 |
+
temperature: scaling factor for logits.
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
Scalar loss (average of forward + backward directions).
|
| 21 |
+
"""
|
| 22 |
+
pred = F.normalize(pred, dim=-1)
|
| 23 |
+
target = F.normalize(target, dim=-1)
|
| 24 |
+
|
| 25 |
+
# Cosine similarity matrix (B, B)
|
| 26 |
+
logits = pred @ target.T / temperature
|
| 27 |
+
|
| 28 |
+
labels = torch.arange(pred.size(0), device=pred.device)
|
| 29 |
+
loss_fwd = F.cross_entropy(logits, labels)
|
| 30 |
+
loss_bwd = F.cross_entropy(logits.T, labels)
|
| 31 |
+
|
| 32 |
+
return (loss_fwd + loss_bwd) / 2
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def sigreg_loss(
|
| 36 |
+
embeddings: torch.Tensor,
|
| 37 |
+
lambda_reg: float = 0.1,
|
| 38 |
+
) -> torch.Tensor:
|
| 39 |
+
"""Regularize embeddings towards unit-variance isotropic distribution.
|
| 40 |
+
|
| 41 |
+
Simplified SIGReg: penalizes deviation of the covariance from identity.
|
| 42 |
+
"""
|
| 43 |
+
if embeddings.size(0) < 2:
|
| 44 |
+
return torch.tensor(0.0, device=embeddings.device)
|
| 45 |
+
|
| 46 |
+
# Center
|
| 47 |
+
embeddings = embeddings - embeddings.mean(dim=0, keepdim=True)
|
| 48 |
+
|
| 49 |
+
# Covariance (D, D)
|
| 50 |
+
B, D = embeddings.shape
|
| 51 |
+
cov = (embeddings.T @ embeddings) / (B - 1)
|
| 52 |
+
|
| 53 |
+
# Variance: encourage diagonal to be 1
|
| 54 |
+
var_loss = F.relu(1.0 - cov.diagonal()).mean()
|
| 55 |
+
|
| 56 |
+
# Covariance: decorrelate off-diagonal
|
| 57 |
+
off_diag = cov - torch.diag(cov.diagonal())
|
| 58 |
+
cov_loss = (off_diag ** 2).mean()
|
| 59 |
+
|
| 60 |
+
return lambda_reg * (var_loss + cov_loss)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def vl_jepa_loss(
|
| 64 |
+
pred: torch.Tensor,
|
| 65 |
+
target: torch.Tensor,
|
| 66 |
+
temperature: float = 0.07,
|
| 67 |
+
sigreg_weight: float = 0.1,
|
| 68 |
+
) -> tuple[torch.Tensor, dict[str, float]]:
|
| 69 |
+
"""Combined VL-JEPA training loss.
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
total_loss: scalar tensor for backprop.
|
| 73 |
+
metrics: dict with breakdown of loss components.
|
| 74 |
+
"""
|
| 75 |
+
align = infonce_bidirectional(pred, target, temperature)
|
| 76 |
+
reg_pred = sigreg_loss(pred, sigreg_weight)
|
| 77 |
+
reg_target = sigreg_loss(target, sigreg_weight)
|
| 78 |
+
|
| 79 |
+
total = align + reg_pred + reg_target
|
| 80 |
+
|
| 81 |
+
metrics = {
|
| 82 |
+
"loss/total": total.item(),
|
| 83 |
+
"loss/infonce": align.item(),
|
| 84 |
+
"loss/sigreg_pred": reg_pred.item(),
|
| 85 |
+
"loss/sigreg_target": reg_target.item(),
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
return total, metrics
|
vljepa/models.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""VL-JEPA model components: V-JEPA 2 (X-Encoder), Qwen 2.5 (Predictor), MiniLM (Y-Encoder)."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from transformers import AutoModel, AutoTokenizer
|
| 7 |
+
from peft import get_peft_model, LoraConfig, TaskType
|
| 8 |
+
from sentence_transformers import SentenceTransformer
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
from vljepa.config import Config
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class XEncoder(nn.Module):
|
| 15 |
+
"""Frozen V-JEPA 2 Video Encoder.
|
| 16 |
+
|
| 17 |
+
Extracts hierarchical video features.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self, config: Config):
|
| 21 |
+
super().__init__()
|
| 22 |
+
# Load V-JEPA 2 model
|
| 23 |
+
try:
|
| 24 |
+
self.model = AutoModel.from_pretrained(config.clip_model, trust_remote_code=True)
|
| 25 |
+
except Exception:
|
| 26 |
+
print(f"Warning: Failed to load {config.clip_model}. Trying fallback 'facebook/vjepa-vit-h-14-224'.")
|
| 27 |
+
self.model = AutoModel.from_pretrained("facebook/vjepa-vit-h-14-224", trust_remote_code=True)
|
| 28 |
+
config.x_dim = self.model.config.hidden_size
|
| 29 |
+
|
| 30 |
+
# Freeze
|
| 31 |
+
for p in self.model.parameters():
|
| 32 |
+
p.requires_grad = False
|
| 33 |
+
self.model.eval()
|
| 34 |
+
|
| 35 |
+
# Move to device if needed
|
| 36 |
+
self.model.to(config.device)
|
| 37 |
+
|
| 38 |
+
self.hidden_size = config.x_dim
|
| 39 |
+
|
| 40 |
+
@torch.no_grad()
|
| 41 |
+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
| 42 |
+
"""Encode video frames.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
pixel_values: (B, C, T, H, W) preprocessed frames (0-1 float, normalized)
|
| 46 |
+
"""
|
| 47 |
+
if pixel_values.shape[1] == 3 and pixel_values.shape[2] > 3:
|
| 48 |
+
# (B, C, T, H, W) -> (B, T, C, H, W)
|
| 49 |
+
pixel_values = pixel_values.permute(0, 2, 1, 3, 4)
|
| 50 |
+
|
| 51 |
+
try:
|
| 52 |
+
outputs = self.model(pixel_values_videos=pixel_values)
|
| 53 |
+
except TypeError:
|
| 54 |
+
# Fallback
|
| 55 |
+
outputs = self.model(pixel_values=pixel_values)
|
| 56 |
+
|
| 57 |
+
last_hidden = outputs.last_hidden_state # (B, seq_len, hidden)
|
| 58 |
+
sv = last_hidden.mean(dim=1) # (B, hidden)
|
| 59 |
+
return sv
|
| 60 |
+
|
| 61 |
+
def preprocess_frames(self, frames_batch: list[list], device: str = "cpu") -> torch.Tensor:
|
| 62 |
+
"""Preprocess frames."""
|
| 63 |
+
mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1, 1)
|
| 64 |
+
std = torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1, 1)
|
| 65 |
+
|
| 66 |
+
padded = []
|
| 67 |
+
for frames in frames_batch:
|
| 68 |
+
if len(frames) == 0:
|
| 69 |
+
t = torch.zeros((16, 3, 224, 224), device=device)
|
| 70 |
+
padded.append(t)
|
| 71 |
+
continue
|
| 72 |
+
|
| 73 |
+
# Stack to (T, H, W, 3)
|
| 74 |
+
t = torch.tensor(np.stack(frames), dtype=torch.float32, device=device)
|
| 75 |
+
|
| 76 |
+
# Permute to (T, 3, H, W)
|
| 77 |
+
t = t.permute(0, 3, 1, 2) / 255.0
|
| 78 |
+
|
| 79 |
+
# Resize
|
| 80 |
+
t = F.interpolate(t, size=(224, 224), mode='bilinear', align_corners=False)
|
| 81 |
+
|
| 82 |
+
padded.append(t)
|
| 83 |
+
|
| 84 |
+
max_t = max((t.size(0) for t in padded), default=16)
|
| 85 |
+
final_padded = []
|
| 86 |
+
for t in padded:
|
| 87 |
+
if t.size(0) < max_t:
|
| 88 |
+
pad = t[-1:].expand(max_t - t.size(0), -1, -1, -1)
|
| 89 |
+
t = torch.cat([t, pad], dim=0)
|
| 90 |
+
final_padded.append(t)
|
| 91 |
+
|
| 92 |
+
# Stack -> (B, T, 3, H, W)
|
| 93 |
+
pixel_values = torch.stack(final_padded, dim=0)
|
| 94 |
+
|
| 95 |
+
# Input to V-JEPA 2 (via HF) usually expects (B, T, C, H, W)
|
| 96 |
+
|
| 97 |
+
# Normalize (broadcasting T)
|
| 98 |
+
# mean/std are (1, 3, 1, 1, 1). We need to align with (B, T, 3, H, W)
|
| 99 |
+
# Permute to (B, 3, T, H, W) for normalization
|
| 100 |
+
pixel_values = pixel_values.permute(0, 2, 1, 3, 4)
|
| 101 |
+
pixel_values = (pixel_values - mean) / std
|
| 102 |
+
|
| 103 |
+
# Permute back to (B, T, 3, H, W)
|
| 104 |
+
pixel_values = pixel_values.permute(0, 2, 1, 3, 4)
|
| 105 |
+
|
| 106 |
+
return pixel_values
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class QueryEncoder(nn.Module):
|
| 110 |
+
"""Tokenizer for Qwen."""
|
| 111 |
+
|
| 112 |
+
def __init__(self, config: Config):
|
| 113 |
+
super().__init__()
|
| 114 |
+
self.tokenizer = AutoTokenizer.from_pretrained(config.predictor_model, trust_remote_code=True)
|
| 115 |
+
if self.tokenizer.pad_token is None:
|
| 116 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 117 |
+
|
| 118 |
+
def tokenize(self, texts: list[str], device: str = "cpu") -> dict:
|
| 119 |
+
return self.tokenizer(
|
| 120 |
+
texts, return_tensors="pt", padding=True, truncation=True, max_length=64
|
| 121 |
+
).to(device)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class Predictor(nn.Module):
|
| 125 |
+
"""Qwen 2.5 0.5B Predictor with LoRA."""
|
| 126 |
+
|
| 127 |
+
def __init__(self, config: Config):
|
| 128 |
+
super().__init__()
|
| 129 |
+
self.model = AutoModel.from_pretrained(
|
| 130 |
+
config.predictor_model,
|
| 131 |
+
torch_dtype=torch.float16 if config.device == "cuda" else torch.float32,
|
| 132 |
+
trust_remote_code=True
|
| 133 |
+
)
|
| 134 |
+
if config.use_lora:
|
| 135 |
+
peft_config = LoraConfig(
|
| 136 |
+
task_type=TaskType.FEATURE_EXTRACTION,
|
| 137 |
+
inference_mode=False,
|
| 138 |
+
r=config.lora_r,
|
| 139 |
+
lora_alpha=config.lora_alpha,
|
| 140 |
+
lora_dropout=config.lora_dropout,
|
| 141 |
+
target_modules=config.lora_target_modules
|
| 142 |
+
)
|
| 143 |
+
self.model = get_peft_model(self.model, peft_config)
|
| 144 |
+
self.model.print_trainable_parameters()
|
| 145 |
+
|
| 146 |
+
self.visual_proj = nn.Linear(config.x_dim, config.predictor_dim)
|
| 147 |
+
self.output_proj = nn.Linear(config.predictor_dim, config.embed_dim)
|
| 148 |
+
|
| 149 |
+
# Move to device
|
| 150 |
+
self.to(config.device)
|
| 151 |
+
|
| 152 |
+
def forward(self, sv: torch.Tensor, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
| 153 |
+
B = sv.size(0)
|
| 154 |
+
sv_embeds = self.visual_proj(sv).unsqueeze(1) # (B, 1, predictor_dim)
|
| 155 |
+
|
| 156 |
+
if hasattr(self.model, "base_model"):
|
| 157 |
+
base = self.model.base_model.model
|
| 158 |
+
else:
|
| 159 |
+
base = self.model
|
| 160 |
+
|
| 161 |
+
# Qwen2 uses model.embed_tokens
|
| 162 |
+
# We try to access it via property or direct module
|
| 163 |
+
if hasattr(base, "model"):
|
| 164 |
+
embed_layer = base.model.embed_tokens
|
| 165 |
+
elif hasattr(base, "embed_tokens"):
|
| 166 |
+
embed_layer = base.embed_tokens
|
| 167 |
+
else:
|
| 168 |
+
# General fallback for AutoModel
|
| 169 |
+
embed_layer = base.get_input_embeddings()
|
| 170 |
+
|
| 171 |
+
inputs_embeds = embed_layer(input_ids)
|
| 172 |
+
combined_embeds = torch.cat([sv_embeds, inputs_embeds], dim=1)
|
| 173 |
+
|
| 174 |
+
ones = torch.ones((B, 1), device=sv.device, dtype=attention_mask.dtype)
|
| 175 |
+
combined_mask = torch.cat([ones, attention_mask], dim=1)
|
| 176 |
+
|
| 177 |
+
outputs = self.model(inputs_embeds=combined_embeds, attention_mask=combined_mask)
|
| 178 |
+
last_hidden = outputs.last_hidden_state[:, -1, :]
|
| 179 |
+
|
| 180 |
+
return self.output_proj(last_hidden)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class YEncoder(nn.Module):
|
| 184 |
+
"""Frozen MiniLM Y-Encoder."""
|
| 185 |
+
|
| 186 |
+
def __init__(self, config: Config):
|
| 187 |
+
super().__init__()
|
| 188 |
+
self.model = SentenceTransformer(config.text_model)
|
| 189 |
+
self.projection = nn.Linear(config.text_dim, config.embed_dim)
|
| 190 |
+
|
| 191 |
+
for p in self.model.parameters():
|
| 192 |
+
p.requires_grad = False
|
| 193 |
+
self.model.eval()
|
| 194 |
+
|
| 195 |
+
def forward(self, texts: list[str], device: str = "cpu") -> torch.Tensor:
|
| 196 |
+
with torch.no_grad():
|
| 197 |
+
embeddings = self.model.encode(texts, convert_to_tensor=True, device=device)
|
| 198 |
+
# Clone to avoid "Inference tensors cannot be saved for backward" error
|
| 199 |
+
return self.projection(embeddings.clone())
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class VLJepa(nn.Module):
|
| 203 |
+
"""V-JEPA 2 + Qwen 2.5 + MiniLM."""
|
| 204 |
+
|
| 205 |
+
def __init__(self, config: Config):
|
| 206 |
+
super().__init__()
|
| 207 |
+
self.config = config
|
| 208 |
+
self.x_encoder = XEncoder(config)
|
| 209 |
+
self.query_encoder = QueryEncoder(config)
|
| 210 |
+
self.predictor = Predictor(config)
|
| 211 |
+
self.y_encoder = YEncoder(config)
|
| 212 |
+
|
| 213 |
+
def forward(self, pixel_values, query_ids, query_mask, target_texts):
|
| 214 |
+
sv = self.x_encoder(pixel_values)
|
| 215 |
+
sy_hat = self.predictor(sv, query_ids, query_mask)
|
| 216 |
+
sy = self.y_encoder(target_texts, device=str(pixel_values.device))
|
| 217 |
+
return sy_hat, sy
|
| 218 |
+
|
| 219 |
+
def encode_video_query(self, pixel_values, query_ids, query_mask):
|
| 220 |
+
sv = self.x_encoder(pixel_values)
|
| 221 |
+
sy_hat = self.predictor(sv, query_ids, query_mask)
|
| 222 |
+
return sy_hat
|
| 223 |
+
|
| 224 |
+
def encode_text(self, texts, device="cpu"):
|
| 225 |
+
return self.y_encoder(texts, device=device)
|
| 226 |
+
|
| 227 |
+
def trainable_parameters(self):
|
| 228 |
+
return list(self.predictor.parameters()) + list(self.y_encoder.projection.parameters())
|
| 229 |
+
|
| 230 |
+
def count_parameters(self):
|
| 231 |
+
def _count(m):
|
| 232 |
+
return {
|
| 233 |
+
"total": sum(p.numel() for p in m.parameters()),
|
| 234 |
+
"trainable": sum(p.numel() for p in m.parameters() if p.requires_grad)
|
| 235 |
+
}
|
| 236 |
+
return {
|
| 237 |
+
"x_encoder": _count(self.x_encoder),
|
| 238 |
+
"predictor": _count(self.predictor),
|
| 239 |
+
"y_encoder": _count(self.y_encoder)
|
| 240 |
+
}
|
vljepa/utils.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utility functions: video I/O, temporal IoU, NMS, sliding windows."""
|
| 2 |
+
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def load_video_frames(
|
| 9 |
+
video_path: str,
|
| 10 |
+
start_sec: float = 0.0,
|
| 11 |
+
end_sec: float | None = None,
|
| 12 |
+
num_frames: int = 16,
|
| 13 |
+
) -> list[np.ndarray] | None:
|
| 14 |
+
"""Load uniformly sampled RGB frames from a video segment.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
video_path: path to .mp4 file
|
| 18 |
+
start_sec: start of segment in seconds
|
| 19 |
+
end_sec: end of segment in seconds (None = end of video)
|
| 20 |
+
num_frames: number of frames to sample
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
List of RGB numpy arrays (H, W, 3), or None on failure.
|
| 24 |
+
"""
|
| 25 |
+
cap = cv2.VideoCapture(video_path)
|
| 26 |
+
if not cap.isOpened():
|
| 27 |
+
return None
|
| 28 |
+
|
| 29 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
| 30 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 31 |
+
|
| 32 |
+
if fps <= 0 or total_frames <= 0:
|
| 33 |
+
cap.release()
|
| 34 |
+
return None
|
| 35 |
+
|
| 36 |
+
duration = total_frames / fps
|
| 37 |
+
if end_sec is None:
|
| 38 |
+
end_sec = duration
|
| 39 |
+
|
| 40 |
+
start_frame = max(0, int(start_sec * fps))
|
| 41 |
+
end_frame = min(total_frames - 1, int(end_sec * fps))
|
| 42 |
+
|
| 43 |
+
if end_frame <= start_frame:
|
| 44 |
+
cap.release()
|
| 45 |
+
return None
|
| 46 |
+
|
| 47 |
+
n_available = end_frame - start_frame + 1
|
| 48 |
+
n_sample = min(num_frames, n_available)
|
| 49 |
+
indices = np.linspace(start_frame, end_frame, n_sample, dtype=int)
|
| 50 |
+
|
| 51 |
+
frames = []
|
| 52 |
+
for idx in indices:
|
| 53 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, int(idx))
|
| 54 |
+
ret, frame = cap.read()
|
| 55 |
+
if ret:
|
| 56 |
+
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
| 57 |
+
|
| 58 |
+
cap.release()
|
| 59 |
+
|
| 60 |
+
if len(frames) == 0:
|
| 61 |
+
return None
|
| 62 |
+
|
| 63 |
+
return frames
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def get_video_duration(video_path: str) -> float:
|
| 67 |
+
"""Get video duration in seconds."""
|
| 68 |
+
cap = cv2.VideoCapture(video_path)
|
| 69 |
+
if not cap.isOpened():
|
| 70 |
+
return 0.0
|
| 71 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
| 72 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 73 |
+
cap.release()
|
| 74 |
+
if fps <= 0:
|
| 75 |
+
return 0.0
|
| 76 |
+
return total_frames / fps
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def temporal_iou(
|
| 80 |
+
pred_start: float,
|
| 81 |
+
pred_end: float,
|
| 82 |
+
gt_start: float,
|
| 83 |
+
gt_end: float,
|
| 84 |
+
) -> float:
|
| 85 |
+
"""Compute temporal Intersection over Union between two segments."""
|
| 86 |
+
inter_start = max(pred_start, gt_start)
|
| 87 |
+
inter_end = min(pred_end, gt_end)
|
| 88 |
+
inter = max(0.0, inter_end - inter_start)
|
| 89 |
+
union = (pred_end - pred_start) + (gt_end - gt_start) - inter
|
| 90 |
+
if union <= 0:
|
| 91 |
+
return 0.0
|
| 92 |
+
return inter / union
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def nms(
|
| 96 |
+
proposals: list[tuple[float, float]],
|
| 97 |
+
scores: list[float],
|
| 98 |
+
iou_threshold: float = 0.5,
|
| 99 |
+
) -> list[int]:
|
| 100 |
+
"""Non-maximum suppression for temporal proposals.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
proposals: list of (start, end) tuples
|
| 104 |
+
scores: corresponding scores
|
| 105 |
+
iou_threshold: suppress proposals with IoU above this
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
List of kept indices (sorted by score descending).
|
| 109 |
+
"""
|
| 110 |
+
if len(proposals) == 0:
|
| 111 |
+
return []
|
| 112 |
+
|
| 113 |
+
sorted_idx = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)
|
| 114 |
+
kept = []
|
| 115 |
+
|
| 116 |
+
for i in sorted_idx:
|
| 117 |
+
should_keep = True
|
| 118 |
+
for j in kept:
|
| 119 |
+
iou = temporal_iou(
|
| 120 |
+
proposals[i][0], proposals[i][1],
|
| 121 |
+
proposals[j][0], proposals[j][1],
|
| 122 |
+
)
|
| 123 |
+
if iou > iou_threshold:
|
| 124 |
+
should_keep = False
|
| 125 |
+
break
|
| 126 |
+
if should_keep:
|
| 127 |
+
kept.append(i)
|
| 128 |
+
|
| 129 |
+
return kept
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def sliding_window_proposals(
|
| 133 |
+
duration: float,
|
| 134 |
+
window_sizes: list[float],
|
| 135 |
+
stride: float = 1.0,
|
| 136 |
+
) -> list[tuple[float, float]]:
|
| 137 |
+
"""Generate candidate temporal proposals using sliding windows.
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
duration: total video duration in seconds
|
| 141 |
+
window_sizes: list of window durations to use
|
| 142 |
+
stride: step size in seconds
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
List of (start, end) proposals.
|
| 146 |
+
"""
|
| 147 |
+
proposals = []
|
| 148 |
+
for ws in window_sizes:
|
| 149 |
+
if ws > duration:
|
| 150 |
+
# Single proposal covering the whole video
|
| 151 |
+
proposals.append((0.0, duration))
|
| 152 |
+
continue
|
| 153 |
+
start = 0.0
|
| 154 |
+
while start + ws <= duration + 0.01: # small epsilon for float
|
| 155 |
+
end = min(start + ws, duration)
|
| 156 |
+
proposals.append((start, end))
|
| 157 |
+
start += stride
|
| 158 |
+
return proposals
|