Mirror agentic-intelligence-lab/elephant-embeddings-v1-multimodal-large from ModelScope
Browse filesMirrored from https://modelscope.cn/models/agentic-intelligence-lab/elephant-embeddings-v1-multimodal-large/summary
- .msc +0 -0
- .mv +1 -0
- README.md +191 -0
- config.json +44 -0
- configuration.json +0 -0
- model.pt +3 -0
- src/hf_st_mm/__init__.py +1 -0
- src/hf_st_mm/__pycache__/__init__.cpython-312.pyc +0 -0
- src/hf_st_mm/__pycache__/data.cpython-312.pyc +0 -0
- src/hf_st_mm/__pycache__/model.cpython-312.pyc +0 -0
- src/hf_st_mm/data.py +863 -0
- src/hf_st_mm/model.py +191 -0
.msc
ADDED
|
Binary file (821 Bytes). View file
|
|
|
.mv
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Revision:master,CreatedAt:1778776244
|
README.md
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
library_name: pytorch
|
| 4 |
+
pipeline_tag: sentence-similarity
|
| 5 |
+
language:
|
| 6 |
+
- multilingual
|
| 7 |
+
tags:
|
| 8 |
+
- agentic-intelligence-lab
|
| 9 |
+
- elephant
|
| 10 |
+
- embeddings
|
| 11 |
+
- multimodal
|
| 12 |
+
- retrieval
|
| 13 |
+
- rag
|
| 14 |
+
- agents
|
| 15 |
+
- routing
|
| 16 |
+
- image-text
|
| 17 |
+
- audio-text
|
| 18 |
+
- tri-encoder
|
| 19 |
+
- pytorch
|
| 20 |
+
model-index:
|
| 21 |
+
- name: elephant-embeddings-v1-multimodal-large
|
| 22 |
+
results:
|
| 23 |
+
- task:
|
| 24 |
+
type: sentence-similarity
|
| 25 |
+
dataset:
|
| 26 |
+
name: Internal cached validation set
|
| 27 |
+
type: cached_retrieval_validation
|
| 28 |
+
metrics:
|
| 29 |
+
- name: Eval loss
|
| 30 |
+
type: eval_loss
|
| 31 |
+
value: 0.389702
|
| 32 |
+
- name: Eval top1
|
| 33 |
+
type: eval_top1
|
| 34 |
+
value: 0.861707
|
| 35 |
+
---
|
| 36 |
+
|
| 37 |
+
# Elephant Embeddings V1 Multimodal Large
|
| 38 |
+
|
| 39 |
+
`elephant-embeddings-v1-multimodal-large` is the large multimodal embedding model in the **Agentic Intelligence Lab Elephant Embeddings V1** family.
|
| 40 |
+
|
| 41 |
+
This ModelScope release is maintained by `agentic-intelligence-lab` to make Elephant embedding models easier to download and deploy in mainland China. It mirrors and renames the upstream HuggingFace model `llm-semantic-router/multi-modal-embed-large` under a consistent Elephant model namespace.
|
| 42 |
+
|
| 43 |
+
## Positioning
|
| 44 |
+
|
| 45 |
+
This model is a production-oriented multimodal embedding model for semantic routing, retrieval, and cross-modal matching across text, image, and audio.
|
| 46 |
+
|
| 47 |
+
It is not a generative chat or captioning model. Instead, it maps different modalities into one shared embedding space so agent systems can compare requests, screenshots, documents, and audio records with the same retrieval interface.
|
| 48 |
+
|
| 49 |
+
## Model at a glance
|
| 50 |
+
|
| 51 |
+
| Item | Value |
|
| 52 |
+
| --- | --- |
|
| 53 |
+
| Family | Elephant Embeddings V1 |
|
| 54 |
+
| Maintainer | Agentic Intelligence Lab |
|
| 55 |
+
| Model type | Multimodal embedding model |
|
| 56 |
+
| Modalities | Text, image, audio |
|
| 57 |
+
| Architecture | Custom PyTorch tri-encoder |
|
| 58 |
+
| Text encoder | `llm-semantic-router/mmbert-embed-32k-2d-matryoshka` |
|
| 59 |
+
| Image encoder | `google/siglip2-so400m-patch14-384` |
|
| 60 |
+
| Audio encoder | `openai/whisper-medium` |
|
| 61 |
+
| Embedding dimension | 768 |
|
| 62 |
+
| Max text length | 32,768 tokens |
|
| 63 |
+
| Objective | Cached multiple negatives ranking loss |
|
| 64 |
+
| Upstream source | `llm-semantic-router/multi-modal-embed-large` |
|
| 65 |
+
| License | Apache 2.0 |
|
| 66 |
+
|
| 67 |
+
## Why it fits agentic workloads
|
| 68 |
+
|
| 69 |
+
Agentic products increasingly need to retrieve and route over mixed inputs: user text, screenshots, UI states, documents, voice notes, support calls, and multimodal memory. This model is designed for that operating pattern.
|
| 70 |
+
|
| 71 |
+
Key advantages:
|
| 72 |
+
|
| 73 |
+
- **Shared semantic space**: text, images, and audio can be compared with cosine similarity.
|
| 74 |
+
- **Routing-grade representation**: optimized for retrieval, matching, and routing rather than generation.
|
| 75 |
+
- **Strong modality towers**: uses dedicated text, image, and audio encoders instead of forcing all modalities through a single monolithic checkpoint.
|
| 76 |
+
- **Long-context text path**: supports long tool descriptions, traces, and knowledge chunks through the text encoder.
|
| 77 |
+
- **Production packaging**: includes the custom source package needed to construct and run the tri-encoder.
|
| 78 |
+
|
| 79 |
+
## Recommended use cases
|
| 80 |
+
|
| 81 |
+
| Scenario | Example |
|
| 82 |
+
| --- | --- |
|
| 83 |
+
| Multimodal RAG | Retrieve text notes using an image or audio query |
|
| 84 |
+
| Agent routing | Route screenshots, user text, or voice requests to the right tool or workflow |
|
| 85 |
+
| Memory search | Search mixed text/image/audio memory stores in one vector space |
|
| 86 |
+
| Support and operations | Match tickets, screenshots, logs, and recorded calls semantically |
|
| 87 |
+
| Offline indexing | Build high-quality 768d multimodal indexes |
|
| 88 |
+
|
| 89 |
+
## Quick start on ModelScope
|
| 90 |
+
|
| 91 |
+
```bash
|
| 92 |
+
pip install modelscope torch sentence-transformers transformers accelerate safetensors pillow librosa soundfile
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
```python
|
| 96 |
+
import json
|
| 97 |
+
import os
|
| 98 |
+
import sys
|
| 99 |
+
|
| 100 |
+
import torch
|
| 101 |
+
import torch.nn.functional as F
|
| 102 |
+
from modelscope import snapshot_download
|
| 103 |
+
|
| 104 |
+
repo_id = "agentic-intelligence-lab/elephant-embeddings-v1-multimodal-large"
|
| 105 |
+
local_dir = snapshot_download(repo_id)
|
| 106 |
+
|
| 107 |
+
sys.path.insert(0, os.path.join(local_dir, "src"))
|
| 108 |
+
|
| 109 |
+
from hf_st_mm.data import PairItem
|
| 110 |
+
from hf_st_mm.model import MultiModalSentenceEmbedder
|
| 111 |
+
|
| 112 |
+
with open(os.path.join(local_dir, "config.json"), "r", encoding="utf-8") as handle:
|
| 113 |
+
cfg = json.load(handle)
|
| 114 |
+
|
| 115 |
+
model = MultiModalSentenceEmbedder(
|
| 116 |
+
text_encoder_name=cfg["model"]["text_encoder_name"],
|
| 117 |
+
image_encoder_name=cfg["model"]["image_encoder_name"],
|
| 118 |
+
audio_encoder_name=cfg["model"]["audio_encoder_name"],
|
| 119 |
+
embedding_dim=int(cfg["model"]["embedding_dim"]),
|
| 120 |
+
max_text_length=int(cfg["model"]["max_text_length"]),
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
state_dict = torch.load(os.path.join(local_dir, "model.pt"), map_location="cpu")
|
| 124 |
+
model.load_state_dict(state_dict)
|
| 125 |
+
model.eval()
|
| 126 |
+
|
| 127 |
+
items = [
|
| 128 |
+
PairItem(modality="text", value="route this request to the billing workflow"),
|
| 129 |
+
PairItem(modality="image", value="/path/to/screenshot.png"),
|
| 130 |
+
PairItem(modality="audio", value="/path/to/call.wav"),
|
| 131 |
+
]
|
| 132 |
+
|
| 133 |
+
with torch.no_grad():
|
| 134 |
+
embeddings = model.encode_items(items)
|
| 135 |
+
|
| 136 |
+
print(embeddings.shape) # [3, 768]
|
| 137 |
+
|
| 138 |
+
query = PairItem(modality="text", value="refund request for a wrong charge")
|
| 139 |
+
candidate = PairItem(modality="audio", value="/path/to/refund_call.wav")
|
| 140 |
+
|
| 141 |
+
with torch.no_grad():
|
| 142 |
+
embs = model.encode_items([query, candidate])
|
| 143 |
+
|
| 144 |
+
similarity = F.cosine_similarity(embs[0:1], embs[1:2]).item()
|
| 145 |
+
print(f"similarity={similarity:.4f}")
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
## Evaluation snapshot
|
| 149 |
+
|
| 150 |
+
| Metric | Value |
|
| 151 |
+
| --- | ---: |
|
| 152 |
+
| Eval loss | 0.389702 |
|
| 153 |
+
| Eval top1 | 0.861707 |
|
| 154 |
+
|
| 155 |
+
The validation metrics come from the tri-encoder cached retrieval validation path used during export. They are intended as a release sanity snapshot rather than a public leaderboard claim.
|
| 156 |
+
|
| 157 |
+
## Files
|
| 158 |
+
|
| 159 |
+
| File | Description |
|
| 160 |
+
| --- | --- |
|
| 161 |
+
| `model.pt` | Exported PyTorch weights |
|
| 162 |
+
| `config.json` | Tri-encoder and training/export configuration |
|
| 163 |
+
| `src/hf_st_mm/` | Python package used to construct and run the model |
|
| 164 |
+
| `README.md` | This model card |
|
| 165 |
+
|
| 166 |
+
## Lineage
|
| 167 |
+
|
| 168 |
+
This ModelScope package is published by `agentic-intelligence-lab` as part of the Elephant model release line. It mirrors the upstream HuggingFace model `llm-semantic-router/multi-modal-embed-large` and keeps the model artifacts unchanged except for the repository naming and model card presentation.
|
| 169 |
+
|
| 170 |
+
## Limitations
|
| 171 |
+
|
| 172 |
+
- This is a custom PyTorch tri-encoder export, not a standard Transformers auto-class checkpoint.
|
| 173 |
+
- Inference relies on the packaged `hf_st_mm` source code.
|
| 174 |
+
- Image and audio inputs are expected as local file paths in the simple inference path.
|
| 175 |
+
- The model is optimized for retrieval, routing, and similarity, not generation or captioning.
|
| 176 |
+
- Reported validation metrics come from an internal cached retrieval validation set.
|
| 177 |
+
|
| 178 |
+
## Citation
|
| 179 |
+
|
| 180 |
+
```bibtex
|
| 181 |
+
@misc{elephant-embeddings-v1-multimodal-large,
|
| 182 |
+
title={Elephant Embeddings V1 Multimodal Large},
|
| 183 |
+
author={Agentic Intelligence Lab},
|
| 184 |
+
year={2026},
|
| 185 |
+
url={https://modelscope.cn/models/agentic-intelligence-lab/elephant-embeddings-v1-multimodal-large}
|
| 186 |
+
}
|
| 187 |
+
```
|
| 188 |
+
|
| 189 |
+
## License
|
| 190 |
+
|
| 191 |
+
Apache 2.0
|
config.json
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"seed": 42,
|
| 3 |
+
"output_dir": "/scratch/hf_st_mm_outputs/server_datacenter_8gpu_tri_encoder",
|
| 4 |
+
"model": {
|
| 5 |
+
"text_encoder_name": "llm-semantic-router/mmbert-embed-32k-2d-matryoshka",
|
| 6 |
+
"image_encoder_name": "google/siglip2-so400m-patch14-384",
|
| 7 |
+
"audio_encoder_name": "openai/whisper-medium",
|
| 8 |
+
"embedding_dim": 768,
|
| 9 |
+
"max_text_length": 32768
|
| 10 |
+
},
|
| 11 |
+
"training": {
|
| 12 |
+
"epochs": 10,
|
| 13 |
+
"batch_size": 12,
|
| 14 |
+
"grad_accum_steps": 8,
|
| 15 |
+
"num_workers": 4,
|
| 16 |
+
"prefetch_factor": 4,
|
| 17 |
+
"shard_prefetch": 2,
|
| 18 |
+
"shard_cache_limit": 4,
|
| 19 |
+
"sequential_shard_loading": true,
|
| 20 |
+
"shuffle": false,
|
| 21 |
+
"modality_homogeneous_batches": false,
|
| 22 |
+
"learning_rate": 1e-05,
|
| 23 |
+
"weight_decay": 0.01,
|
| 24 |
+
"warmup_ratio": 0.1,
|
| 25 |
+
"max_grad_norm": 1.0,
|
| 26 |
+
"mixed_precision": "bf16",
|
| 27 |
+
"log_every": 10,
|
| 28 |
+
"save_every": 2000,
|
| 29 |
+
"hard_negative_ratio": 0.5
|
| 30 |
+
},
|
| 31 |
+
"loss": {
|
| 32 |
+
"type": "cached_mnrl",
|
| 33 |
+
"scale": 20.0
|
| 34 |
+
},
|
| 35 |
+
"data": {
|
| 36 |
+
"cache_dir": "/scratch/2dmse-data/server_full_datacenter_cache/train"
|
| 37 |
+
},
|
| 38 |
+
"validation": {
|
| 39 |
+
"cache_dir": "/scratch/2dmse-data/server_full_datacenter_cache/val",
|
| 40 |
+
"num_workers": 2,
|
| 41 |
+
"shard_prefetch": 1,
|
| 42 |
+
"shard_cache_limit": 2
|
| 43 |
+
}
|
| 44 |
+
}
|
configuration.json
ADDED
|
File without changes
|
model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f5fe61d4864fffb703f53860234a657a2f51f71e393e2dc1b7f635b284cb48c4
|
| 3 |
+
size 6393990436
|
src/hf_st_mm/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Standalone HF Sentence-Transformers multimodal training package."""
|
src/hf_st_mm/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (200 Bytes). View file
|
|
|
src/hf_st_mm/__pycache__/data.cpython-312.pyc
ADDED
|
Binary file (45.7 kB). View file
|
|
|
src/hf_st_mm/__pycache__/model.cpython-312.pyc
ADDED
|
Binary file (14.7 kB). View file
|
|
|
src/hf_st_mm/data.py
ADDED
|
@@ -0,0 +1,863 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import math
|
| 3 |
+
import os
|
| 4 |
+
import queue
|
| 5 |
+
import random
|
| 6 |
+
import threading
|
| 7 |
+
from bisect import bisect_right
|
| 8 |
+
from collections import OrderedDict
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from typing import Any, Dict, Iterable, List, Optional
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from datasets import Dataset, Features, IterableDataset, Value
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
SUPPORTED_MODALITIES = {"text", "image", "audio"}
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class PairItem:
|
| 21 |
+
modality: str
|
| 22 |
+
value: Any
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class TrainRecord:
|
| 27 |
+
query: PairItem
|
| 28 |
+
positive: PairItem
|
| 29 |
+
negative: Optional[PairItem] = None
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _parse_item(obj: Any, prefix: str) -> PairItem:
|
| 33 |
+
if isinstance(obj, dict):
|
| 34 |
+
modality = obj.get("type")
|
| 35 |
+
value = obj.get("value")
|
| 36 |
+
else:
|
| 37 |
+
modality = None
|
| 38 |
+
value = None
|
| 39 |
+
|
| 40 |
+
if not modality or not value:
|
| 41 |
+
raise ValueError(f"{prefix} must include type/value")
|
| 42 |
+
if modality not in SUPPORTED_MODALITIES:
|
| 43 |
+
raise ValueError(f"Unsupported modality '{modality}' in {prefix}")
|
| 44 |
+
return PairItem(modality=modality, value=value)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def parse_record(raw: Dict[str, Any]) -> TrainRecord:
|
| 48 |
+
if "query" in raw and "positive" in raw:
|
| 49 |
+
query = _parse_item(raw["query"], "query")
|
| 50 |
+
positive = _parse_item(raw["positive"], "positive")
|
| 51 |
+
negative = _parse_item(raw["negative"], "negative") if raw.get("negative") else None
|
| 52 |
+
return TrainRecord(query=query, positive=positive, negative=negative)
|
| 53 |
+
|
| 54 |
+
# Compatibility with common pair formats in existing repos
|
| 55 |
+
if "texts_a" in raw and "texts_b" in raw:
|
| 56 |
+
query = PairItem("text", raw["texts_a"])
|
| 57 |
+
positive = PairItem("text", raw["texts_b"])
|
| 58 |
+
return TrainRecord(query=query, positive=positive)
|
| 59 |
+
|
| 60 |
+
if "image_path" in raw and "caption" in raw:
|
| 61 |
+
query = PairItem("image", raw["image_path"])
|
| 62 |
+
positive = PairItem("text", raw["caption"])
|
| 63 |
+
return TrainRecord(query=query, positive=positive)
|
| 64 |
+
|
| 65 |
+
if "audio_path" in raw and "caption" in raw:
|
| 66 |
+
query = PairItem("audio", raw["audio_path"])
|
| 67 |
+
positive = PairItem("text", raw["caption"])
|
| 68 |
+
return TrainRecord(query=query, positive=positive)
|
| 69 |
+
|
| 70 |
+
raise ValueError("Record does not match supported schemas")
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class JsonlManifestDataset:
|
| 74 |
+
def __init__(
|
| 75 |
+
self,
|
| 76 |
+
manifest_path: str,
|
| 77 |
+
image_root: Optional[str] = None,
|
| 78 |
+
audio_root: Optional[str] = None,
|
| 79 |
+
allow_missing_negative: bool = True,
|
| 80 |
+
) -> None:
|
| 81 |
+
self.manifest_path = manifest_path
|
| 82 |
+
self.image_root = image_root
|
| 83 |
+
self.audio_root = audio_root
|
| 84 |
+
self.allow_missing_negative = allow_missing_negative
|
| 85 |
+
self.records = list(
|
| 86 |
+
iter_manifest_records(
|
| 87 |
+
manifest_path=self.manifest_path,
|
| 88 |
+
image_root=self.image_root,
|
| 89 |
+
audio_root=self.audio_root,
|
| 90 |
+
allow_missing_negative=self.allow_missing_negative,
|
| 91 |
+
)
|
| 92 |
+
)
|
| 93 |
+
if not self.records:
|
| 94 |
+
raise ValueError(f"No records loaded from {self.manifest_path}")
|
| 95 |
+
|
| 96 |
+
def __len__(self) -> int:
|
| 97 |
+
return len(self.records)
|
| 98 |
+
|
| 99 |
+
def __getitem__(self, idx: int) -> TrainRecord:
|
| 100 |
+
return self.records[idx]
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class CachedShardDataset:
|
| 104 |
+
def __init__(self, cache_dir: str, shard_cache_limit: int = 2, prefetch_shards: int = 0) -> None:
|
| 105 |
+
self.cache_dir = cache_dir
|
| 106 |
+
self.shard_cache_limit = max(int(shard_cache_limit), 1)
|
| 107 |
+
self.prefetch_shards = max(int(prefetch_shards), 0)
|
| 108 |
+
self.metadata = self._load_metadata()
|
| 109 |
+
self.shard_files = self._discover_shards()
|
| 110 |
+
self.shard_sizes = self._resolve_shard_sizes()
|
| 111 |
+
self.shard_offsets = self._build_offsets(self.shard_sizes)
|
| 112 |
+
self.total_rows = sum(self.shard_sizes)
|
| 113 |
+
self._shard_cache: OrderedDict[int, List[Dict[str, Any]]] = OrderedDict()
|
| 114 |
+
self._init_runtime_state()
|
| 115 |
+
|
| 116 |
+
def _init_runtime_state(self) -> None:
|
| 117 |
+
self._cache_lock = threading.Lock()
|
| 118 |
+
self._prefetch_queue = None
|
| 119 |
+
self._prefetch_thread = None
|
| 120 |
+
self._prefetch_stop = threading.Event()
|
| 121 |
+
self._prefetch_requested: set[int] = set()
|
| 122 |
+
self._prefetch_hits = 0
|
| 123 |
+
self._prefetch_misses = 0
|
| 124 |
+
|
| 125 |
+
def __getstate__(self):
|
| 126 |
+
state = self.__dict__.copy()
|
| 127 |
+
state["_shard_cache"] = OrderedDict(state.get("_shard_cache", OrderedDict()))
|
| 128 |
+
state["_cache_lock"] = None
|
| 129 |
+
state["_prefetch_queue"] = None
|
| 130 |
+
state["_prefetch_thread"] = None
|
| 131 |
+
state["_prefetch_stop"] = None
|
| 132 |
+
state["_prefetch_requested"] = set()
|
| 133 |
+
return state
|
| 134 |
+
|
| 135 |
+
def __setstate__(self, state):
|
| 136 |
+
self.__dict__.update(state)
|
| 137 |
+
self._shard_cache = OrderedDict(self._shard_cache)
|
| 138 |
+
self._init_runtime_state()
|
| 139 |
+
|
| 140 |
+
def _load_metadata(self) -> Dict[str, Any]:
|
| 141 |
+
metadata_path = os.path.join(self.cache_dir, "metadata.json")
|
| 142 |
+
if not os.path.exists(metadata_path):
|
| 143 |
+
return {}
|
| 144 |
+
with open(metadata_path, "r", encoding="utf-8") as handle:
|
| 145 |
+
return json.load(handle)
|
| 146 |
+
|
| 147 |
+
def _discover_shards(self) -> List[str]:
|
| 148 |
+
if not os.path.isdir(self.cache_dir):
|
| 149 |
+
raise FileNotFoundError(f"Cache directory not found: {self.cache_dir}")
|
| 150 |
+
shards: List[str] = []
|
| 151 |
+
for name in sorted(os.listdir(self.cache_dir)):
|
| 152 |
+
if not (name.startswith("shard_") and name.endswith(".pt")):
|
| 153 |
+
continue
|
| 154 |
+
shard_path = os.path.join(self.cache_dir, name)
|
| 155 |
+
shards.append(shard_path)
|
| 156 |
+
if not shards:
|
| 157 |
+
raise ValueError(f"No cache shards found under {self.cache_dir}")
|
| 158 |
+
return shards
|
| 159 |
+
|
| 160 |
+
@staticmethod
|
| 161 |
+
def _build_offsets(shard_sizes: List[int]) -> List[int]:
|
| 162 |
+
offsets: List[int] = []
|
| 163 |
+
running_total = 0
|
| 164 |
+
for shard_size in shard_sizes:
|
| 165 |
+
running_total += shard_size
|
| 166 |
+
offsets.append(running_total)
|
| 167 |
+
return offsets
|
| 168 |
+
|
| 169 |
+
def _resolve_shard_sizes(self) -> List[int]:
|
| 170 |
+
num_shards = len(self.shard_files)
|
| 171 |
+
metadata_num_shards = self.metadata.get("num_shards")
|
| 172 |
+
metadata_num_records = self.metadata.get("num_records")
|
| 173 |
+
shard_size = self.metadata.get("shard_size")
|
| 174 |
+
|
| 175 |
+
if (
|
| 176 |
+
isinstance(metadata_num_shards, int)
|
| 177 |
+
and isinstance(metadata_num_records, int)
|
| 178 |
+
and isinstance(shard_size, int)
|
| 179 |
+
and metadata_num_shards == num_shards
|
| 180 |
+
and metadata_num_records > 0
|
| 181 |
+
and shard_size > 0
|
| 182 |
+
):
|
| 183 |
+
shard_sizes = [shard_size] * num_shards
|
| 184 |
+
full_rows_before_last = shard_size * max(num_shards - 1, 0)
|
| 185 |
+
shard_sizes[-1] = metadata_num_records - full_rows_before_last
|
| 186 |
+
if shard_sizes[-1] <= 0:
|
| 187 |
+
raise ValueError(f"Invalid metadata in {self.cache_dir}: last shard size computed as {shard_sizes[-1]}")
|
| 188 |
+
return shard_sizes
|
| 189 |
+
|
| 190 |
+
shard_sizes: List[int] = []
|
| 191 |
+
for shard_path in self.shard_files:
|
| 192 |
+
payload = torch.load(shard_path, map_location="cpu", weights_only=False)
|
| 193 |
+
records = payload.get("records")
|
| 194 |
+
if not isinstance(records, list):
|
| 195 |
+
raise ValueError(f"Invalid shard format in {shard_path}")
|
| 196 |
+
shard_sizes.append(len(records))
|
| 197 |
+
return shard_sizes
|
| 198 |
+
|
| 199 |
+
def _store_shard(self, shard_idx: int, records: List[Dict[str, Any]]) -> None:
|
| 200 |
+
with self._cache_lock:
|
| 201 |
+
self._shard_cache[shard_idx] = records
|
| 202 |
+
self._shard_cache.move_to_end(shard_idx)
|
| 203 |
+
while len(self._shard_cache) > self.shard_cache_limit:
|
| 204 |
+
self._shard_cache.popitem(last=False)
|
| 205 |
+
|
| 206 |
+
def _ensure_prefetch_thread(self) -> None:
|
| 207 |
+
if self.prefetch_shards <= 0:
|
| 208 |
+
return
|
| 209 |
+
if self._prefetch_thread is not None and self._prefetch_thread.is_alive():
|
| 210 |
+
return
|
| 211 |
+
|
| 212 |
+
self._prefetch_stop.clear()
|
| 213 |
+
self._prefetch_queue = queue.Queue(maxsize=max(self.prefetch_shards * 2, 1))
|
| 214 |
+
self._prefetch_thread = threading.Thread(
|
| 215 |
+
target=self._prefetch_worker,
|
| 216 |
+
daemon=True,
|
| 217 |
+
name=f"cached-shard-prefetch-{os.getpid()}",
|
| 218 |
+
)
|
| 219 |
+
self._prefetch_thread.start()
|
| 220 |
+
|
| 221 |
+
def _prefetch_worker(self) -> None:
|
| 222 |
+
while not self._prefetch_stop.is_set():
|
| 223 |
+
try:
|
| 224 |
+
shard_idx = self._prefetch_queue.get(timeout=0.1)
|
| 225 |
+
except queue.Empty:
|
| 226 |
+
continue
|
| 227 |
+
|
| 228 |
+
if shard_idx is None:
|
| 229 |
+
continue
|
| 230 |
+
|
| 231 |
+
try:
|
| 232 |
+
with self._cache_lock:
|
| 233 |
+
if shard_idx in self._shard_cache:
|
| 234 |
+
self._prefetch_hits += 1
|
| 235 |
+
continue
|
| 236 |
+
payload = torch.load(self.shard_files[shard_idx], map_location="cpu", weights_only=False)
|
| 237 |
+
records = payload["records"]
|
| 238 |
+
self._store_shard(shard_idx, records)
|
| 239 |
+
self._prefetch_hits += 1
|
| 240 |
+
finally:
|
| 241 |
+
with self._cache_lock:
|
| 242 |
+
self._prefetch_requested.discard(shard_idx)
|
| 243 |
+
|
| 244 |
+
def _schedule_prefetch(self, shard_idx: int) -> None:
|
| 245 |
+
if self.prefetch_shards <= 0:
|
| 246 |
+
return
|
| 247 |
+
|
| 248 |
+
self._ensure_prefetch_thread()
|
| 249 |
+
if self._prefetch_queue is None:
|
| 250 |
+
return
|
| 251 |
+
|
| 252 |
+
for next_idx in range(shard_idx + 1, min(len(self.shard_files), shard_idx + 1 + self.prefetch_shards)):
|
| 253 |
+
with self._cache_lock:
|
| 254 |
+
if next_idx in self._shard_cache or next_idx in self._prefetch_requested:
|
| 255 |
+
continue
|
| 256 |
+
self._prefetch_requested.add(next_idx)
|
| 257 |
+
try:
|
| 258 |
+
self._prefetch_queue.put_nowait(next_idx)
|
| 259 |
+
except queue.Full:
|
| 260 |
+
with self._cache_lock:
|
| 261 |
+
self._prefetch_requested.discard(next_idx)
|
| 262 |
+
break
|
| 263 |
+
|
| 264 |
+
def _load_shard(self, shard_idx: int) -> List[Dict[str, Any]]:
|
| 265 |
+
cached = None
|
| 266 |
+
with self._cache_lock:
|
| 267 |
+
cached = self._shard_cache.get(shard_idx)
|
| 268 |
+
if cached is not None:
|
| 269 |
+
self._shard_cache.move_to_end(shard_idx)
|
| 270 |
+
if cached is not None:
|
| 271 |
+
self._schedule_prefetch(shard_idx)
|
| 272 |
+
return cached
|
| 273 |
+
|
| 274 |
+
self._prefetch_misses += 1
|
| 275 |
+
payload = torch.load(self.shard_files[shard_idx], map_location="cpu", weights_only=False)
|
| 276 |
+
records = payload["records"]
|
| 277 |
+
self._store_shard(shard_idx, records)
|
| 278 |
+
with self._cache_lock:
|
| 279 |
+
self._prefetch_requested.discard(shard_idx)
|
| 280 |
+
self._schedule_prefetch(shard_idx)
|
| 281 |
+
return records
|
| 282 |
+
|
| 283 |
+
@staticmethod
|
| 284 |
+
def _deserialize_item(raw: Optional[Dict[str, Any]]) -> Optional[PairItem]:
|
| 285 |
+
if raw is None:
|
| 286 |
+
return None
|
| 287 |
+
modality = raw["type"]
|
| 288 |
+
if modality == "text" and "tokens" in raw:
|
| 289 |
+
value = raw["tokens"]
|
| 290 |
+
elif modality == "text":
|
| 291 |
+
value = raw["value"]
|
| 292 |
+
elif "tensor" in raw:
|
| 293 |
+
value = raw["tensor"]
|
| 294 |
+
else:
|
| 295 |
+
value = raw.get("value")
|
| 296 |
+
return PairItem(modality=modality, value=value)
|
| 297 |
+
|
| 298 |
+
def __len__(self) -> int:
|
| 299 |
+
return self.total_rows
|
| 300 |
+
|
| 301 |
+
def __getitem__(self, idx: int) -> TrainRecord:
|
| 302 |
+
if idx < 0 or idx >= self.total_rows:
|
| 303 |
+
raise IndexError(idx)
|
| 304 |
+
shard_idx = bisect_right(self.shard_offsets, idx)
|
| 305 |
+
shard_start = 0 if shard_idx == 0 else self.shard_offsets[shard_idx - 1]
|
| 306 |
+
local_idx = idx - shard_start
|
| 307 |
+
raw = self._load_shard(shard_idx)[local_idx]
|
| 308 |
+
return TrainRecord(
|
| 309 |
+
query=self._deserialize_item(raw["query"]),
|
| 310 |
+
positive=self._deserialize_item(raw["positive"]),
|
| 311 |
+
negative=self._deserialize_item(raw.get("negative")),
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
def get_prefetch_stats(self) -> Dict[str, int]:
|
| 315 |
+
with self._cache_lock:
|
| 316 |
+
return {
|
| 317 |
+
"cache_size": len(self._shard_cache),
|
| 318 |
+
"cache_limit": self.shard_cache_limit,
|
| 319 |
+
"prefetch_shards": self.prefetch_shards,
|
| 320 |
+
"prefetch_hits": self._prefetch_hits,
|
| 321 |
+
"prefetch_misses": self._prefetch_misses,
|
| 322 |
+
"prefetch_pending": len(self._prefetch_requested),
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
def close(self) -> None:
|
| 326 |
+
self._prefetch_stop.set()
|
| 327 |
+
if self._prefetch_thread is not None and self._prefetch_thread.is_alive():
|
| 328 |
+
self._prefetch_thread.join(timeout=1.0)
|
| 329 |
+
self._prefetch_thread = None
|
| 330 |
+
self._prefetch_queue = None
|
| 331 |
+
|
| 332 |
+
def __del__(self):
|
| 333 |
+
self.close()
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
class SequentialShardDataset:
|
| 337 |
+
def __init__(
|
| 338 |
+
self,
|
| 339 |
+
cache_dir: str,
|
| 340 |
+
shuffle: bool = True,
|
| 341 |
+
rank: int = 0,
|
| 342 |
+
world_size: int = 1,
|
| 343 |
+
prefetch_shards: int = 2,
|
| 344 |
+
shard_cache_limit: int = 4,
|
| 345 |
+
) -> None:
|
| 346 |
+
self.cache_dir = cache_dir
|
| 347 |
+
self.shuffle = shuffle
|
| 348 |
+
self.rank = rank
|
| 349 |
+
self.world_size = max(world_size, 1)
|
| 350 |
+
self.prefetch_shards = max(int(prefetch_shards), 0)
|
| 351 |
+
self.shard_cache_limit = max(int(shard_cache_limit), 1)
|
| 352 |
+
|
| 353 |
+
self.metadata = self._load_metadata()
|
| 354 |
+
self.shard_files = self._discover_shards()
|
| 355 |
+
self.shard_sizes = self._resolve_shard_sizes()
|
| 356 |
+
self.total_rows = sum(self.shard_sizes)
|
| 357 |
+
self.target_shard_size = int(self.metadata.get("shard_size") or max(self.shard_sizes))
|
| 358 |
+
|
| 359 |
+
self._shard_cache: OrderedDict[int, List[Dict[str, Any]]] = OrderedDict()
|
| 360 |
+
self._cache_lock = threading.Lock()
|
| 361 |
+
self._prefetch_queue = None
|
| 362 |
+
self._prefetch_thread = None
|
| 363 |
+
self._prefetch_stop = threading.Event()
|
| 364 |
+
self._prefetch_requested: set[int] = set()
|
| 365 |
+
self._prefetch_hits = 0
|
| 366 |
+
self._prefetch_misses = 0
|
| 367 |
+
|
| 368 |
+
self._all_shard_indices = list(range(len(self.shard_files)))
|
| 369 |
+
self._local_shard_indices: List[int] = []
|
| 370 |
+
self.current_local_shard_pos = -1
|
| 371 |
+
self.current_records: Optional[List[Dict[str, Any]]] = None
|
| 372 |
+
|
| 373 |
+
def _load_metadata(self) -> Dict[str, Any]:
|
| 374 |
+
metadata_path = os.path.join(self.cache_dir, "metadata.json")
|
| 375 |
+
if not os.path.exists(metadata_path):
|
| 376 |
+
return {}
|
| 377 |
+
with open(metadata_path, "r", encoding="utf-8") as handle:
|
| 378 |
+
return json.load(handle)
|
| 379 |
+
|
| 380 |
+
def _discover_shards(self) -> List[str]:
|
| 381 |
+
if not os.path.isdir(self.cache_dir):
|
| 382 |
+
raise FileNotFoundError(f"Cache directory not found: {self.cache_dir}")
|
| 383 |
+
shards: List[str] = []
|
| 384 |
+
for name in sorted(os.listdir(self.cache_dir)):
|
| 385 |
+
if not (name.startswith("shard_") and name.endswith(".pt")):
|
| 386 |
+
continue
|
| 387 |
+
shards.append(os.path.join(self.cache_dir, name))
|
| 388 |
+
if not shards:
|
| 389 |
+
raise ValueError(f"No cache shards found under {self.cache_dir}")
|
| 390 |
+
return shards
|
| 391 |
+
|
| 392 |
+
def _resolve_shard_sizes(self) -> List[int]:
|
| 393 |
+
num_shards = len(self.shard_files)
|
| 394 |
+
metadata_num_shards = self.metadata.get("num_shards")
|
| 395 |
+
metadata_num_records = self.metadata.get("num_records")
|
| 396 |
+
shard_size = self.metadata.get("shard_size")
|
| 397 |
+
|
| 398 |
+
if (
|
| 399 |
+
isinstance(metadata_num_shards, int)
|
| 400 |
+
and isinstance(metadata_num_records, int)
|
| 401 |
+
and isinstance(shard_size, int)
|
| 402 |
+
and metadata_num_shards == num_shards
|
| 403 |
+
and metadata_num_records > 0
|
| 404 |
+
and shard_size > 0
|
| 405 |
+
):
|
| 406 |
+
shard_sizes = [shard_size] * num_shards
|
| 407 |
+
shard_sizes[-1] = metadata_num_records - shard_size * max(num_shards - 1, 0)
|
| 408 |
+
return shard_sizes
|
| 409 |
+
|
| 410 |
+
shard_sizes: List[int] = []
|
| 411 |
+
for shard_path in self.shard_files:
|
| 412 |
+
payload = torch.load(shard_path, map_location="cpu", weights_only=False)
|
| 413 |
+
records = payload.get("records")
|
| 414 |
+
if not isinstance(records, list):
|
| 415 |
+
raise ValueError(f"Invalid shard format in {shard_path}")
|
| 416 |
+
shard_sizes.append(len(records))
|
| 417 |
+
return shard_sizes
|
| 418 |
+
|
| 419 |
+
@staticmethod
|
| 420 |
+
def _deserialize_item(raw: Optional[Dict[str, Any]]) -> Optional[PairItem]:
|
| 421 |
+
if raw is None:
|
| 422 |
+
return None
|
| 423 |
+
modality = raw["type"]
|
| 424 |
+
if modality == "text" and "tokens" in raw:
|
| 425 |
+
value = raw["tokens"]
|
| 426 |
+
elif modality == "text":
|
| 427 |
+
value = raw["value"]
|
| 428 |
+
elif "tensor" in raw:
|
| 429 |
+
value = raw["tensor"]
|
| 430 |
+
else:
|
| 431 |
+
value = raw.get("value")
|
| 432 |
+
return PairItem(modality=modality, value=value)
|
| 433 |
+
|
| 434 |
+
def _store_shard(self, shard_idx: int, records: List[Dict[str, Any]]) -> None:
|
| 435 |
+
with self._cache_lock:
|
| 436 |
+
self._shard_cache[shard_idx] = records
|
| 437 |
+
self._shard_cache.move_to_end(shard_idx)
|
| 438 |
+
while len(self._shard_cache) > self.shard_cache_limit:
|
| 439 |
+
self._shard_cache.popitem(last=False)
|
| 440 |
+
|
| 441 |
+
def _ensure_prefetch_thread(self) -> None:
|
| 442 |
+
if self.prefetch_shards <= 0:
|
| 443 |
+
return
|
| 444 |
+
if self._prefetch_thread is not None and self._prefetch_thread.is_alive():
|
| 445 |
+
return
|
| 446 |
+
self._prefetch_stop.clear()
|
| 447 |
+
self._prefetch_queue = queue.Queue(maxsize=max(self.prefetch_shards * 2, 1))
|
| 448 |
+
self._prefetch_thread = threading.Thread(
|
| 449 |
+
target=self._prefetch_worker,
|
| 450 |
+
daemon=True,
|
| 451 |
+
name=f"sequential-shard-prefetch-{os.getpid()}",
|
| 452 |
+
)
|
| 453 |
+
self._prefetch_thread.start()
|
| 454 |
+
|
| 455 |
+
def _prefetch_worker(self) -> None:
|
| 456 |
+
while not self._prefetch_stop.is_set():
|
| 457 |
+
try:
|
| 458 |
+
shard_idx = self._prefetch_queue.get(timeout=0.1)
|
| 459 |
+
except queue.Empty:
|
| 460 |
+
continue
|
| 461 |
+
if shard_idx is None:
|
| 462 |
+
continue
|
| 463 |
+
try:
|
| 464 |
+
with self._cache_lock:
|
| 465 |
+
if shard_idx in self._shard_cache:
|
| 466 |
+
self._prefetch_hits += 1
|
| 467 |
+
continue
|
| 468 |
+
payload = torch.load(self.shard_files[shard_idx], map_location="cpu", weights_only=False)
|
| 469 |
+
self._store_shard(shard_idx, payload["records"])
|
| 470 |
+
self._prefetch_hits += 1
|
| 471 |
+
finally:
|
| 472 |
+
with self._cache_lock:
|
| 473 |
+
self._prefetch_requested.discard(shard_idx)
|
| 474 |
+
|
| 475 |
+
def _stop_prefetch_thread(self) -> None:
|
| 476 |
+
self._prefetch_stop.set()
|
| 477 |
+
if self._prefetch_thread is not None and self._prefetch_thread.is_alive():
|
| 478 |
+
self._prefetch_thread.join(timeout=1.0)
|
| 479 |
+
self._prefetch_thread = None
|
| 480 |
+
self._prefetch_queue = None
|
| 481 |
+
|
| 482 |
+
def _schedule_prefetch_from_position(self, local_pos: int) -> None:
|
| 483 |
+
if self.prefetch_shards <= 0:
|
| 484 |
+
return
|
| 485 |
+
self._ensure_prefetch_thread()
|
| 486 |
+
if self._prefetch_queue is None:
|
| 487 |
+
return
|
| 488 |
+
for next_pos in range(local_pos + 1, min(len(self._local_shard_indices), local_pos + 1 + self.prefetch_shards)):
|
| 489 |
+
shard_idx = self._local_shard_indices[next_pos]
|
| 490 |
+
with self._cache_lock:
|
| 491 |
+
if shard_idx in self._shard_cache or shard_idx in self._prefetch_requested:
|
| 492 |
+
continue
|
| 493 |
+
self._prefetch_requested.add(shard_idx)
|
| 494 |
+
try:
|
| 495 |
+
self._prefetch_queue.put_nowait(shard_idx)
|
| 496 |
+
except queue.Full:
|
| 497 |
+
with self._cache_lock:
|
| 498 |
+
self._prefetch_requested.discard(shard_idx)
|
| 499 |
+
break
|
| 500 |
+
|
| 501 |
+
def _build_local_shard_order(self, epoch: int) -> List[int]:
|
| 502 |
+
shard_indices = list(self._all_shard_indices)
|
| 503 |
+
if self.shuffle:
|
| 504 |
+
random.Random(42 + epoch).shuffle(shard_indices)
|
| 505 |
+
local_shards = shard_indices[self.rank::self.world_size]
|
| 506 |
+
max_shards = math.ceil(len(shard_indices) / self.world_size)
|
| 507 |
+
if not local_shards:
|
| 508 |
+
raise ValueError(f"Rank {self.rank} received no shards from {self.cache_dir}")
|
| 509 |
+
while len(local_shards) < max_shards:
|
| 510 |
+
local_shards.append(local_shards[len(local_shards) % len(local_shards)])
|
| 511 |
+
return local_shards
|
| 512 |
+
|
| 513 |
+
def _load_records_for_shard(self, shard_idx: int) -> List[Dict[str, Any]]:
|
| 514 |
+
cached = None
|
| 515 |
+
with self._cache_lock:
|
| 516 |
+
cached = self._shard_cache.get(shard_idx)
|
| 517 |
+
if cached is not None:
|
| 518 |
+
self._shard_cache.move_to_end(shard_idx)
|
| 519 |
+
if cached is not None:
|
| 520 |
+
return cached
|
| 521 |
+
|
| 522 |
+
self._prefetch_misses += 1
|
| 523 |
+
payload = torch.load(self.shard_files[shard_idx], map_location="cpu", weights_only=False)
|
| 524 |
+
records = payload["records"]
|
| 525 |
+
self._store_shard(shard_idx, records)
|
| 526 |
+
with self._cache_lock:
|
| 527 |
+
self._prefetch_requested.discard(shard_idx)
|
| 528 |
+
return records
|
| 529 |
+
|
| 530 |
+
def _pad_records(self, records: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 531 |
+
if len(records) >= self.target_shard_size:
|
| 532 |
+
return records
|
| 533 |
+
repeat = math.ceil(self.target_shard_size / len(records))
|
| 534 |
+
return (records * repeat)[: self.target_shard_size]
|
| 535 |
+
|
| 536 |
+
def reset(self, epoch: int) -> bool:
|
| 537 |
+
self._stop_prefetch_thread()
|
| 538 |
+
self._local_shard_indices = self._build_local_shard_order(epoch)
|
| 539 |
+
self.current_local_shard_pos = -1
|
| 540 |
+
self.current_records = None
|
| 541 |
+
with self._cache_lock:
|
| 542 |
+
self._prefetch_requested.clear()
|
| 543 |
+
if self.prefetch_shards > 0:
|
| 544 |
+
self._ensure_prefetch_thread()
|
| 545 |
+
return self.next_shard()
|
| 546 |
+
|
| 547 |
+
def next_shard(self) -> bool:
|
| 548 |
+
self.current_local_shard_pos += 1
|
| 549 |
+
if self.current_local_shard_pos >= len(self._local_shard_indices):
|
| 550 |
+
self.current_records = None
|
| 551 |
+
return False
|
| 552 |
+
shard_idx = self._local_shard_indices[self.current_local_shard_pos]
|
| 553 |
+
records = self._load_records_for_shard(shard_idx)
|
| 554 |
+
self.current_records = self._pad_records(records)
|
| 555 |
+
self._schedule_prefetch_from_position(self.current_local_shard_pos)
|
| 556 |
+
return True
|
| 557 |
+
|
| 558 |
+
def __len__(self) -> int:
|
| 559 |
+
return len(self.current_records or [])
|
| 560 |
+
|
| 561 |
+
def __getitem__(self, idx: int) -> TrainRecord:
|
| 562 |
+
if self.current_records is None:
|
| 563 |
+
raise IndexError(idx)
|
| 564 |
+
raw = self.current_records[idx]
|
| 565 |
+
return TrainRecord(
|
| 566 |
+
query=self._deserialize_item(raw["query"]),
|
| 567 |
+
positive=self._deserialize_item(raw["positive"]),
|
| 568 |
+
negative=self._deserialize_item(raw.get("negative")),
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
def estimated_num_batches(self, batch_size: int, drop_last: bool) -> int:
|
| 572 |
+
shard_batches = self.target_shard_size // batch_size if drop_last else math.ceil(self.target_shard_size / batch_size)
|
| 573 |
+
return shard_batches * max(len(self._build_local_shard_order(0)), 1)
|
| 574 |
+
|
| 575 |
+
def get_prefetch_stats(self) -> Dict[str, int]:
|
| 576 |
+
with self._cache_lock:
|
| 577 |
+
return {
|
| 578 |
+
"cache_size": len(self._shard_cache),
|
| 579 |
+
"cache_limit": self.shard_cache_limit,
|
| 580 |
+
"prefetch_shards": self.prefetch_shards,
|
| 581 |
+
"prefetch_hits": self._prefetch_hits,
|
| 582 |
+
"prefetch_misses": self._prefetch_misses,
|
| 583 |
+
"prefetch_pending": len(self._prefetch_requested),
|
| 584 |
+
"local_shards": len(self._local_shard_indices),
|
| 585 |
+
"target_shard_size": self.target_shard_size,
|
| 586 |
+
}
|
| 587 |
+
|
| 588 |
+
def close(self) -> None:
|
| 589 |
+
self._stop_prefetch_thread()
|
| 590 |
+
|
| 591 |
+
def __del__(self):
|
| 592 |
+
self.close()
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
def _process_shard() -> tuple[int, int]:
|
| 596 |
+
rank = int(os.environ.get("ACCELERATE_PROCESS_INDEX") or os.environ.get("RANK") or 0)
|
| 597 |
+
world_size = int(os.environ.get("WORLD_SIZE") or os.environ.get("ACCELERATE_NUM_PROCESSES") or 1)
|
| 598 |
+
worker_info = torch.utils.data.get_worker_info()
|
| 599 |
+
if worker_info is None:
|
| 600 |
+
return rank, max(world_size, 1)
|
| 601 |
+
|
| 602 |
+
total_shards = max(world_size, 1) * worker_info.num_workers
|
| 603 |
+
shard_id = rank * worker_info.num_workers + worker_info.id
|
| 604 |
+
return shard_id, max(total_shards, 1)
|
| 605 |
+
|
| 606 |
+
|
| 607 |
+
def iter_sentence_transformers_rows(
|
| 608 |
+
manifest_path: str,
|
| 609 |
+
image_root: Optional[str],
|
| 610 |
+
audio_root: Optional[str],
|
| 611 |
+
allow_missing_negative: bool,
|
| 612 |
+
allowed_modalities: Optional[List[str]],
|
| 613 |
+
query_modalities: Optional[List[str]],
|
| 614 |
+
positive_modalities: Optional[List[str]],
|
| 615 |
+
negative_modalities: Optional[List[str]],
|
| 616 |
+
use_negative_column: bool,
|
| 617 |
+
):
|
| 618 |
+
allowed = set(allowed_modalities or [])
|
| 619 |
+
allowed_query = set(query_modalities or [])
|
| 620 |
+
allowed_positive = set(positive_modalities or [])
|
| 621 |
+
allowed_negative = set(negative_modalities or [])
|
| 622 |
+
shard_id, total_shards = _process_shard()
|
| 623 |
+
matched_index = 0
|
| 624 |
+
|
| 625 |
+
for record in iter_manifest_records(
|
| 626 |
+
manifest_path=manifest_path,
|
| 627 |
+
image_root=image_root,
|
| 628 |
+
audio_root=audio_root,
|
| 629 |
+
allow_missing_negative=allow_missing_negative,
|
| 630 |
+
):
|
| 631 |
+
if not record_matches_filters(
|
| 632 |
+
record,
|
| 633 |
+
allowed=allowed,
|
| 634 |
+
allowed_query=allowed_query,
|
| 635 |
+
allowed_positive=allowed_positive,
|
| 636 |
+
allowed_negative=allowed_negative,
|
| 637 |
+
):
|
| 638 |
+
continue
|
| 639 |
+
|
| 640 |
+
if matched_index % total_shards == shard_id:
|
| 641 |
+
yield record_to_sentence_transformers_row(record, include_negative=use_negative_column)
|
| 642 |
+
matched_index += 1
|
| 643 |
+
|
| 644 |
+
|
| 645 |
+
def collate_records(batch: List[TrainRecord]) -> Dict[str, List[PairItem]]:
|
| 646 |
+
return {
|
| 647 |
+
"query": [r.query for r in batch],
|
| 648 |
+
"positive": [r.positive for r in batch],
|
| 649 |
+
"negative": [r.negative for r in batch],
|
| 650 |
+
}
|
| 651 |
+
|
| 652 |
+
|
| 653 |
+
def sentence_transformers_input(item: PairItem) -> Any:
|
| 654 |
+
payload: Dict[str, Any] = {}
|
| 655 |
+
if item.modality == "text":
|
| 656 |
+
payload["text"] = item.value
|
| 657 |
+
return payload
|
| 658 |
+
if item.modality == "image":
|
| 659 |
+
payload["image"] = item.value
|
| 660 |
+
return payload
|
| 661 |
+
if item.modality == "audio":
|
| 662 |
+
payload["audio"] = item.value
|
| 663 |
+
return payload
|
| 664 |
+
return item.value
|
| 665 |
+
|
| 666 |
+
|
| 667 |
+
def resolve_media(item: PairItem, image_root: Optional[str], audio_root: Optional[str]) -> PairItem:
|
| 668 |
+
if item.modality == "image" and image_root and not os.path.isabs(item.value):
|
| 669 |
+
return PairItem(item.modality, os.path.join(image_root, item.value))
|
| 670 |
+
if item.modality == "audio" and audio_root and not os.path.isabs(item.value):
|
| 671 |
+
return PairItem(item.modality, os.path.join(audio_root, item.value))
|
| 672 |
+
return item
|
| 673 |
+
|
| 674 |
+
|
| 675 |
+
def iter_manifest_records(
|
| 676 |
+
manifest_path: str,
|
| 677 |
+
image_root: Optional[str] = None,
|
| 678 |
+
audio_root: Optional[str] = None,
|
| 679 |
+
allow_missing_negative: bool = True,
|
| 680 |
+
) -> Iterable[TrainRecord]:
|
| 681 |
+
if not os.path.exists(manifest_path):
|
| 682 |
+
raise FileNotFoundError(f"Manifest not found: {manifest_path}")
|
| 683 |
+
|
| 684 |
+
with open(manifest_path, "r", encoding="utf-8") as handle:
|
| 685 |
+
for line_no, line in enumerate(handle, start=1):
|
| 686 |
+
line = line.strip()
|
| 687 |
+
if not line:
|
| 688 |
+
continue
|
| 689 |
+
raw = json.loads(line)
|
| 690 |
+
record = parse_record(raw)
|
| 691 |
+
record = TrainRecord(
|
| 692 |
+
query=resolve_media(record.query, image_root, audio_root),
|
| 693 |
+
positive=resolve_media(record.positive, image_root, audio_root),
|
| 694 |
+
negative=resolve_media(record.negative, image_root, audio_root) if record.negative else None,
|
| 695 |
+
)
|
| 696 |
+
if record.negative is None and not allow_missing_negative:
|
| 697 |
+
raise ValueError(f"Missing negative at line {line_no}")
|
| 698 |
+
yield record
|
| 699 |
+
|
| 700 |
+
|
| 701 |
+
def record_matches_filters(
|
| 702 |
+
record: TrainRecord,
|
| 703 |
+
allowed: set[str],
|
| 704 |
+
allowed_query: set[str],
|
| 705 |
+
allowed_positive: set[str],
|
| 706 |
+
allowed_negative: set[str],
|
| 707 |
+
) -> bool:
|
| 708 |
+
record_modalities = {record.query.modality, record.positive.modality}
|
| 709 |
+
if record.negative is not None:
|
| 710 |
+
record_modalities.add(record.negative.modality)
|
| 711 |
+
if allowed and not record_modalities.issubset(allowed):
|
| 712 |
+
return False
|
| 713 |
+
if allowed_query and record.query.modality not in allowed_query:
|
| 714 |
+
return False
|
| 715 |
+
if allowed_positive and record.positive.modality not in allowed_positive:
|
| 716 |
+
return False
|
| 717 |
+
if record.negative is not None and allowed_negative and record.negative.modality not in allowed_negative:
|
| 718 |
+
return False
|
| 719 |
+
return True
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
def record_to_sentence_transformers_row(record: TrainRecord, include_negative: bool) -> Dict[str, Any]:
|
| 723 |
+
row = {
|
| 724 |
+
"query": sentence_transformers_input(record.query),
|
| 725 |
+
"positive": sentence_transformers_input(record.positive),
|
| 726 |
+
}
|
| 727 |
+
if include_negative and record.negative is not None:
|
| 728 |
+
row["negative_0"] = sentence_transformers_input(record.negative)
|
| 729 |
+
return row
|
| 730 |
+
|
| 731 |
+
|
| 732 |
+
def summarize_manifest_records(
|
| 733 |
+
manifest_path: str,
|
| 734 |
+
image_root: Optional[str] = None,
|
| 735 |
+
audio_root: Optional[str] = None,
|
| 736 |
+
allow_missing_negative: bool = True,
|
| 737 |
+
allowed_modalities: Optional[List[str]] = None,
|
| 738 |
+
query_modalities: Optional[List[str]] = None,
|
| 739 |
+
positive_modalities: Optional[List[str]] = None,
|
| 740 |
+
negative_modalities: Optional[List[str]] = None,
|
| 741 |
+
max_records: Optional[int] = None,
|
| 742 |
+
) -> Dict[str, Any]:
|
| 743 |
+
modalities = set()
|
| 744 |
+
negatives_present = 0
|
| 745 |
+
negatives_missing = 0
|
| 746 |
+
skipped_rows = 0
|
| 747 |
+
num_rows = 0
|
| 748 |
+
allowed = set(allowed_modalities or [])
|
| 749 |
+
allowed_query = set(query_modalities or [])
|
| 750 |
+
allowed_positive = set(positive_modalities or [])
|
| 751 |
+
allowed_negative = set(negative_modalities or [])
|
| 752 |
+
|
| 753 |
+
for record in iter_manifest_records(
|
| 754 |
+
manifest_path=manifest_path,
|
| 755 |
+
image_root=image_root,
|
| 756 |
+
audio_root=audio_root,
|
| 757 |
+
allow_missing_negative=allow_missing_negative,
|
| 758 |
+
):
|
| 759 |
+
if not record_matches_filters(
|
| 760 |
+
record,
|
| 761 |
+
allowed=allowed,
|
| 762 |
+
allowed_query=allowed_query,
|
| 763 |
+
allowed_positive=allowed_positive,
|
| 764 |
+
allowed_negative=allowed_negative,
|
| 765 |
+
):
|
| 766 |
+
skipped_rows += 1
|
| 767 |
+
continue
|
| 768 |
+
|
| 769 |
+
modalities.add(record.query.modality)
|
| 770 |
+
modalities.add(record.positive.modality)
|
| 771 |
+
if record.negative is not None:
|
| 772 |
+
modalities.add(record.negative.modality)
|
| 773 |
+
negatives_present += 1
|
| 774 |
+
else:
|
| 775 |
+
negatives_missing += 1
|
| 776 |
+
num_rows += 1
|
| 777 |
+
if max_records is not None and num_rows >= max_records:
|
| 778 |
+
break
|
| 779 |
+
|
| 780 |
+
if num_rows == 0:
|
| 781 |
+
raise ValueError(f"No records loaded from {manifest_path}")
|
| 782 |
+
|
| 783 |
+
return {
|
| 784 |
+
"modalities": sorted(modalities),
|
| 785 |
+
"num_rows": num_rows,
|
| 786 |
+
"has_uniform_negatives": negatives_present > 0 and negatives_missing == 0,
|
| 787 |
+
"num_negatives_present": negatives_present,
|
| 788 |
+
"num_negatives_missing": negatives_missing,
|
| 789 |
+
"skipped_rows": skipped_rows,
|
| 790 |
+
}
|
| 791 |
+
|
| 792 |
+
|
| 793 |
+
def manifest_to_sentence_transformers_dataset(
|
| 794 |
+
manifest_path: str,
|
| 795 |
+
image_root: Optional[str] = None,
|
| 796 |
+
audio_root: Optional[str] = None,
|
| 797 |
+
allow_missing_negative: bool = True,
|
| 798 |
+
allowed_modalities: Optional[List[str]] = None,
|
| 799 |
+
query_modalities: Optional[List[str]] = None,
|
| 800 |
+
positive_modalities: Optional[List[str]] = None,
|
| 801 |
+
negative_modalities: Optional[List[str]] = None,
|
| 802 |
+
as_iterable: bool = False,
|
| 803 |
+
max_records: Optional[int] = None,
|
| 804 |
+
) -> tuple[Dataset | IterableDataset, Dict[str, Any]]:
|
| 805 |
+
info = summarize_manifest_records(
|
| 806 |
+
manifest_path=manifest_path,
|
| 807 |
+
image_root=image_root,
|
| 808 |
+
audio_root=audio_root,
|
| 809 |
+
allow_missing_negative=allow_missing_negative,
|
| 810 |
+
allowed_modalities=allowed_modalities,
|
| 811 |
+
query_modalities=query_modalities,
|
| 812 |
+
positive_modalities=positive_modalities,
|
| 813 |
+
negative_modalities=negative_modalities,
|
| 814 |
+
max_records=max_records,
|
| 815 |
+
)
|
| 816 |
+
|
| 817 |
+
dataset_out: Dataset | IterableDataset
|
| 818 |
+
if as_iterable:
|
| 819 |
+
column_names = ["query", "positive"]
|
| 820 |
+
if info["has_uniform_negatives"]:
|
| 821 |
+
column_names.append("negative_0")
|
| 822 |
+
dataset_out = IterableDataset.from_generator(
|
| 823 |
+
iter_sentence_transformers_rows,
|
| 824 |
+
features=Features({key: Value("null") for key in column_names}),
|
| 825 |
+
gen_kwargs={
|
| 826 |
+
"manifest_path": manifest_path,
|
| 827 |
+
"image_root": image_root,
|
| 828 |
+
"audio_root": audio_root,
|
| 829 |
+
"allow_missing_negative": allow_missing_negative,
|
| 830 |
+
"allowed_modalities": allowed_modalities,
|
| 831 |
+
"query_modalities": query_modalities,
|
| 832 |
+
"positive_modalities": positive_modalities,
|
| 833 |
+
"negative_modalities": negative_modalities,
|
| 834 |
+
"use_negative_column": info["has_uniform_negatives"],
|
| 835 |
+
},
|
| 836 |
+
)
|
| 837 |
+
else:
|
| 838 |
+
dataset = JsonlManifestDataset(
|
| 839 |
+
manifest_path=manifest_path,
|
| 840 |
+
image_root=image_root,
|
| 841 |
+
audio_root=audio_root,
|
| 842 |
+
allow_missing_negative=allow_missing_negative,
|
| 843 |
+
)
|
| 844 |
+
allowed = set(allowed_modalities or [])
|
| 845 |
+
allowed_query = set(query_modalities or [])
|
| 846 |
+
allowed_positive = set(positive_modalities or [])
|
| 847 |
+
allowed_negative = set(negative_modalities or [])
|
| 848 |
+
rows: List[Dict[str, Any]] = []
|
| 849 |
+
for record in dataset.records:
|
| 850 |
+
if not record_matches_filters(
|
| 851 |
+
record,
|
| 852 |
+
allowed=allowed,
|
| 853 |
+
allowed_query=allowed_query,
|
| 854 |
+
allowed_positive=allowed_positive,
|
| 855 |
+
allowed_negative=allowed_negative,
|
| 856 |
+
):
|
| 857 |
+
continue
|
| 858 |
+
rows.append(record_to_sentence_transformers_row(record, include_negative=info["has_uniform_negatives"]))
|
| 859 |
+
if max_records is not None and len(rows) >= max_records:
|
| 860 |
+
break
|
| 861 |
+
dataset_out = Dataset.from_list(rows)
|
| 862 |
+
|
| 863 |
+
return dataset_out, info
|
src/hf_st_mm/model.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import defaultdict
|
| 2 |
+
from typing import Any, Dict, List
|
| 3 |
+
|
| 4 |
+
import librosa
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from sentence_transformers import SentenceTransformer
|
| 11 |
+
from transformers import AutoModel, AutoProcessor, WhisperFeatureExtractor, WhisperModel
|
| 12 |
+
|
| 13 |
+
from .data import PairItem
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class MultiModalSentenceEmbedder(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
text_encoder_name: str,
|
| 20 |
+
image_encoder_name: str,
|
| 21 |
+
audio_encoder_name: str,
|
| 22 |
+
embedding_dim: int,
|
| 23 |
+
max_text_length: int,
|
| 24 |
+
) -> None:
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.text_model = SentenceTransformer(text_encoder_name)
|
| 27 |
+
self.text_model.max_seq_length = max_text_length
|
| 28 |
+
|
| 29 |
+
self.image_model = AutoModel.from_pretrained(image_encoder_name, trust_remote_code=True)
|
| 30 |
+
self.image_processor = AutoProcessor.from_pretrained(image_encoder_name, trust_remote_code=True)
|
| 31 |
+
|
| 32 |
+
whisper = WhisperModel.from_pretrained(audio_encoder_name)
|
| 33 |
+
self.audio_model = whisper.encoder
|
| 34 |
+
self.audio_processor = WhisperFeatureExtractor.from_pretrained(audio_encoder_name)
|
| 35 |
+
|
| 36 |
+
text_dim = self.text_model.get_sentence_embedding_dimension()
|
| 37 |
+
image_dim = self._get_vision_dim(self.image_model)
|
| 38 |
+
audio_dim = whisper.config.d_model
|
| 39 |
+
|
| 40 |
+
self.text_proj = nn.Linear(text_dim, embedding_dim) if text_dim != embedding_dim else nn.Identity()
|
| 41 |
+
self.image_proj = nn.Linear(image_dim, embedding_dim) if image_dim != embedding_dim else nn.Identity()
|
| 42 |
+
self.audio_proj = nn.Linear(audio_dim, embedding_dim) if audio_dim != embedding_dim else nn.Identity()
|
| 43 |
+
|
| 44 |
+
@staticmethod
|
| 45 |
+
def _get_vision_dim(model: nn.Module) -> int:
|
| 46 |
+
if hasattr(model, "vision_model") and hasattr(model.config, "vision_config"):
|
| 47 |
+
return int(model.config.vision_config.hidden_size)
|
| 48 |
+
if hasattr(model.config, "hidden_size"):
|
| 49 |
+
return int(model.config.hidden_size)
|
| 50 |
+
raise ValueError("Could not infer image hidden size")
|
| 51 |
+
|
| 52 |
+
def _encode_text(self, texts: List[Any]) -> torch.Tensor:
|
| 53 |
+
device = next(self.parameters()).device
|
| 54 |
+
normalized: List[torch.Tensor | None] = [None] * len(texts)
|
| 55 |
+
|
| 56 |
+
dict_positions = [idx for idx, item in enumerate(texts) if isinstance(item, dict)]
|
| 57 |
+
if dict_positions:
|
| 58 |
+
pad_values = {
|
| 59 |
+
"input_ids": 0,
|
| 60 |
+
"attention_mask": 0,
|
| 61 |
+
"token_type_ids": 0,
|
| 62 |
+
}
|
| 63 |
+
dict_items = [texts[idx] for idx in dict_positions]
|
| 64 |
+
features = {
|
| 65 |
+
key: pad_sequence(
|
| 66 |
+
[item[key].detach().cpu() for item in dict_items],
|
| 67 |
+
batch_first=True,
|
| 68 |
+
padding_value=pad_values.get(key, 0),
|
| 69 |
+
).to(device)
|
| 70 |
+
for key in dict_items[0].keys()
|
| 71 |
+
}
|
| 72 |
+
out = self.text_model(features)
|
| 73 |
+
emb = F.normalize(self.text_proj(out["sentence_embedding"]), p=2, dim=-1)
|
| 74 |
+
for loc, row in zip(dict_positions, emb):
|
| 75 |
+
normalized[loc] = row
|
| 76 |
+
|
| 77 |
+
raw_positions = [idx for idx, item in enumerate(texts) if not isinstance(item, dict)]
|
| 78 |
+
if raw_positions:
|
| 79 |
+
raw_texts = [texts[idx] for idx in raw_positions]
|
| 80 |
+
features = self.text_model.tokenize(raw_texts)
|
| 81 |
+
features = {
|
| 82 |
+
k: (v.to(device) if hasattr(v, "to") else v)
|
| 83 |
+
for k, v in features.items()
|
| 84 |
+
}
|
| 85 |
+
out = self.text_model(features)
|
| 86 |
+
emb = F.normalize(self.text_proj(out["sentence_embedding"]), p=2, dim=-1)
|
| 87 |
+
for loc, row in zip(raw_positions, emb):
|
| 88 |
+
normalized[loc] = row
|
| 89 |
+
|
| 90 |
+
return torch.stack([row for row in normalized if row is not None], dim=0)
|
| 91 |
+
|
| 92 |
+
def _encode_image_paths(self, paths: List[str]) -> torch.Tensor:
|
| 93 |
+
images = [Image.open(path).convert("RGB") for path in paths]
|
| 94 |
+
proc = self.image_processor(images=images, return_tensors="pt")
|
| 95 |
+
device = next(self.parameters()).device
|
| 96 |
+
proc = {k: v.to(device) for k, v in proc.items()}
|
| 97 |
+
return self._encode_image_pixel_values(proc["pixel_values"])
|
| 98 |
+
|
| 99 |
+
def _encode_image_pixel_values(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
| 100 |
+
device = next(self.parameters()).device
|
| 101 |
+
proc = {"pixel_values": pixel_values.to(device)}
|
| 102 |
+
if hasattr(self.image_model, "vision_model"):
|
| 103 |
+
out = self.image_model.vision_model(**proc, output_hidden_states=False)
|
| 104 |
+
hidden = out.last_hidden_state
|
| 105 |
+
else:
|
| 106 |
+
out = self.image_model(**proc, output_hidden_states=False)
|
| 107 |
+
hidden = out.last_hidden_state
|
| 108 |
+
pooled = hidden[:, 1:].mean(dim=1) if hidden.shape[1] > 1 else hidden.mean(dim=1)
|
| 109 |
+
emb = self.image_proj(pooled)
|
| 110 |
+
return F.normalize(emb, p=2, dim=-1)
|
| 111 |
+
|
| 112 |
+
def _encode_audio_paths(self, paths: List[str]) -> torch.Tensor:
|
| 113 |
+
waves = [librosa.load(path, sr=16000, mono=True)[0] for path in paths]
|
| 114 |
+
proc = self.audio_processor(waves, sampling_rate=16000, return_tensors="pt")
|
| 115 |
+
return self._encode_audio_features(proc["input_features"])
|
| 116 |
+
|
| 117 |
+
def _encode_audio_features(self, input_features: torch.Tensor) -> torch.Tensor:
|
| 118 |
+
device = next(self.parameters()).device
|
| 119 |
+
input_features = input_features.to(device)
|
| 120 |
+
input_features = input_features.to(self.audio_model.conv1.weight.dtype)
|
| 121 |
+
out = self.audio_model(input_features=input_features, output_hidden_states=False)
|
| 122 |
+
pooled = out.last_hidden_state.mean(dim=1)
|
| 123 |
+
emb = self.audio_proj(pooled)
|
| 124 |
+
return F.normalize(emb, p=2, dim=-1)
|
| 125 |
+
|
| 126 |
+
@staticmethod
|
| 127 |
+
def _stack_tensor_values(values: List[Any]) -> torch.Tensor:
|
| 128 |
+
tensors = []
|
| 129 |
+
for value in values:
|
| 130 |
+
if not torch.is_tensor(value):
|
| 131 |
+
raise TypeError("Expected tensor payload in cached item")
|
| 132 |
+
tensor = value.detach().cpu()
|
| 133 |
+
if tensor.dim() > 0 and tensor.shape[0] == 1:
|
| 134 |
+
tensor = tensor.squeeze(0)
|
| 135 |
+
tensors.append(tensor)
|
| 136 |
+
return torch.stack(tensors, dim=0)
|
| 137 |
+
|
| 138 |
+
def encode_items(self, items: List[PairItem]) -> torch.Tensor:
|
| 139 |
+
grouped = defaultdict(list)
|
| 140 |
+
for idx, item in enumerate(items):
|
| 141 |
+
grouped[item.modality].append((idx, item.value))
|
| 142 |
+
|
| 143 |
+
device = next(self.parameters()).device
|
| 144 |
+
out = [None] * len(items)
|
| 145 |
+
|
| 146 |
+
if grouped["text"]:
|
| 147 |
+
idxs, vals = zip(*grouped["text"])
|
| 148 |
+
embs = self._encode_text(list(vals))
|
| 149 |
+
for loc, emb in zip(idxs, embs):
|
| 150 |
+
out[loc] = emb
|
| 151 |
+
|
| 152 |
+
if grouped["image"]:
|
| 153 |
+
idxs, vals = zip(*grouped["image"])
|
| 154 |
+
tensor_pairs = [(idx, val) for idx, val in zip(idxs, vals) if torch.is_tensor(val)]
|
| 155 |
+
path_pairs = [(idx, val) for idx, val in zip(idxs, vals) if not torch.is_tensor(val)]
|
| 156 |
+
if path_pairs:
|
| 157 |
+
p_idxs, p_vals = zip(*path_pairs)
|
| 158 |
+
embs = self._encode_image_paths(list(p_vals))
|
| 159 |
+
for loc, emb in zip(p_idxs, embs):
|
| 160 |
+
out[loc] = emb
|
| 161 |
+
if tensor_pairs:
|
| 162 |
+
t_idxs, t_vals = zip(*tensor_pairs)
|
| 163 |
+
embs = self._encode_image_pixel_values(self._stack_tensor_values(list(t_vals)))
|
| 164 |
+
for loc, emb in zip(t_idxs, embs):
|
| 165 |
+
out[loc] = emb
|
| 166 |
+
|
| 167 |
+
if grouped["audio"]:
|
| 168 |
+
idxs, vals = zip(*grouped["audio"])
|
| 169 |
+
tensor_pairs = [(idx, val) for idx, val in zip(idxs, vals) if torch.is_tensor(val)]
|
| 170 |
+
path_pairs = [(idx, val) for idx, val in zip(idxs, vals) if not torch.is_tensor(val)]
|
| 171 |
+
if path_pairs:
|
| 172 |
+
p_idxs, p_vals = zip(*path_pairs)
|
| 173 |
+
embs = self._encode_audio_paths(list(p_vals))
|
| 174 |
+
for loc, emb in zip(p_idxs, embs):
|
| 175 |
+
out[loc] = emb
|
| 176 |
+
if tensor_pairs:
|
| 177 |
+
t_idxs, t_vals = zip(*tensor_pairs)
|
| 178 |
+
embs = self._encode_audio_features(self._stack_tensor_values(list(t_vals)))
|
| 179 |
+
for loc, emb in zip(t_idxs, embs):
|
| 180 |
+
out[loc] = emb
|
| 181 |
+
|
| 182 |
+
stacked = torch.stack(out, dim=0).to(device=device, dtype=torch.float32)
|
| 183 |
+
return F.normalize(stacked, p=2, dim=-1)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def multiple_negatives_ranking_loss(anchor: torch.Tensor, positive: torch.Tensor, scale: float = 20.0) -> torch.Tensor:
|
| 187 |
+
scores = torch.matmul(anchor, positive.T) * scale
|
| 188 |
+
labels = torch.arange(scores.shape[0], device=scores.device)
|
| 189 |
+
loss_a = torch.nn.functional.cross_entropy(scores, labels)
|
| 190 |
+
loss_b = torch.nn.functional.cross_entropy(scores.T, labels)
|
| 191 |
+
return (loss_a + loss_b) * 0.5
|