Xunzhuo commited on
Commit
17bbccd
·
verified ·
1 Parent(s): 954f2c4

Mirror agentic-intelligence-lab/elephant-embeddings-v1-multimodal-large from ModelScope

Browse files

Mirrored from https://modelscope.cn/models/agentic-intelligence-lab/elephant-embeddings-v1-multimodal-large/summary

.msc ADDED
Binary file (821 Bytes). View file
 
.mv ADDED
@@ -0,0 +1 @@
 
 
1
+ Revision:master,CreatedAt:1778776244
README.md ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ library_name: pytorch
4
+ pipeline_tag: sentence-similarity
5
+ language:
6
+ - multilingual
7
+ tags:
8
+ - agentic-intelligence-lab
9
+ - elephant
10
+ - embeddings
11
+ - multimodal
12
+ - retrieval
13
+ - rag
14
+ - agents
15
+ - routing
16
+ - image-text
17
+ - audio-text
18
+ - tri-encoder
19
+ - pytorch
20
+ model-index:
21
+ - name: elephant-embeddings-v1-multimodal-large
22
+ results:
23
+ - task:
24
+ type: sentence-similarity
25
+ dataset:
26
+ name: Internal cached validation set
27
+ type: cached_retrieval_validation
28
+ metrics:
29
+ - name: Eval loss
30
+ type: eval_loss
31
+ value: 0.389702
32
+ - name: Eval top1
33
+ type: eval_top1
34
+ value: 0.861707
35
+ ---
36
+
37
+ # Elephant Embeddings V1 Multimodal Large
38
+
39
+ `elephant-embeddings-v1-multimodal-large` is the large multimodal embedding model in the **Agentic Intelligence Lab Elephant Embeddings V1** family.
40
+
41
+ This ModelScope release is maintained by `agentic-intelligence-lab` to make Elephant embedding models easier to download and deploy in mainland China. It mirrors and renames the upstream HuggingFace model `llm-semantic-router/multi-modal-embed-large` under a consistent Elephant model namespace.
42
+
43
+ ## Positioning
44
+
45
+ This model is a production-oriented multimodal embedding model for semantic routing, retrieval, and cross-modal matching across text, image, and audio.
46
+
47
+ It is not a generative chat or captioning model. Instead, it maps different modalities into one shared embedding space so agent systems can compare requests, screenshots, documents, and audio records with the same retrieval interface.
48
+
49
+ ## Model at a glance
50
+
51
+ | Item | Value |
52
+ | --- | --- |
53
+ | Family | Elephant Embeddings V1 |
54
+ | Maintainer | Agentic Intelligence Lab |
55
+ | Model type | Multimodal embedding model |
56
+ | Modalities | Text, image, audio |
57
+ | Architecture | Custom PyTorch tri-encoder |
58
+ | Text encoder | `llm-semantic-router/mmbert-embed-32k-2d-matryoshka` |
59
+ | Image encoder | `google/siglip2-so400m-patch14-384` |
60
+ | Audio encoder | `openai/whisper-medium` |
61
+ | Embedding dimension | 768 |
62
+ | Max text length | 32,768 tokens |
63
+ | Objective | Cached multiple negatives ranking loss |
64
+ | Upstream source | `llm-semantic-router/multi-modal-embed-large` |
65
+ | License | Apache 2.0 |
66
+
67
+ ## Why it fits agentic workloads
68
+
69
+ Agentic products increasingly need to retrieve and route over mixed inputs: user text, screenshots, UI states, documents, voice notes, support calls, and multimodal memory. This model is designed for that operating pattern.
70
+
71
+ Key advantages:
72
+
73
+ - **Shared semantic space**: text, images, and audio can be compared with cosine similarity.
74
+ - **Routing-grade representation**: optimized for retrieval, matching, and routing rather than generation.
75
+ - **Strong modality towers**: uses dedicated text, image, and audio encoders instead of forcing all modalities through a single monolithic checkpoint.
76
+ - **Long-context text path**: supports long tool descriptions, traces, and knowledge chunks through the text encoder.
77
+ - **Production packaging**: includes the custom source package needed to construct and run the tri-encoder.
78
+
79
+ ## Recommended use cases
80
+
81
+ | Scenario | Example |
82
+ | --- | --- |
83
+ | Multimodal RAG | Retrieve text notes using an image or audio query |
84
+ | Agent routing | Route screenshots, user text, or voice requests to the right tool or workflow |
85
+ | Memory search | Search mixed text/image/audio memory stores in one vector space |
86
+ | Support and operations | Match tickets, screenshots, logs, and recorded calls semantically |
87
+ | Offline indexing | Build high-quality 768d multimodal indexes |
88
+
89
+ ## Quick start on ModelScope
90
+
91
+ ```bash
92
+ pip install modelscope torch sentence-transformers transformers accelerate safetensors pillow librosa soundfile
93
+ ```
94
+
95
+ ```python
96
+ import json
97
+ import os
98
+ import sys
99
+
100
+ import torch
101
+ import torch.nn.functional as F
102
+ from modelscope import snapshot_download
103
+
104
+ repo_id = "agentic-intelligence-lab/elephant-embeddings-v1-multimodal-large"
105
+ local_dir = snapshot_download(repo_id)
106
+
107
+ sys.path.insert(0, os.path.join(local_dir, "src"))
108
+
109
+ from hf_st_mm.data import PairItem
110
+ from hf_st_mm.model import MultiModalSentenceEmbedder
111
+
112
+ with open(os.path.join(local_dir, "config.json"), "r", encoding="utf-8") as handle:
113
+ cfg = json.load(handle)
114
+
115
+ model = MultiModalSentenceEmbedder(
116
+ text_encoder_name=cfg["model"]["text_encoder_name"],
117
+ image_encoder_name=cfg["model"]["image_encoder_name"],
118
+ audio_encoder_name=cfg["model"]["audio_encoder_name"],
119
+ embedding_dim=int(cfg["model"]["embedding_dim"]),
120
+ max_text_length=int(cfg["model"]["max_text_length"]),
121
+ )
122
+
123
+ state_dict = torch.load(os.path.join(local_dir, "model.pt"), map_location="cpu")
124
+ model.load_state_dict(state_dict)
125
+ model.eval()
126
+
127
+ items = [
128
+ PairItem(modality="text", value="route this request to the billing workflow"),
129
+ PairItem(modality="image", value="/path/to/screenshot.png"),
130
+ PairItem(modality="audio", value="/path/to/call.wav"),
131
+ ]
132
+
133
+ with torch.no_grad():
134
+ embeddings = model.encode_items(items)
135
+
136
+ print(embeddings.shape) # [3, 768]
137
+
138
+ query = PairItem(modality="text", value="refund request for a wrong charge")
139
+ candidate = PairItem(modality="audio", value="/path/to/refund_call.wav")
140
+
141
+ with torch.no_grad():
142
+ embs = model.encode_items([query, candidate])
143
+
144
+ similarity = F.cosine_similarity(embs[0:1], embs[1:2]).item()
145
+ print(f"similarity={similarity:.4f}")
146
+ ```
147
+
148
+ ## Evaluation snapshot
149
+
150
+ | Metric | Value |
151
+ | --- | ---: |
152
+ | Eval loss | 0.389702 |
153
+ | Eval top1 | 0.861707 |
154
+
155
+ The validation metrics come from the tri-encoder cached retrieval validation path used during export. They are intended as a release sanity snapshot rather than a public leaderboard claim.
156
+
157
+ ## Files
158
+
159
+ | File | Description |
160
+ | --- | --- |
161
+ | `model.pt` | Exported PyTorch weights |
162
+ | `config.json` | Tri-encoder and training/export configuration |
163
+ | `src/hf_st_mm/` | Python package used to construct and run the model |
164
+ | `README.md` | This model card |
165
+
166
+ ## Lineage
167
+
168
+ This ModelScope package is published by `agentic-intelligence-lab` as part of the Elephant model release line. It mirrors the upstream HuggingFace model `llm-semantic-router/multi-modal-embed-large` and keeps the model artifacts unchanged except for the repository naming and model card presentation.
169
+
170
+ ## Limitations
171
+
172
+ - This is a custom PyTorch tri-encoder export, not a standard Transformers auto-class checkpoint.
173
+ - Inference relies on the packaged `hf_st_mm` source code.
174
+ - Image and audio inputs are expected as local file paths in the simple inference path.
175
+ - The model is optimized for retrieval, routing, and similarity, not generation or captioning.
176
+ - Reported validation metrics come from an internal cached retrieval validation set.
177
+
178
+ ## Citation
179
+
180
+ ```bibtex
181
+ @misc{elephant-embeddings-v1-multimodal-large,
182
+ title={Elephant Embeddings V1 Multimodal Large},
183
+ author={Agentic Intelligence Lab},
184
+ year={2026},
185
+ url={https://modelscope.cn/models/agentic-intelligence-lab/elephant-embeddings-v1-multimodal-large}
186
+ }
187
+ ```
188
+
189
+ ## License
190
+
191
+ Apache 2.0
config.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "seed": 42,
3
+ "output_dir": "/scratch/hf_st_mm_outputs/server_datacenter_8gpu_tri_encoder",
4
+ "model": {
5
+ "text_encoder_name": "llm-semantic-router/mmbert-embed-32k-2d-matryoshka",
6
+ "image_encoder_name": "google/siglip2-so400m-patch14-384",
7
+ "audio_encoder_name": "openai/whisper-medium",
8
+ "embedding_dim": 768,
9
+ "max_text_length": 32768
10
+ },
11
+ "training": {
12
+ "epochs": 10,
13
+ "batch_size": 12,
14
+ "grad_accum_steps": 8,
15
+ "num_workers": 4,
16
+ "prefetch_factor": 4,
17
+ "shard_prefetch": 2,
18
+ "shard_cache_limit": 4,
19
+ "sequential_shard_loading": true,
20
+ "shuffle": false,
21
+ "modality_homogeneous_batches": false,
22
+ "learning_rate": 1e-05,
23
+ "weight_decay": 0.01,
24
+ "warmup_ratio": 0.1,
25
+ "max_grad_norm": 1.0,
26
+ "mixed_precision": "bf16",
27
+ "log_every": 10,
28
+ "save_every": 2000,
29
+ "hard_negative_ratio": 0.5
30
+ },
31
+ "loss": {
32
+ "type": "cached_mnrl",
33
+ "scale": 20.0
34
+ },
35
+ "data": {
36
+ "cache_dir": "/scratch/2dmse-data/server_full_datacenter_cache/train"
37
+ },
38
+ "validation": {
39
+ "cache_dir": "/scratch/2dmse-data/server_full_datacenter_cache/val",
40
+ "num_workers": 2,
41
+ "shard_prefetch": 1,
42
+ "shard_cache_limit": 2
43
+ }
44
+ }
configuration.json ADDED
File without changes
model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f5fe61d4864fffb703f53860234a657a2f51f71e393e2dc1b7f635b284cb48c4
3
+ size 6393990436
src/hf_st_mm/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Standalone HF Sentence-Transformers multimodal training package."""
src/hf_st_mm/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (200 Bytes). View file
 
src/hf_st_mm/__pycache__/data.cpython-312.pyc ADDED
Binary file (45.7 kB). View file
 
src/hf_st_mm/__pycache__/model.cpython-312.pyc ADDED
Binary file (14.7 kB). View file
 
src/hf_st_mm/data.py ADDED
@@ -0,0 +1,863 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import math
3
+ import os
4
+ import queue
5
+ import random
6
+ import threading
7
+ from bisect import bisect_right
8
+ from collections import OrderedDict
9
+ from dataclasses import dataclass
10
+ from typing import Any, Dict, Iterable, List, Optional
11
+
12
+ import torch
13
+ from datasets import Dataset, Features, IterableDataset, Value
14
+
15
+
16
+ SUPPORTED_MODALITIES = {"text", "image", "audio"}
17
+
18
+
19
+ @dataclass
20
+ class PairItem:
21
+ modality: str
22
+ value: Any
23
+
24
+
25
+ @dataclass
26
+ class TrainRecord:
27
+ query: PairItem
28
+ positive: PairItem
29
+ negative: Optional[PairItem] = None
30
+
31
+
32
+ def _parse_item(obj: Any, prefix: str) -> PairItem:
33
+ if isinstance(obj, dict):
34
+ modality = obj.get("type")
35
+ value = obj.get("value")
36
+ else:
37
+ modality = None
38
+ value = None
39
+
40
+ if not modality or not value:
41
+ raise ValueError(f"{prefix} must include type/value")
42
+ if modality not in SUPPORTED_MODALITIES:
43
+ raise ValueError(f"Unsupported modality '{modality}' in {prefix}")
44
+ return PairItem(modality=modality, value=value)
45
+
46
+
47
+ def parse_record(raw: Dict[str, Any]) -> TrainRecord:
48
+ if "query" in raw and "positive" in raw:
49
+ query = _parse_item(raw["query"], "query")
50
+ positive = _parse_item(raw["positive"], "positive")
51
+ negative = _parse_item(raw["negative"], "negative") if raw.get("negative") else None
52
+ return TrainRecord(query=query, positive=positive, negative=negative)
53
+
54
+ # Compatibility with common pair formats in existing repos
55
+ if "texts_a" in raw and "texts_b" in raw:
56
+ query = PairItem("text", raw["texts_a"])
57
+ positive = PairItem("text", raw["texts_b"])
58
+ return TrainRecord(query=query, positive=positive)
59
+
60
+ if "image_path" in raw and "caption" in raw:
61
+ query = PairItem("image", raw["image_path"])
62
+ positive = PairItem("text", raw["caption"])
63
+ return TrainRecord(query=query, positive=positive)
64
+
65
+ if "audio_path" in raw and "caption" in raw:
66
+ query = PairItem("audio", raw["audio_path"])
67
+ positive = PairItem("text", raw["caption"])
68
+ return TrainRecord(query=query, positive=positive)
69
+
70
+ raise ValueError("Record does not match supported schemas")
71
+
72
+
73
+ class JsonlManifestDataset:
74
+ def __init__(
75
+ self,
76
+ manifest_path: str,
77
+ image_root: Optional[str] = None,
78
+ audio_root: Optional[str] = None,
79
+ allow_missing_negative: bool = True,
80
+ ) -> None:
81
+ self.manifest_path = manifest_path
82
+ self.image_root = image_root
83
+ self.audio_root = audio_root
84
+ self.allow_missing_negative = allow_missing_negative
85
+ self.records = list(
86
+ iter_manifest_records(
87
+ manifest_path=self.manifest_path,
88
+ image_root=self.image_root,
89
+ audio_root=self.audio_root,
90
+ allow_missing_negative=self.allow_missing_negative,
91
+ )
92
+ )
93
+ if not self.records:
94
+ raise ValueError(f"No records loaded from {self.manifest_path}")
95
+
96
+ def __len__(self) -> int:
97
+ return len(self.records)
98
+
99
+ def __getitem__(self, idx: int) -> TrainRecord:
100
+ return self.records[idx]
101
+
102
+
103
+ class CachedShardDataset:
104
+ def __init__(self, cache_dir: str, shard_cache_limit: int = 2, prefetch_shards: int = 0) -> None:
105
+ self.cache_dir = cache_dir
106
+ self.shard_cache_limit = max(int(shard_cache_limit), 1)
107
+ self.prefetch_shards = max(int(prefetch_shards), 0)
108
+ self.metadata = self._load_metadata()
109
+ self.shard_files = self._discover_shards()
110
+ self.shard_sizes = self._resolve_shard_sizes()
111
+ self.shard_offsets = self._build_offsets(self.shard_sizes)
112
+ self.total_rows = sum(self.shard_sizes)
113
+ self._shard_cache: OrderedDict[int, List[Dict[str, Any]]] = OrderedDict()
114
+ self._init_runtime_state()
115
+
116
+ def _init_runtime_state(self) -> None:
117
+ self._cache_lock = threading.Lock()
118
+ self._prefetch_queue = None
119
+ self._prefetch_thread = None
120
+ self._prefetch_stop = threading.Event()
121
+ self._prefetch_requested: set[int] = set()
122
+ self._prefetch_hits = 0
123
+ self._prefetch_misses = 0
124
+
125
+ def __getstate__(self):
126
+ state = self.__dict__.copy()
127
+ state["_shard_cache"] = OrderedDict(state.get("_shard_cache", OrderedDict()))
128
+ state["_cache_lock"] = None
129
+ state["_prefetch_queue"] = None
130
+ state["_prefetch_thread"] = None
131
+ state["_prefetch_stop"] = None
132
+ state["_prefetch_requested"] = set()
133
+ return state
134
+
135
+ def __setstate__(self, state):
136
+ self.__dict__.update(state)
137
+ self._shard_cache = OrderedDict(self._shard_cache)
138
+ self._init_runtime_state()
139
+
140
+ def _load_metadata(self) -> Dict[str, Any]:
141
+ metadata_path = os.path.join(self.cache_dir, "metadata.json")
142
+ if not os.path.exists(metadata_path):
143
+ return {}
144
+ with open(metadata_path, "r", encoding="utf-8") as handle:
145
+ return json.load(handle)
146
+
147
+ def _discover_shards(self) -> List[str]:
148
+ if not os.path.isdir(self.cache_dir):
149
+ raise FileNotFoundError(f"Cache directory not found: {self.cache_dir}")
150
+ shards: List[str] = []
151
+ for name in sorted(os.listdir(self.cache_dir)):
152
+ if not (name.startswith("shard_") and name.endswith(".pt")):
153
+ continue
154
+ shard_path = os.path.join(self.cache_dir, name)
155
+ shards.append(shard_path)
156
+ if not shards:
157
+ raise ValueError(f"No cache shards found under {self.cache_dir}")
158
+ return shards
159
+
160
+ @staticmethod
161
+ def _build_offsets(shard_sizes: List[int]) -> List[int]:
162
+ offsets: List[int] = []
163
+ running_total = 0
164
+ for shard_size in shard_sizes:
165
+ running_total += shard_size
166
+ offsets.append(running_total)
167
+ return offsets
168
+
169
+ def _resolve_shard_sizes(self) -> List[int]:
170
+ num_shards = len(self.shard_files)
171
+ metadata_num_shards = self.metadata.get("num_shards")
172
+ metadata_num_records = self.metadata.get("num_records")
173
+ shard_size = self.metadata.get("shard_size")
174
+
175
+ if (
176
+ isinstance(metadata_num_shards, int)
177
+ and isinstance(metadata_num_records, int)
178
+ and isinstance(shard_size, int)
179
+ and metadata_num_shards == num_shards
180
+ and metadata_num_records > 0
181
+ and shard_size > 0
182
+ ):
183
+ shard_sizes = [shard_size] * num_shards
184
+ full_rows_before_last = shard_size * max(num_shards - 1, 0)
185
+ shard_sizes[-1] = metadata_num_records - full_rows_before_last
186
+ if shard_sizes[-1] <= 0:
187
+ raise ValueError(f"Invalid metadata in {self.cache_dir}: last shard size computed as {shard_sizes[-1]}")
188
+ return shard_sizes
189
+
190
+ shard_sizes: List[int] = []
191
+ for shard_path in self.shard_files:
192
+ payload = torch.load(shard_path, map_location="cpu", weights_only=False)
193
+ records = payload.get("records")
194
+ if not isinstance(records, list):
195
+ raise ValueError(f"Invalid shard format in {shard_path}")
196
+ shard_sizes.append(len(records))
197
+ return shard_sizes
198
+
199
+ def _store_shard(self, shard_idx: int, records: List[Dict[str, Any]]) -> None:
200
+ with self._cache_lock:
201
+ self._shard_cache[shard_idx] = records
202
+ self._shard_cache.move_to_end(shard_idx)
203
+ while len(self._shard_cache) > self.shard_cache_limit:
204
+ self._shard_cache.popitem(last=False)
205
+
206
+ def _ensure_prefetch_thread(self) -> None:
207
+ if self.prefetch_shards <= 0:
208
+ return
209
+ if self._prefetch_thread is not None and self._prefetch_thread.is_alive():
210
+ return
211
+
212
+ self._prefetch_stop.clear()
213
+ self._prefetch_queue = queue.Queue(maxsize=max(self.prefetch_shards * 2, 1))
214
+ self._prefetch_thread = threading.Thread(
215
+ target=self._prefetch_worker,
216
+ daemon=True,
217
+ name=f"cached-shard-prefetch-{os.getpid()}",
218
+ )
219
+ self._prefetch_thread.start()
220
+
221
+ def _prefetch_worker(self) -> None:
222
+ while not self._prefetch_stop.is_set():
223
+ try:
224
+ shard_idx = self._prefetch_queue.get(timeout=0.1)
225
+ except queue.Empty:
226
+ continue
227
+
228
+ if shard_idx is None:
229
+ continue
230
+
231
+ try:
232
+ with self._cache_lock:
233
+ if shard_idx in self._shard_cache:
234
+ self._prefetch_hits += 1
235
+ continue
236
+ payload = torch.load(self.shard_files[shard_idx], map_location="cpu", weights_only=False)
237
+ records = payload["records"]
238
+ self._store_shard(shard_idx, records)
239
+ self._prefetch_hits += 1
240
+ finally:
241
+ with self._cache_lock:
242
+ self._prefetch_requested.discard(shard_idx)
243
+
244
+ def _schedule_prefetch(self, shard_idx: int) -> None:
245
+ if self.prefetch_shards <= 0:
246
+ return
247
+
248
+ self._ensure_prefetch_thread()
249
+ if self._prefetch_queue is None:
250
+ return
251
+
252
+ for next_idx in range(shard_idx + 1, min(len(self.shard_files), shard_idx + 1 + self.prefetch_shards)):
253
+ with self._cache_lock:
254
+ if next_idx in self._shard_cache or next_idx in self._prefetch_requested:
255
+ continue
256
+ self._prefetch_requested.add(next_idx)
257
+ try:
258
+ self._prefetch_queue.put_nowait(next_idx)
259
+ except queue.Full:
260
+ with self._cache_lock:
261
+ self._prefetch_requested.discard(next_idx)
262
+ break
263
+
264
+ def _load_shard(self, shard_idx: int) -> List[Dict[str, Any]]:
265
+ cached = None
266
+ with self._cache_lock:
267
+ cached = self._shard_cache.get(shard_idx)
268
+ if cached is not None:
269
+ self._shard_cache.move_to_end(shard_idx)
270
+ if cached is not None:
271
+ self._schedule_prefetch(shard_idx)
272
+ return cached
273
+
274
+ self._prefetch_misses += 1
275
+ payload = torch.load(self.shard_files[shard_idx], map_location="cpu", weights_only=False)
276
+ records = payload["records"]
277
+ self._store_shard(shard_idx, records)
278
+ with self._cache_lock:
279
+ self._prefetch_requested.discard(shard_idx)
280
+ self._schedule_prefetch(shard_idx)
281
+ return records
282
+
283
+ @staticmethod
284
+ def _deserialize_item(raw: Optional[Dict[str, Any]]) -> Optional[PairItem]:
285
+ if raw is None:
286
+ return None
287
+ modality = raw["type"]
288
+ if modality == "text" and "tokens" in raw:
289
+ value = raw["tokens"]
290
+ elif modality == "text":
291
+ value = raw["value"]
292
+ elif "tensor" in raw:
293
+ value = raw["tensor"]
294
+ else:
295
+ value = raw.get("value")
296
+ return PairItem(modality=modality, value=value)
297
+
298
+ def __len__(self) -> int:
299
+ return self.total_rows
300
+
301
+ def __getitem__(self, idx: int) -> TrainRecord:
302
+ if idx < 0 or idx >= self.total_rows:
303
+ raise IndexError(idx)
304
+ shard_idx = bisect_right(self.shard_offsets, idx)
305
+ shard_start = 0 if shard_idx == 0 else self.shard_offsets[shard_idx - 1]
306
+ local_idx = idx - shard_start
307
+ raw = self._load_shard(shard_idx)[local_idx]
308
+ return TrainRecord(
309
+ query=self._deserialize_item(raw["query"]),
310
+ positive=self._deserialize_item(raw["positive"]),
311
+ negative=self._deserialize_item(raw.get("negative")),
312
+ )
313
+
314
+ def get_prefetch_stats(self) -> Dict[str, int]:
315
+ with self._cache_lock:
316
+ return {
317
+ "cache_size": len(self._shard_cache),
318
+ "cache_limit": self.shard_cache_limit,
319
+ "prefetch_shards": self.prefetch_shards,
320
+ "prefetch_hits": self._prefetch_hits,
321
+ "prefetch_misses": self._prefetch_misses,
322
+ "prefetch_pending": len(self._prefetch_requested),
323
+ }
324
+
325
+ def close(self) -> None:
326
+ self._prefetch_stop.set()
327
+ if self._prefetch_thread is not None and self._prefetch_thread.is_alive():
328
+ self._prefetch_thread.join(timeout=1.0)
329
+ self._prefetch_thread = None
330
+ self._prefetch_queue = None
331
+
332
+ def __del__(self):
333
+ self.close()
334
+
335
+
336
+ class SequentialShardDataset:
337
+ def __init__(
338
+ self,
339
+ cache_dir: str,
340
+ shuffle: bool = True,
341
+ rank: int = 0,
342
+ world_size: int = 1,
343
+ prefetch_shards: int = 2,
344
+ shard_cache_limit: int = 4,
345
+ ) -> None:
346
+ self.cache_dir = cache_dir
347
+ self.shuffle = shuffle
348
+ self.rank = rank
349
+ self.world_size = max(world_size, 1)
350
+ self.prefetch_shards = max(int(prefetch_shards), 0)
351
+ self.shard_cache_limit = max(int(shard_cache_limit), 1)
352
+
353
+ self.metadata = self._load_metadata()
354
+ self.shard_files = self._discover_shards()
355
+ self.shard_sizes = self._resolve_shard_sizes()
356
+ self.total_rows = sum(self.shard_sizes)
357
+ self.target_shard_size = int(self.metadata.get("shard_size") or max(self.shard_sizes))
358
+
359
+ self._shard_cache: OrderedDict[int, List[Dict[str, Any]]] = OrderedDict()
360
+ self._cache_lock = threading.Lock()
361
+ self._prefetch_queue = None
362
+ self._prefetch_thread = None
363
+ self._prefetch_stop = threading.Event()
364
+ self._prefetch_requested: set[int] = set()
365
+ self._prefetch_hits = 0
366
+ self._prefetch_misses = 0
367
+
368
+ self._all_shard_indices = list(range(len(self.shard_files)))
369
+ self._local_shard_indices: List[int] = []
370
+ self.current_local_shard_pos = -1
371
+ self.current_records: Optional[List[Dict[str, Any]]] = None
372
+
373
+ def _load_metadata(self) -> Dict[str, Any]:
374
+ metadata_path = os.path.join(self.cache_dir, "metadata.json")
375
+ if not os.path.exists(metadata_path):
376
+ return {}
377
+ with open(metadata_path, "r", encoding="utf-8") as handle:
378
+ return json.load(handle)
379
+
380
+ def _discover_shards(self) -> List[str]:
381
+ if not os.path.isdir(self.cache_dir):
382
+ raise FileNotFoundError(f"Cache directory not found: {self.cache_dir}")
383
+ shards: List[str] = []
384
+ for name in sorted(os.listdir(self.cache_dir)):
385
+ if not (name.startswith("shard_") and name.endswith(".pt")):
386
+ continue
387
+ shards.append(os.path.join(self.cache_dir, name))
388
+ if not shards:
389
+ raise ValueError(f"No cache shards found under {self.cache_dir}")
390
+ return shards
391
+
392
+ def _resolve_shard_sizes(self) -> List[int]:
393
+ num_shards = len(self.shard_files)
394
+ metadata_num_shards = self.metadata.get("num_shards")
395
+ metadata_num_records = self.metadata.get("num_records")
396
+ shard_size = self.metadata.get("shard_size")
397
+
398
+ if (
399
+ isinstance(metadata_num_shards, int)
400
+ and isinstance(metadata_num_records, int)
401
+ and isinstance(shard_size, int)
402
+ and metadata_num_shards == num_shards
403
+ and metadata_num_records > 0
404
+ and shard_size > 0
405
+ ):
406
+ shard_sizes = [shard_size] * num_shards
407
+ shard_sizes[-1] = metadata_num_records - shard_size * max(num_shards - 1, 0)
408
+ return shard_sizes
409
+
410
+ shard_sizes: List[int] = []
411
+ for shard_path in self.shard_files:
412
+ payload = torch.load(shard_path, map_location="cpu", weights_only=False)
413
+ records = payload.get("records")
414
+ if not isinstance(records, list):
415
+ raise ValueError(f"Invalid shard format in {shard_path}")
416
+ shard_sizes.append(len(records))
417
+ return shard_sizes
418
+
419
+ @staticmethod
420
+ def _deserialize_item(raw: Optional[Dict[str, Any]]) -> Optional[PairItem]:
421
+ if raw is None:
422
+ return None
423
+ modality = raw["type"]
424
+ if modality == "text" and "tokens" in raw:
425
+ value = raw["tokens"]
426
+ elif modality == "text":
427
+ value = raw["value"]
428
+ elif "tensor" in raw:
429
+ value = raw["tensor"]
430
+ else:
431
+ value = raw.get("value")
432
+ return PairItem(modality=modality, value=value)
433
+
434
+ def _store_shard(self, shard_idx: int, records: List[Dict[str, Any]]) -> None:
435
+ with self._cache_lock:
436
+ self._shard_cache[shard_idx] = records
437
+ self._shard_cache.move_to_end(shard_idx)
438
+ while len(self._shard_cache) > self.shard_cache_limit:
439
+ self._shard_cache.popitem(last=False)
440
+
441
+ def _ensure_prefetch_thread(self) -> None:
442
+ if self.prefetch_shards <= 0:
443
+ return
444
+ if self._prefetch_thread is not None and self._prefetch_thread.is_alive():
445
+ return
446
+ self._prefetch_stop.clear()
447
+ self._prefetch_queue = queue.Queue(maxsize=max(self.prefetch_shards * 2, 1))
448
+ self._prefetch_thread = threading.Thread(
449
+ target=self._prefetch_worker,
450
+ daemon=True,
451
+ name=f"sequential-shard-prefetch-{os.getpid()}",
452
+ )
453
+ self._prefetch_thread.start()
454
+
455
+ def _prefetch_worker(self) -> None:
456
+ while not self._prefetch_stop.is_set():
457
+ try:
458
+ shard_idx = self._prefetch_queue.get(timeout=0.1)
459
+ except queue.Empty:
460
+ continue
461
+ if shard_idx is None:
462
+ continue
463
+ try:
464
+ with self._cache_lock:
465
+ if shard_idx in self._shard_cache:
466
+ self._prefetch_hits += 1
467
+ continue
468
+ payload = torch.load(self.shard_files[shard_idx], map_location="cpu", weights_only=False)
469
+ self._store_shard(shard_idx, payload["records"])
470
+ self._prefetch_hits += 1
471
+ finally:
472
+ with self._cache_lock:
473
+ self._prefetch_requested.discard(shard_idx)
474
+
475
+ def _stop_prefetch_thread(self) -> None:
476
+ self._prefetch_stop.set()
477
+ if self._prefetch_thread is not None and self._prefetch_thread.is_alive():
478
+ self._prefetch_thread.join(timeout=1.0)
479
+ self._prefetch_thread = None
480
+ self._prefetch_queue = None
481
+
482
+ def _schedule_prefetch_from_position(self, local_pos: int) -> None:
483
+ if self.prefetch_shards <= 0:
484
+ return
485
+ self._ensure_prefetch_thread()
486
+ if self._prefetch_queue is None:
487
+ return
488
+ for next_pos in range(local_pos + 1, min(len(self._local_shard_indices), local_pos + 1 + self.prefetch_shards)):
489
+ shard_idx = self._local_shard_indices[next_pos]
490
+ with self._cache_lock:
491
+ if shard_idx in self._shard_cache or shard_idx in self._prefetch_requested:
492
+ continue
493
+ self._prefetch_requested.add(shard_idx)
494
+ try:
495
+ self._prefetch_queue.put_nowait(shard_idx)
496
+ except queue.Full:
497
+ with self._cache_lock:
498
+ self._prefetch_requested.discard(shard_idx)
499
+ break
500
+
501
+ def _build_local_shard_order(self, epoch: int) -> List[int]:
502
+ shard_indices = list(self._all_shard_indices)
503
+ if self.shuffle:
504
+ random.Random(42 + epoch).shuffle(shard_indices)
505
+ local_shards = shard_indices[self.rank::self.world_size]
506
+ max_shards = math.ceil(len(shard_indices) / self.world_size)
507
+ if not local_shards:
508
+ raise ValueError(f"Rank {self.rank} received no shards from {self.cache_dir}")
509
+ while len(local_shards) < max_shards:
510
+ local_shards.append(local_shards[len(local_shards) % len(local_shards)])
511
+ return local_shards
512
+
513
+ def _load_records_for_shard(self, shard_idx: int) -> List[Dict[str, Any]]:
514
+ cached = None
515
+ with self._cache_lock:
516
+ cached = self._shard_cache.get(shard_idx)
517
+ if cached is not None:
518
+ self._shard_cache.move_to_end(shard_idx)
519
+ if cached is not None:
520
+ return cached
521
+
522
+ self._prefetch_misses += 1
523
+ payload = torch.load(self.shard_files[shard_idx], map_location="cpu", weights_only=False)
524
+ records = payload["records"]
525
+ self._store_shard(shard_idx, records)
526
+ with self._cache_lock:
527
+ self._prefetch_requested.discard(shard_idx)
528
+ return records
529
+
530
+ def _pad_records(self, records: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
531
+ if len(records) >= self.target_shard_size:
532
+ return records
533
+ repeat = math.ceil(self.target_shard_size / len(records))
534
+ return (records * repeat)[: self.target_shard_size]
535
+
536
+ def reset(self, epoch: int) -> bool:
537
+ self._stop_prefetch_thread()
538
+ self._local_shard_indices = self._build_local_shard_order(epoch)
539
+ self.current_local_shard_pos = -1
540
+ self.current_records = None
541
+ with self._cache_lock:
542
+ self._prefetch_requested.clear()
543
+ if self.prefetch_shards > 0:
544
+ self._ensure_prefetch_thread()
545
+ return self.next_shard()
546
+
547
+ def next_shard(self) -> bool:
548
+ self.current_local_shard_pos += 1
549
+ if self.current_local_shard_pos >= len(self._local_shard_indices):
550
+ self.current_records = None
551
+ return False
552
+ shard_idx = self._local_shard_indices[self.current_local_shard_pos]
553
+ records = self._load_records_for_shard(shard_idx)
554
+ self.current_records = self._pad_records(records)
555
+ self._schedule_prefetch_from_position(self.current_local_shard_pos)
556
+ return True
557
+
558
+ def __len__(self) -> int:
559
+ return len(self.current_records or [])
560
+
561
+ def __getitem__(self, idx: int) -> TrainRecord:
562
+ if self.current_records is None:
563
+ raise IndexError(idx)
564
+ raw = self.current_records[idx]
565
+ return TrainRecord(
566
+ query=self._deserialize_item(raw["query"]),
567
+ positive=self._deserialize_item(raw["positive"]),
568
+ negative=self._deserialize_item(raw.get("negative")),
569
+ )
570
+
571
+ def estimated_num_batches(self, batch_size: int, drop_last: bool) -> int:
572
+ shard_batches = self.target_shard_size // batch_size if drop_last else math.ceil(self.target_shard_size / batch_size)
573
+ return shard_batches * max(len(self._build_local_shard_order(0)), 1)
574
+
575
+ def get_prefetch_stats(self) -> Dict[str, int]:
576
+ with self._cache_lock:
577
+ return {
578
+ "cache_size": len(self._shard_cache),
579
+ "cache_limit": self.shard_cache_limit,
580
+ "prefetch_shards": self.prefetch_shards,
581
+ "prefetch_hits": self._prefetch_hits,
582
+ "prefetch_misses": self._prefetch_misses,
583
+ "prefetch_pending": len(self._prefetch_requested),
584
+ "local_shards": len(self._local_shard_indices),
585
+ "target_shard_size": self.target_shard_size,
586
+ }
587
+
588
+ def close(self) -> None:
589
+ self._stop_prefetch_thread()
590
+
591
+ def __del__(self):
592
+ self.close()
593
+
594
+
595
+ def _process_shard() -> tuple[int, int]:
596
+ rank = int(os.environ.get("ACCELERATE_PROCESS_INDEX") or os.environ.get("RANK") or 0)
597
+ world_size = int(os.environ.get("WORLD_SIZE") or os.environ.get("ACCELERATE_NUM_PROCESSES") or 1)
598
+ worker_info = torch.utils.data.get_worker_info()
599
+ if worker_info is None:
600
+ return rank, max(world_size, 1)
601
+
602
+ total_shards = max(world_size, 1) * worker_info.num_workers
603
+ shard_id = rank * worker_info.num_workers + worker_info.id
604
+ return shard_id, max(total_shards, 1)
605
+
606
+
607
+ def iter_sentence_transformers_rows(
608
+ manifest_path: str,
609
+ image_root: Optional[str],
610
+ audio_root: Optional[str],
611
+ allow_missing_negative: bool,
612
+ allowed_modalities: Optional[List[str]],
613
+ query_modalities: Optional[List[str]],
614
+ positive_modalities: Optional[List[str]],
615
+ negative_modalities: Optional[List[str]],
616
+ use_negative_column: bool,
617
+ ):
618
+ allowed = set(allowed_modalities or [])
619
+ allowed_query = set(query_modalities or [])
620
+ allowed_positive = set(positive_modalities or [])
621
+ allowed_negative = set(negative_modalities or [])
622
+ shard_id, total_shards = _process_shard()
623
+ matched_index = 0
624
+
625
+ for record in iter_manifest_records(
626
+ manifest_path=manifest_path,
627
+ image_root=image_root,
628
+ audio_root=audio_root,
629
+ allow_missing_negative=allow_missing_negative,
630
+ ):
631
+ if not record_matches_filters(
632
+ record,
633
+ allowed=allowed,
634
+ allowed_query=allowed_query,
635
+ allowed_positive=allowed_positive,
636
+ allowed_negative=allowed_negative,
637
+ ):
638
+ continue
639
+
640
+ if matched_index % total_shards == shard_id:
641
+ yield record_to_sentence_transformers_row(record, include_negative=use_negative_column)
642
+ matched_index += 1
643
+
644
+
645
+ def collate_records(batch: List[TrainRecord]) -> Dict[str, List[PairItem]]:
646
+ return {
647
+ "query": [r.query for r in batch],
648
+ "positive": [r.positive for r in batch],
649
+ "negative": [r.negative for r in batch],
650
+ }
651
+
652
+
653
+ def sentence_transformers_input(item: PairItem) -> Any:
654
+ payload: Dict[str, Any] = {}
655
+ if item.modality == "text":
656
+ payload["text"] = item.value
657
+ return payload
658
+ if item.modality == "image":
659
+ payload["image"] = item.value
660
+ return payload
661
+ if item.modality == "audio":
662
+ payload["audio"] = item.value
663
+ return payload
664
+ return item.value
665
+
666
+
667
+ def resolve_media(item: PairItem, image_root: Optional[str], audio_root: Optional[str]) -> PairItem:
668
+ if item.modality == "image" and image_root and not os.path.isabs(item.value):
669
+ return PairItem(item.modality, os.path.join(image_root, item.value))
670
+ if item.modality == "audio" and audio_root and not os.path.isabs(item.value):
671
+ return PairItem(item.modality, os.path.join(audio_root, item.value))
672
+ return item
673
+
674
+
675
+ def iter_manifest_records(
676
+ manifest_path: str,
677
+ image_root: Optional[str] = None,
678
+ audio_root: Optional[str] = None,
679
+ allow_missing_negative: bool = True,
680
+ ) -> Iterable[TrainRecord]:
681
+ if not os.path.exists(manifest_path):
682
+ raise FileNotFoundError(f"Manifest not found: {manifest_path}")
683
+
684
+ with open(manifest_path, "r", encoding="utf-8") as handle:
685
+ for line_no, line in enumerate(handle, start=1):
686
+ line = line.strip()
687
+ if not line:
688
+ continue
689
+ raw = json.loads(line)
690
+ record = parse_record(raw)
691
+ record = TrainRecord(
692
+ query=resolve_media(record.query, image_root, audio_root),
693
+ positive=resolve_media(record.positive, image_root, audio_root),
694
+ negative=resolve_media(record.negative, image_root, audio_root) if record.negative else None,
695
+ )
696
+ if record.negative is None and not allow_missing_negative:
697
+ raise ValueError(f"Missing negative at line {line_no}")
698
+ yield record
699
+
700
+
701
+ def record_matches_filters(
702
+ record: TrainRecord,
703
+ allowed: set[str],
704
+ allowed_query: set[str],
705
+ allowed_positive: set[str],
706
+ allowed_negative: set[str],
707
+ ) -> bool:
708
+ record_modalities = {record.query.modality, record.positive.modality}
709
+ if record.negative is not None:
710
+ record_modalities.add(record.negative.modality)
711
+ if allowed and not record_modalities.issubset(allowed):
712
+ return False
713
+ if allowed_query and record.query.modality not in allowed_query:
714
+ return False
715
+ if allowed_positive and record.positive.modality not in allowed_positive:
716
+ return False
717
+ if record.negative is not None and allowed_negative and record.negative.modality not in allowed_negative:
718
+ return False
719
+ return True
720
+
721
+
722
+ def record_to_sentence_transformers_row(record: TrainRecord, include_negative: bool) -> Dict[str, Any]:
723
+ row = {
724
+ "query": sentence_transformers_input(record.query),
725
+ "positive": sentence_transformers_input(record.positive),
726
+ }
727
+ if include_negative and record.negative is not None:
728
+ row["negative_0"] = sentence_transformers_input(record.negative)
729
+ return row
730
+
731
+
732
+ def summarize_manifest_records(
733
+ manifest_path: str,
734
+ image_root: Optional[str] = None,
735
+ audio_root: Optional[str] = None,
736
+ allow_missing_negative: bool = True,
737
+ allowed_modalities: Optional[List[str]] = None,
738
+ query_modalities: Optional[List[str]] = None,
739
+ positive_modalities: Optional[List[str]] = None,
740
+ negative_modalities: Optional[List[str]] = None,
741
+ max_records: Optional[int] = None,
742
+ ) -> Dict[str, Any]:
743
+ modalities = set()
744
+ negatives_present = 0
745
+ negatives_missing = 0
746
+ skipped_rows = 0
747
+ num_rows = 0
748
+ allowed = set(allowed_modalities or [])
749
+ allowed_query = set(query_modalities or [])
750
+ allowed_positive = set(positive_modalities or [])
751
+ allowed_negative = set(negative_modalities or [])
752
+
753
+ for record in iter_manifest_records(
754
+ manifest_path=manifest_path,
755
+ image_root=image_root,
756
+ audio_root=audio_root,
757
+ allow_missing_negative=allow_missing_negative,
758
+ ):
759
+ if not record_matches_filters(
760
+ record,
761
+ allowed=allowed,
762
+ allowed_query=allowed_query,
763
+ allowed_positive=allowed_positive,
764
+ allowed_negative=allowed_negative,
765
+ ):
766
+ skipped_rows += 1
767
+ continue
768
+
769
+ modalities.add(record.query.modality)
770
+ modalities.add(record.positive.modality)
771
+ if record.negative is not None:
772
+ modalities.add(record.negative.modality)
773
+ negatives_present += 1
774
+ else:
775
+ negatives_missing += 1
776
+ num_rows += 1
777
+ if max_records is not None and num_rows >= max_records:
778
+ break
779
+
780
+ if num_rows == 0:
781
+ raise ValueError(f"No records loaded from {manifest_path}")
782
+
783
+ return {
784
+ "modalities": sorted(modalities),
785
+ "num_rows": num_rows,
786
+ "has_uniform_negatives": negatives_present > 0 and negatives_missing == 0,
787
+ "num_negatives_present": negatives_present,
788
+ "num_negatives_missing": negatives_missing,
789
+ "skipped_rows": skipped_rows,
790
+ }
791
+
792
+
793
+ def manifest_to_sentence_transformers_dataset(
794
+ manifest_path: str,
795
+ image_root: Optional[str] = None,
796
+ audio_root: Optional[str] = None,
797
+ allow_missing_negative: bool = True,
798
+ allowed_modalities: Optional[List[str]] = None,
799
+ query_modalities: Optional[List[str]] = None,
800
+ positive_modalities: Optional[List[str]] = None,
801
+ negative_modalities: Optional[List[str]] = None,
802
+ as_iterable: bool = False,
803
+ max_records: Optional[int] = None,
804
+ ) -> tuple[Dataset | IterableDataset, Dict[str, Any]]:
805
+ info = summarize_manifest_records(
806
+ manifest_path=manifest_path,
807
+ image_root=image_root,
808
+ audio_root=audio_root,
809
+ allow_missing_negative=allow_missing_negative,
810
+ allowed_modalities=allowed_modalities,
811
+ query_modalities=query_modalities,
812
+ positive_modalities=positive_modalities,
813
+ negative_modalities=negative_modalities,
814
+ max_records=max_records,
815
+ )
816
+
817
+ dataset_out: Dataset | IterableDataset
818
+ if as_iterable:
819
+ column_names = ["query", "positive"]
820
+ if info["has_uniform_negatives"]:
821
+ column_names.append("negative_0")
822
+ dataset_out = IterableDataset.from_generator(
823
+ iter_sentence_transformers_rows,
824
+ features=Features({key: Value("null") for key in column_names}),
825
+ gen_kwargs={
826
+ "manifest_path": manifest_path,
827
+ "image_root": image_root,
828
+ "audio_root": audio_root,
829
+ "allow_missing_negative": allow_missing_negative,
830
+ "allowed_modalities": allowed_modalities,
831
+ "query_modalities": query_modalities,
832
+ "positive_modalities": positive_modalities,
833
+ "negative_modalities": negative_modalities,
834
+ "use_negative_column": info["has_uniform_negatives"],
835
+ },
836
+ )
837
+ else:
838
+ dataset = JsonlManifestDataset(
839
+ manifest_path=manifest_path,
840
+ image_root=image_root,
841
+ audio_root=audio_root,
842
+ allow_missing_negative=allow_missing_negative,
843
+ )
844
+ allowed = set(allowed_modalities or [])
845
+ allowed_query = set(query_modalities or [])
846
+ allowed_positive = set(positive_modalities or [])
847
+ allowed_negative = set(negative_modalities or [])
848
+ rows: List[Dict[str, Any]] = []
849
+ for record in dataset.records:
850
+ if not record_matches_filters(
851
+ record,
852
+ allowed=allowed,
853
+ allowed_query=allowed_query,
854
+ allowed_positive=allowed_positive,
855
+ allowed_negative=allowed_negative,
856
+ ):
857
+ continue
858
+ rows.append(record_to_sentence_transformers_row(record, include_negative=info["has_uniform_negatives"]))
859
+ if max_records is not None and len(rows) >= max_records:
860
+ break
861
+ dataset_out = Dataset.from_list(rows)
862
+
863
+ return dataset_out, info
src/hf_st_mm/model.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from typing import Any, Dict, List
3
+
4
+ import librosa
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch.nn.utils.rnn import pad_sequence
9
+ from PIL import Image
10
+ from sentence_transformers import SentenceTransformer
11
+ from transformers import AutoModel, AutoProcessor, WhisperFeatureExtractor, WhisperModel
12
+
13
+ from .data import PairItem
14
+
15
+
16
+ class MultiModalSentenceEmbedder(nn.Module):
17
+ def __init__(
18
+ self,
19
+ text_encoder_name: str,
20
+ image_encoder_name: str,
21
+ audio_encoder_name: str,
22
+ embedding_dim: int,
23
+ max_text_length: int,
24
+ ) -> None:
25
+ super().__init__()
26
+ self.text_model = SentenceTransformer(text_encoder_name)
27
+ self.text_model.max_seq_length = max_text_length
28
+
29
+ self.image_model = AutoModel.from_pretrained(image_encoder_name, trust_remote_code=True)
30
+ self.image_processor = AutoProcessor.from_pretrained(image_encoder_name, trust_remote_code=True)
31
+
32
+ whisper = WhisperModel.from_pretrained(audio_encoder_name)
33
+ self.audio_model = whisper.encoder
34
+ self.audio_processor = WhisperFeatureExtractor.from_pretrained(audio_encoder_name)
35
+
36
+ text_dim = self.text_model.get_sentence_embedding_dimension()
37
+ image_dim = self._get_vision_dim(self.image_model)
38
+ audio_dim = whisper.config.d_model
39
+
40
+ self.text_proj = nn.Linear(text_dim, embedding_dim) if text_dim != embedding_dim else nn.Identity()
41
+ self.image_proj = nn.Linear(image_dim, embedding_dim) if image_dim != embedding_dim else nn.Identity()
42
+ self.audio_proj = nn.Linear(audio_dim, embedding_dim) if audio_dim != embedding_dim else nn.Identity()
43
+
44
+ @staticmethod
45
+ def _get_vision_dim(model: nn.Module) -> int:
46
+ if hasattr(model, "vision_model") and hasattr(model.config, "vision_config"):
47
+ return int(model.config.vision_config.hidden_size)
48
+ if hasattr(model.config, "hidden_size"):
49
+ return int(model.config.hidden_size)
50
+ raise ValueError("Could not infer image hidden size")
51
+
52
+ def _encode_text(self, texts: List[Any]) -> torch.Tensor:
53
+ device = next(self.parameters()).device
54
+ normalized: List[torch.Tensor | None] = [None] * len(texts)
55
+
56
+ dict_positions = [idx for idx, item in enumerate(texts) if isinstance(item, dict)]
57
+ if dict_positions:
58
+ pad_values = {
59
+ "input_ids": 0,
60
+ "attention_mask": 0,
61
+ "token_type_ids": 0,
62
+ }
63
+ dict_items = [texts[idx] for idx in dict_positions]
64
+ features = {
65
+ key: pad_sequence(
66
+ [item[key].detach().cpu() for item in dict_items],
67
+ batch_first=True,
68
+ padding_value=pad_values.get(key, 0),
69
+ ).to(device)
70
+ for key in dict_items[0].keys()
71
+ }
72
+ out = self.text_model(features)
73
+ emb = F.normalize(self.text_proj(out["sentence_embedding"]), p=2, dim=-1)
74
+ for loc, row in zip(dict_positions, emb):
75
+ normalized[loc] = row
76
+
77
+ raw_positions = [idx for idx, item in enumerate(texts) if not isinstance(item, dict)]
78
+ if raw_positions:
79
+ raw_texts = [texts[idx] for idx in raw_positions]
80
+ features = self.text_model.tokenize(raw_texts)
81
+ features = {
82
+ k: (v.to(device) if hasattr(v, "to") else v)
83
+ for k, v in features.items()
84
+ }
85
+ out = self.text_model(features)
86
+ emb = F.normalize(self.text_proj(out["sentence_embedding"]), p=2, dim=-1)
87
+ for loc, row in zip(raw_positions, emb):
88
+ normalized[loc] = row
89
+
90
+ return torch.stack([row for row in normalized if row is not None], dim=0)
91
+
92
+ def _encode_image_paths(self, paths: List[str]) -> torch.Tensor:
93
+ images = [Image.open(path).convert("RGB") for path in paths]
94
+ proc = self.image_processor(images=images, return_tensors="pt")
95
+ device = next(self.parameters()).device
96
+ proc = {k: v.to(device) for k, v in proc.items()}
97
+ return self._encode_image_pixel_values(proc["pixel_values"])
98
+
99
+ def _encode_image_pixel_values(self, pixel_values: torch.Tensor) -> torch.Tensor:
100
+ device = next(self.parameters()).device
101
+ proc = {"pixel_values": pixel_values.to(device)}
102
+ if hasattr(self.image_model, "vision_model"):
103
+ out = self.image_model.vision_model(**proc, output_hidden_states=False)
104
+ hidden = out.last_hidden_state
105
+ else:
106
+ out = self.image_model(**proc, output_hidden_states=False)
107
+ hidden = out.last_hidden_state
108
+ pooled = hidden[:, 1:].mean(dim=1) if hidden.shape[1] > 1 else hidden.mean(dim=1)
109
+ emb = self.image_proj(pooled)
110
+ return F.normalize(emb, p=2, dim=-1)
111
+
112
+ def _encode_audio_paths(self, paths: List[str]) -> torch.Tensor:
113
+ waves = [librosa.load(path, sr=16000, mono=True)[0] for path in paths]
114
+ proc = self.audio_processor(waves, sampling_rate=16000, return_tensors="pt")
115
+ return self._encode_audio_features(proc["input_features"])
116
+
117
+ def _encode_audio_features(self, input_features: torch.Tensor) -> torch.Tensor:
118
+ device = next(self.parameters()).device
119
+ input_features = input_features.to(device)
120
+ input_features = input_features.to(self.audio_model.conv1.weight.dtype)
121
+ out = self.audio_model(input_features=input_features, output_hidden_states=False)
122
+ pooled = out.last_hidden_state.mean(dim=1)
123
+ emb = self.audio_proj(pooled)
124
+ return F.normalize(emb, p=2, dim=-1)
125
+
126
+ @staticmethod
127
+ def _stack_tensor_values(values: List[Any]) -> torch.Tensor:
128
+ tensors = []
129
+ for value in values:
130
+ if not torch.is_tensor(value):
131
+ raise TypeError("Expected tensor payload in cached item")
132
+ tensor = value.detach().cpu()
133
+ if tensor.dim() > 0 and tensor.shape[0] == 1:
134
+ tensor = tensor.squeeze(0)
135
+ tensors.append(tensor)
136
+ return torch.stack(tensors, dim=0)
137
+
138
+ def encode_items(self, items: List[PairItem]) -> torch.Tensor:
139
+ grouped = defaultdict(list)
140
+ for idx, item in enumerate(items):
141
+ grouped[item.modality].append((idx, item.value))
142
+
143
+ device = next(self.parameters()).device
144
+ out = [None] * len(items)
145
+
146
+ if grouped["text"]:
147
+ idxs, vals = zip(*grouped["text"])
148
+ embs = self._encode_text(list(vals))
149
+ for loc, emb in zip(idxs, embs):
150
+ out[loc] = emb
151
+
152
+ if grouped["image"]:
153
+ idxs, vals = zip(*grouped["image"])
154
+ tensor_pairs = [(idx, val) for idx, val in zip(idxs, vals) if torch.is_tensor(val)]
155
+ path_pairs = [(idx, val) for idx, val in zip(idxs, vals) if not torch.is_tensor(val)]
156
+ if path_pairs:
157
+ p_idxs, p_vals = zip(*path_pairs)
158
+ embs = self._encode_image_paths(list(p_vals))
159
+ for loc, emb in zip(p_idxs, embs):
160
+ out[loc] = emb
161
+ if tensor_pairs:
162
+ t_idxs, t_vals = zip(*tensor_pairs)
163
+ embs = self._encode_image_pixel_values(self._stack_tensor_values(list(t_vals)))
164
+ for loc, emb in zip(t_idxs, embs):
165
+ out[loc] = emb
166
+
167
+ if grouped["audio"]:
168
+ idxs, vals = zip(*grouped["audio"])
169
+ tensor_pairs = [(idx, val) for idx, val in zip(idxs, vals) if torch.is_tensor(val)]
170
+ path_pairs = [(idx, val) for idx, val in zip(idxs, vals) if not torch.is_tensor(val)]
171
+ if path_pairs:
172
+ p_idxs, p_vals = zip(*path_pairs)
173
+ embs = self._encode_audio_paths(list(p_vals))
174
+ for loc, emb in zip(p_idxs, embs):
175
+ out[loc] = emb
176
+ if tensor_pairs:
177
+ t_idxs, t_vals = zip(*tensor_pairs)
178
+ embs = self._encode_audio_features(self._stack_tensor_values(list(t_vals)))
179
+ for loc, emb in zip(t_idxs, embs):
180
+ out[loc] = emb
181
+
182
+ stacked = torch.stack(out, dim=0).to(device=device, dtype=torch.float32)
183
+ return F.normalize(stacked, p=2, dim=-1)
184
+
185
+
186
+ def multiple_negatives_ranking_loss(anchor: torch.Tensor, positive: torch.Tensor, scale: float = 20.0) -> torch.Tensor:
187
+ scores = torch.matmul(anchor, positive.T) * scale
188
+ labels = torch.arange(scores.shape[0], device=scores.device)
189
+ loss_a = torch.nn.functional.cross_entropy(scores, labels)
190
+ loss_b = torch.nn.functional.cross_entropy(scores.T, labels)
191
+ return (loss_a + loss_b) * 0.5