Sentence Similarity
PyTorch
sentence-transformers
multimodal
embeddings
retrieval
image-text
audio-text
text-image-audio
tri-encoder
semantic-router
Eval Results (legacy)
Instructions to use llm-semantic-router/multi-modal-embed-large with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- sentence-transformers
How to use llm-semantic-router/multi-modal-embed-large with sentence-transformers:
from sentence_transformers import SentenceTransformer model = SentenceTransformer("llm-semantic-router/multi-modal-embed-large") sentences = [ "That is a happy person", "That is a happy dog", "That is a very happy person", "Today is a sunny day" ] embeddings = model.encode(sentences) similarities = model.similarity(embeddings, embeddings) print(similarities.shape) # [4, 4] - Notebooks
- Google Colab
- Kaggle
| import json | |
| import math | |
| import os | |
| import queue | |
| import random | |
| import threading | |
| from bisect import bisect_right | |
| from collections import OrderedDict | |
| from dataclasses import dataclass | |
| from typing import Any, Dict, Iterable, List, Optional | |
| import torch | |
| from datasets import Dataset, Features, IterableDataset, Value | |
| SUPPORTED_MODALITIES = {"text", "image", "audio"} | |
| class PairItem: | |
| modality: str | |
| value: Any | |
| class TrainRecord: | |
| query: PairItem | |
| positive: PairItem | |
| negative: Optional[PairItem] = None | |
| def _parse_item(obj: Any, prefix: str) -> PairItem: | |
| if isinstance(obj, dict): | |
| modality = obj.get("type") | |
| value = obj.get("value") | |
| else: | |
| modality = None | |
| value = None | |
| if not modality or not value: | |
| raise ValueError(f"{prefix} must include type/value") | |
| if modality not in SUPPORTED_MODALITIES: | |
| raise ValueError(f"Unsupported modality '{modality}' in {prefix}") | |
| return PairItem(modality=modality, value=value) | |
| def parse_record(raw: Dict[str, Any]) -> TrainRecord: | |
| if "query" in raw and "positive" in raw: | |
| query = _parse_item(raw["query"], "query") | |
| positive = _parse_item(raw["positive"], "positive") | |
| negative = _parse_item(raw["negative"], "negative") if raw.get("negative") else None | |
| return TrainRecord(query=query, positive=positive, negative=negative) | |
| # Compatibility with common pair formats in existing repos | |
| if "texts_a" in raw and "texts_b" in raw: | |
| query = PairItem("text", raw["texts_a"]) | |
| positive = PairItem("text", raw["texts_b"]) | |
| return TrainRecord(query=query, positive=positive) | |
| if "image_path" in raw and "caption" in raw: | |
| query = PairItem("image", raw["image_path"]) | |
| positive = PairItem("text", raw["caption"]) | |
| return TrainRecord(query=query, positive=positive) | |
| if "audio_path" in raw and "caption" in raw: | |
| query = PairItem("audio", raw["audio_path"]) | |
| positive = PairItem("text", raw["caption"]) | |
| return TrainRecord(query=query, positive=positive) | |
| raise ValueError("Record does not match supported schemas") | |
| class JsonlManifestDataset: | |
| def __init__( | |
| self, | |
| manifest_path: str, | |
| image_root: Optional[str] = None, | |
| audio_root: Optional[str] = None, | |
| allow_missing_negative: bool = True, | |
| ) -> None: | |
| self.manifest_path = manifest_path | |
| self.image_root = image_root | |
| self.audio_root = audio_root | |
| self.allow_missing_negative = allow_missing_negative | |
| self.records = list( | |
| iter_manifest_records( | |
| manifest_path=self.manifest_path, | |
| image_root=self.image_root, | |
| audio_root=self.audio_root, | |
| allow_missing_negative=self.allow_missing_negative, | |
| ) | |
| ) | |
| if not self.records: | |
| raise ValueError(f"No records loaded from {self.manifest_path}") | |
| def __len__(self) -> int: | |
| return len(self.records) | |
| def __getitem__(self, idx: int) -> TrainRecord: | |
| return self.records[idx] | |
| class CachedShardDataset: | |
| def __init__(self, cache_dir: str, shard_cache_limit: int = 2, prefetch_shards: int = 0) -> None: | |
| self.cache_dir = cache_dir | |
| self.shard_cache_limit = max(int(shard_cache_limit), 1) | |
| self.prefetch_shards = max(int(prefetch_shards), 0) | |
| self.metadata = self._load_metadata() | |
| self.shard_files = self._discover_shards() | |
| self.shard_sizes = self._resolve_shard_sizes() | |
| self.shard_offsets = self._build_offsets(self.shard_sizes) | |
| self.total_rows = sum(self.shard_sizes) | |
| self._shard_cache: OrderedDict[int, List[Dict[str, Any]]] = OrderedDict() | |
| self._init_runtime_state() | |
| def _init_runtime_state(self) -> None: | |
| self._cache_lock = threading.Lock() | |
| self._prefetch_queue = None | |
| self._prefetch_thread = None | |
| self._prefetch_stop = threading.Event() | |
| self._prefetch_requested: set[int] = set() | |
| self._prefetch_hits = 0 | |
| self._prefetch_misses = 0 | |
| def __getstate__(self): | |
| state = self.__dict__.copy() | |
| state["_shard_cache"] = OrderedDict(state.get("_shard_cache", OrderedDict())) | |
| state["_cache_lock"] = None | |
| state["_prefetch_queue"] = None | |
| state["_prefetch_thread"] = None | |
| state["_prefetch_stop"] = None | |
| state["_prefetch_requested"] = set() | |
| return state | |
| def __setstate__(self, state): | |
| self.__dict__.update(state) | |
| self._shard_cache = OrderedDict(self._shard_cache) | |
| self._init_runtime_state() | |
| def _load_metadata(self) -> Dict[str, Any]: | |
| metadata_path = os.path.join(self.cache_dir, "metadata.json") | |
| if not os.path.exists(metadata_path): | |
| return {} | |
| with open(metadata_path, "r", encoding="utf-8") as handle: | |
| return json.load(handle) | |
| def _discover_shards(self) -> List[str]: | |
| if not os.path.isdir(self.cache_dir): | |
| raise FileNotFoundError(f"Cache directory not found: {self.cache_dir}") | |
| shards: List[str] = [] | |
| for name in sorted(os.listdir(self.cache_dir)): | |
| if not (name.startswith("shard_") and name.endswith(".pt")): | |
| continue | |
| shard_path = os.path.join(self.cache_dir, name) | |
| shards.append(shard_path) | |
| if not shards: | |
| raise ValueError(f"No cache shards found under {self.cache_dir}") | |
| return shards | |
| def _build_offsets(shard_sizes: List[int]) -> List[int]: | |
| offsets: List[int] = [] | |
| running_total = 0 | |
| for shard_size in shard_sizes: | |
| running_total += shard_size | |
| offsets.append(running_total) | |
| return offsets | |
| def _resolve_shard_sizes(self) -> List[int]: | |
| num_shards = len(self.shard_files) | |
| metadata_num_shards = self.metadata.get("num_shards") | |
| metadata_num_records = self.metadata.get("num_records") | |
| shard_size = self.metadata.get("shard_size") | |
| if ( | |
| isinstance(metadata_num_shards, int) | |
| and isinstance(metadata_num_records, int) | |
| and isinstance(shard_size, int) | |
| and metadata_num_shards == num_shards | |
| and metadata_num_records > 0 | |
| and shard_size > 0 | |
| ): | |
| shard_sizes = [shard_size] * num_shards | |
| full_rows_before_last = shard_size * max(num_shards - 1, 0) | |
| shard_sizes[-1] = metadata_num_records - full_rows_before_last | |
| if shard_sizes[-1] <= 0: | |
| raise ValueError(f"Invalid metadata in {self.cache_dir}: last shard size computed as {shard_sizes[-1]}") | |
| return shard_sizes | |
| shard_sizes: List[int] = [] | |
| for shard_path in self.shard_files: | |
| payload = torch.load(shard_path, map_location="cpu", weights_only=False) | |
| records = payload.get("records") | |
| if not isinstance(records, list): | |
| raise ValueError(f"Invalid shard format in {shard_path}") | |
| shard_sizes.append(len(records)) | |
| return shard_sizes | |
| def _store_shard(self, shard_idx: int, records: List[Dict[str, Any]]) -> None: | |
| with self._cache_lock: | |
| self._shard_cache[shard_idx] = records | |
| self._shard_cache.move_to_end(shard_idx) | |
| while len(self._shard_cache) > self.shard_cache_limit: | |
| self._shard_cache.popitem(last=False) | |
| def _ensure_prefetch_thread(self) -> None: | |
| if self.prefetch_shards <= 0: | |
| return | |
| if self._prefetch_thread is not None and self._prefetch_thread.is_alive(): | |
| return | |
| self._prefetch_stop.clear() | |
| self._prefetch_queue = queue.Queue(maxsize=max(self.prefetch_shards * 2, 1)) | |
| self._prefetch_thread = threading.Thread( | |
| target=self._prefetch_worker, | |
| daemon=True, | |
| name=f"cached-shard-prefetch-{os.getpid()}", | |
| ) | |
| self._prefetch_thread.start() | |
| def _prefetch_worker(self) -> None: | |
| while not self._prefetch_stop.is_set(): | |
| try: | |
| shard_idx = self._prefetch_queue.get(timeout=0.1) | |
| except queue.Empty: | |
| continue | |
| if shard_idx is None: | |
| continue | |
| try: | |
| with self._cache_lock: | |
| if shard_idx in self._shard_cache: | |
| self._prefetch_hits += 1 | |
| continue | |
| payload = torch.load(self.shard_files[shard_idx], map_location="cpu", weights_only=False) | |
| records = payload["records"] | |
| self._store_shard(shard_idx, records) | |
| self._prefetch_hits += 1 | |
| finally: | |
| with self._cache_lock: | |
| self._prefetch_requested.discard(shard_idx) | |
| def _schedule_prefetch(self, shard_idx: int) -> None: | |
| if self.prefetch_shards <= 0: | |
| return | |
| self._ensure_prefetch_thread() | |
| if self._prefetch_queue is None: | |
| return | |
| for next_idx in range(shard_idx + 1, min(len(self.shard_files), shard_idx + 1 + self.prefetch_shards)): | |
| with self._cache_lock: | |
| if next_idx in self._shard_cache or next_idx in self._prefetch_requested: | |
| continue | |
| self._prefetch_requested.add(next_idx) | |
| try: | |
| self._prefetch_queue.put_nowait(next_idx) | |
| except queue.Full: | |
| with self._cache_lock: | |
| self._prefetch_requested.discard(next_idx) | |
| break | |
| def _load_shard(self, shard_idx: int) -> List[Dict[str, Any]]: | |
| cached = None | |
| with self._cache_lock: | |
| cached = self._shard_cache.get(shard_idx) | |
| if cached is not None: | |
| self._shard_cache.move_to_end(shard_idx) | |
| if cached is not None: | |
| self._schedule_prefetch(shard_idx) | |
| return cached | |
| self._prefetch_misses += 1 | |
| payload = torch.load(self.shard_files[shard_idx], map_location="cpu", weights_only=False) | |
| records = payload["records"] | |
| self._store_shard(shard_idx, records) | |
| with self._cache_lock: | |
| self._prefetch_requested.discard(shard_idx) | |
| self._schedule_prefetch(shard_idx) | |
| return records | |
| def _deserialize_item(raw: Optional[Dict[str, Any]]) -> Optional[PairItem]: | |
| if raw is None: | |
| return None | |
| modality = raw["type"] | |
| if modality == "text" and "tokens" in raw: | |
| value = raw["tokens"] | |
| elif modality == "text": | |
| value = raw["value"] | |
| elif "tensor" in raw: | |
| value = raw["tensor"] | |
| else: | |
| value = raw.get("value") | |
| return PairItem(modality=modality, value=value) | |
| def __len__(self) -> int: | |
| return self.total_rows | |
| def __getitem__(self, idx: int) -> TrainRecord: | |
| if idx < 0 or idx >= self.total_rows: | |
| raise IndexError(idx) | |
| shard_idx = bisect_right(self.shard_offsets, idx) | |
| shard_start = 0 if shard_idx == 0 else self.shard_offsets[shard_idx - 1] | |
| local_idx = idx - shard_start | |
| raw = self._load_shard(shard_idx)[local_idx] | |
| return TrainRecord( | |
| query=self._deserialize_item(raw["query"]), | |
| positive=self._deserialize_item(raw["positive"]), | |
| negative=self._deserialize_item(raw.get("negative")), | |
| ) | |
| def get_prefetch_stats(self) -> Dict[str, int]: | |
| with self._cache_lock: | |
| return { | |
| "cache_size": len(self._shard_cache), | |
| "cache_limit": self.shard_cache_limit, | |
| "prefetch_shards": self.prefetch_shards, | |
| "prefetch_hits": self._prefetch_hits, | |
| "prefetch_misses": self._prefetch_misses, | |
| "prefetch_pending": len(self._prefetch_requested), | |
| } | |
| def close(self) -> None: | |
| self._prefetch_stop.set() | |
| if self._prefetch_thread is not None and self._prefetch_thread.is_alive(): | |
| self._prefetch_thread.join(timeout=1.0) | |
| self._prefetch_thread = None | |
| self._prefetch_queue = None | |
| def __del__(self): | |
| self.close() | |
| class SequentialShardDataset: | |
| def __init__( | |
| self, | |
| cache_dir: str, | |
| shuffle: bool = True, | |
| rank: int = 0, | |
| world_size: int = 1, | |
| prefetch_shards: int = 2, | |
| shard_cache_limit: int = 4, | |
| ) -> None: | |
| self.cache_dir = cache_dir | |
| self.shuffle = shuffle | |
| self.rank = rank | |
| self.world_size = max(world_size, 1) | |
| self.prefetch_shards = max(int(prefetch_shards), 0) | |
| self.shard_cache_limit = max(int(shard_cache_limit), 1) | |
| self.metadata = self._load_metadata() | |
| self.shard_files = self._discover_shards() | |
| self.shard_sizes = self._resolve_shard_sizes() | |
| self.total_rows = sum(self.shard_sizes) | |
| self.target_shard_size = int(self.metadata.get("shard_size") or max(self.shard_sizes)) | |
| self._shard_cache: OrderedDict[int, List[Dict[str, Any]]] = OrderedDict() | |
| self._cache_lock = threading.Lock() | |
| self._prefetch_queue = None | |
| self._prefetch_thread = None | |
| self._prefetch_stop = threading.Event() | |
| self._prefetch_requested: set[int] = set() | |
| self._prefetch_hits = 0 | |
| self._prefetch_misses = 0 | |
| self._all_shard_indices = list(range(len(self.shard_files))) | |
| self._local_shard_indices: List[int] = [] | |
| self.current_local_shard_pos = -1 | |
| self.current_records: Optional[List[Dict[str, Any]]] = None | |
| def _load_metadata(self) -> Dict[str, Any]: | |
| metadata_path = os.path.join(self.cache_dir, "metadata.json") | |
| if not os.path.exists(metadata_path): | |
| return {} | |
| with open(metadata_path, "r", encoding="utf-8") as handle: | |
| return json.load(handle) | |
| def _discover_shards(self) -> List[str]: | |
| if not os.path.isdir(self.cache_dir): | |
| raise FileNotFoundError(f"Cache directory not found: {self.cache_dir}") | |
| shards: List[str] = [] | |
| for name in sorted(os.listdir(self.cache_dir)): | |
| if not (name.startswith("shard_") and name.endswith(".pt")): | |
| continue | |
| shards.append(os.path.join(self.cache_dir, name)) | |
| if not shards: | |
| raise ValueError(f"No cache shards found under {self.cache_dir}") | |
| return shards | |
| def _resolve_shard_sizes(self) -> List[int]: | |
| num_shards = len(self.shard_files) | |
| metadata_num_shards = self.metadata.get("num_shards") | |
| metadata_num_records = self.metadata.get("num_records") | |
| shard_size = self.metadata.get("shard_size") | |
| if ( | |
| isinstance(metadata_num_shards, int) | |
| and isinstance(metadata_num_records, int) | |
| and isinstance(shard_size, int) | |
| and metadata_num_shards == num_shards | |
| and metadata_num_records > 0 | |
| and shard_size > 0 | |
| ): | |
| shard_sizes = [shard_size] * num_shards | |
| shard_sizes[-1] = metadata_num_records - shard_size * max(num_shards - 1, 0) | |
| return shard_sizes | |
| shard_sizes: List[int] = [] | |
| for shard_path in self.shard_files: | |
| payload = torch.load(shard_path, map_location="cpu", weights_only=False) | |
| records = payload.get("records") | |
| if not isinstance(records, list): | |
| raise ValueError(f"Invalid shard format in {shard_path}") | |
| shard_sizes.append(len(records)) | |
| return shard_sizes | |
| def _deserialize_item(raw: Optional[Dict[str, Any]]) -> Optional[PairItem]: | |
| if raw is None: | |
| return None | |
| modality = raw["type"] | |
| if modality == "text" and "tokens" in raw: | |
| value = raw["tokens"] | |
| elif modality == "text": | |
| value = raw["value"] | |
| elif "tensor" in raw: | |
| value = raw["tensor"] | |
| else: | |
| value = raw.get("value") | |
| return PairItem(modality=modality, value=value) | |
| def _store_shard(self, shard_idx: int, records: List[Dict[str, Any]]) -> None: | |
| with self._cache_lock: | |
| self._shard_cache[shard_idx] = records | |
| self._shard_cache.move_to_end(shard_idx) | |
| while len(self._shard_cache) > self.shard_cache_limit: | |
| self._shard_cache.popitem(last=False) | |
| def _ensure_prefetch_thread(self) -> None: | |
| if self.prefetch_shards <= 0: | |
| return | |
| if self._prefetch_thread is not None and self._prefetch_thread.is_alive(): | |
| return | |
| self._prefetch_stop.clear() | |
| self._prefetch_queue = queue.Queue(maxsize=max(self.prefetch_shards * 2, 1)) | |
| self._prefetch_thread = threading.Thread( | |
| target=self._prefetch_worker, | |
| daemon=True, | |
| name=f"sequential-shard-prefetch-{os.getpid()}", | |
| ) | |
| self._prefetch_thread.start() | |
| def _prefetch_worker(self) -> None: | |
| while not self._prefetch_stop.is_set(): | |
| try: | |
| shard_idx = self._prefetch_queue.get(timeout=0.1) | |
| except queue.Empty: | |
| continue | |
| if shard_idx is None: | |
| continue | |
| try: | |
| with self._cache_lock: | |
| if shard_idx in self._shard_cache: | |
| self._prefetch_hits += 1 | |
| continue | |
| payload = torch.load(self.shard_files[shard_idx], map_location="cpu", weights_only=False) | |
| self._store_shard(shard_idx, payload["records"]) | |
| self._prefetch_hits += 1 | |
| finally: | |
| with self._cache_lock: | |
| self._prefetch_requested.discard(shard_idx) | |
| def _stop_prefetch_thread(self) -> None: | |
| self._prefetch_stop.set() | |
| if self._prefetch_thread is not None and self._prefetch_thread.is_alive(): | |
| self._prefetch_thread.join(timeout=1.0) | |
| self._prefetch_thread = None | |
| self._prefetch_queue = None | |
| def _schedule_prefetch_from_position(self, local_pos: int) -> None: | |
| if self.prefetch_shards <= 0: | |
| return | |
| self._ensure_prefetch_thread() | |
| if self._prefetch_queue is None: | |
| return | |
| for next_pos in range(local_pos + 1, min(len(self._local_shard_indices), local_pos + 1 + self.prefetch_shards)): | |
| shard_idx = self._local_shard_indices[next_pos] | |
| with self._cache_lock: | |
| if shard_idx in self._shard_cache or shard_idx in self._prefetch_requested: | |
| continue | |
| self._prefetch_requested.add(shard_idx) | |
| try: | |
| self._prefetch_queue.put_nowait(shard_idx) | |
| except queue.Full: | |
| with self._cache_lock: | |
| self._prefetch_requested.discard(shard_idx) | |
| break | |
| def _build_local_shard_order(self, epoch: int) -> List[int]: | |
| shard_indices = list(self._all_shard_indices) | |
| if self.shuffle: | |
| random.Random(42 + epoch).shuffle(shard_indices) | |
| local_shards = shard_indices[self.rank::self.world_size] | |
| max_shards = math.ceil(len(shard_indices) / self.world_size) | |
| if not local_shards: | |
| raise ValueError(f"Rank {self.rank} received no shards from {self.cache_dir}") | |
| while len(local_shards) < max_shards: | |
| local_shards.append(local_shards[len(local_shards) % len(local_shards)]) | |
| return local_shards | |
| def _load_records_for_shard(self, shard_idx: int) -> List[Dict[str, Any]]: | |
| cached = None | |
| with self._cache_lock: | |
| cached = self._shard_cache.get(shard_idx) | |
| if cached is not None: | |
| self._shard_cache.move_to_end(shard_idx) | |
| if cached is not None: | |
| return cached | |
| self._prefetch_misses += 1 | |
| payload = torch.load(self.shard_files[shard_idx], map_location="cpu", weights_only=False) | |
| records = payload["records"] | |
| self._store_shard(shard_idx, records) | |
| with self._cache_lock: | |
| self._prefetch_requested.discard(shard_idx) | |
| return records | |
| def _pad_records(self, records: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
| if len(records) >= self.target_shard_size: | |
| return records | |
| repeat = math.ceil(self.target_shard_size / len(records)) | |
| return (records * repeat)[: self.target_shard_size] | |
| def reset(self, epoch: int) -> bool: | |
| self._stop_prefetch_thread() | |
| self._local_shard_indices = self._build_local_shard_order(epoch) | |
| self.current_local_shard_pos = -1 | |
| self.current_records = None | |
| with self._cache_lock: | |
| self._prefetch_requested.clear() | |
| if self.prefetch_shards > 0: | |
| self._ensure_prefetch_thread() | |
| return self.next_shard() | |
| def next_shard(self) -> bool: | |
| self.current_local_shard_pos += 1 | |
| if self.current_local_shard_pos >= len(self._local_shard_indices): | |
| self.current_records = None | |
| return False | |
| shard_idx = self._local_shard_indices[self.current_local_shard_pos] | |
| records = self._load_records_for_shard(shard_idx) | |
| self.current_records = self._pad_records(records) | |
| self._schedule_prefetch_from_position(self.current_local_shard_pos) | |
| return True | |
| def __len__(self) -> int: | |
| return len(self.current_records or []) | |
| def __getitem__(self, idx: int) -> TrainRecord: | |
| if self.current_records is None: | |
| raise IndexError(idx) | |
| raw = self.current_records[idx] | |
| return TrainRecord( | |
| query=self._deserialize_item(raw["query"]), | |
| positive=self._deserialize_item(raw["positive"]), | |
| negative=self._deserialize_item(raw.get("negative")), | |
| ) | |
| def estimated_num_batches(self, batch_size: int, drop_last: bool) -> int: | |
| shard_batches = self.target_shard_size // batch_size if drop_last else math.ceil(self.target_shard_size / batch_size) | |
| return shard_batches * max(len(self._build_local_shard_order(0)), 1) | |
| def get_prefetch_stats(self) -> Dict[str, int]: | |
| with self._cache_lock: | |
| return { | |
| "cache_size": len(self._shard_cache), | |
| "cache_limit": self.shard_cache_limit, | |
| "prefetch_shards": self.prefetch_shards, | |
| "prefetch_hits": self._prefetch_hits, | |
| "prefetch_misses": self._prefetch_misses, | |
| "prefetch_pending": len(self._prefetch_requested), | |
| "local_shards": len(self._local_shard_indices), | |
| "target_shard_size": self.target_shard_size, | |
| } | |
| def close(self) -> None: | |
| self._stop_prefetch_thread() | |
| def __del__(self): | |
| self.close() | |
| def _process_shard() -> tuple[int, int]: | |
| rank = int(os.environ.get("ACCELERATE_PROCESS_INDEX") or os.environ.get("RANK") or 0) | |
| world_size = int(os.environ.get("WORLD_SIZE") or os.environ.get("ACCELERATE_NUM_PROCESSES") or 1) | |
| worker_info = torch.utils.data.get_worker_info() | |
| if worker_info is None: | |
| return rank, max(world_size, 1) | |
| total_shards = max(world_size, 1) * worker_info.num_workers | |
| shard_id = rank * worker_info.num_workers + worker_info.id | |
| return shard_id, max(total_shards, 1) | |
| def iter_sentence_transformers_rows( | |
| manifest_path: str, | |
| image_root: Optional[str], | |
| audio_root: Optional[str], | |
| allow_missing_negative: bool, | |
| allowed_modalities: Optional[List[str]], | |
| query_modalities: Optional[List[str]], | |
| positive_modalities: Optional[List[str]], | |
| negative_modalities: Optional[List[str]], | |
| use_negative_column: bool, | |
| ): | |
| allowed = set(allowed_modalities or []) | |
| allowed_query = set(query_modalities or []) | |
| allowed_positive = set(positive_modalities or []) | |
| allowed_negative = set(negative_modalities or []) | |
| shard_id, total_shards = _process_shard() | |
| matched_index = 0 | |
| for record in iter_manifest_records( | |
| manifest_path=manifest_path, | |
| image_root=image_root, | |
| audio_root=audio_root, | |
| allow_missing_negative=allow_missing_negative, | |
| ): | |
| if not record_matches_filters( | |
| record, | |
| allowed=allowed, | |
| allowed_query=allowed_query, | |
| allowed_positive=allowed_positive, | |
| allowed_negative=allowed_negative, | |
| ): | |
| continue | |
| if matched_index % total_shards == shard_id: | |
| yield record_to_sentence_transformers_row(record, include_negative=use_negative_column) | |
| matched_index += 1 | |
| def collate_records(batch: List[TrainRecord]) -> Dict[str, List[PairItem]]: | |
| return { | |
| "query": [r.query for r in batch], | |
| "positive": [r.positive for r in batch], | |
| "negative": [r.negative for r in batch], | |
| } | |
| def sentence_transformers_input(item: PairItem) -> Any: | |
| payload: Dict[str, Any] = {} | |
| if item.modality == "text": | |
| payload["text"] = item.value | |
| return payload | |
| if item.modality == "image": | |
| payload["image"] = item.value | |
| return payload | |
| if item.modality == "audio": | |
| payload["audio"] = item.value | |
| return payload | |
| return item.value | |
| def resolve_media(item: PairItem, image_root: Optional[str], audio_root: Optional[str]) -> PairItem: | |
| if item.modality == "image" and image_root and not os.path.isabs(item.value): | |
| return PairItem(item.modality, os.path.join(image_root, item.value)) | |
| if item.modality == "audio" and audio_root and not os.path.isabs(item.value): | |
| return PairItem(item.modality, os.path.join(audio_root, item.value)) | |
| return item | |
| def iter_manifest_records( | |
| manifest_path: str, | |
| image_root: Optional[str] = None, | |
| audio_root: Optional[str] = None, | |
| allow_missing_negative: bool = True, | |
| ) -> Iterable[TrainRecord]: | |
| if not os.path.exists(manifest_path): | |
| raise FileNotFoundError(f"Manifest not found: {manifest_path}") | |
| with open(manifest_path, "r", encoding="utf-8") as handle: | |
| for line_no, line in enumerate(handle, start=1): | |
| line = line.strip() | |
| if not line: | |
| continue | |
| raw = json.loads(line) | |
| record = parse_record(raw) | |
| record = TrainRecord( | |
| query=resolve_media(record.query, image_root, audio_root), | |
| positive=resolve_media(record.positive, image_root, audio_root), | |
| negative=resolve_media(record.negative, image_root, audio_root) if record.negative else None, | |
| ) | |
| if record.negative is None and not allow_missing_negative: | |
| raise ValueError(f"Missing negative at line {line_no}") | |
| yield record | |
| def record_matches_filters( | |
| record: TrainRecord, | |
| allowed: set[str], | |
| allowed_query: set[str], | |
| allowed_positive: set[str], | |
| allowed_negative: set[str], | |
| ) -> bool: | |
| record_modalities = {record.query.modality, record.positive.modality} | |
| if record.negative is not None: | |
| record_modalities.add(record.negative.modality) | |
| if allowed and not record_modalities.issubset(allowed): | |
| return False | |
| if allowed_query and record.query.modality not in allowed_query: | |
| return False | |
| if allowed_positive and record.positive.modality not in allowed_positive: | |
| return False | |
| if record.negative is not None and allowed_negative and record.negative.modality not in allowed_negative: | |
| return False | |
| return True | |
| def record_to_sentence_transformers_row(record: TrainRecord, include_negative: bool) -> Dict[str, Any]: | |
| row = { | |
| "query": sentence_transformers_input(record.query), | |
| "positive": sentence_transformers_input(record.positive), | |
| } | |
| if include_negative and record.negative is not None: | |
| row["negative_0"] = sentence_transformers_input(record.negative) | |
| return row | |
| def summarize_manifest_records( | |
| manifest_path: str, | |
| image_root: Optional[str] = None, | |
| audio_root: Optional[str] = None, | |
| allow_missing_negative: bool = True, | |
| allowed_modalities: Optional[List[str]] = None, | |
| query_modalities: Optional[List[str]] = None, | |
| positive_modalities: Optional[List[str]] = None, | |
| negative_modalities: Optional[List[str]] = None, | |
| max_records: Optional[int] = None, | |
| ) -> Dict[str, Any]: | |
| modalities = set() | |
| negatives_present = 0 | |
| negatives_missing = 0 | |
| skipped_rows = 0 | |
| num_rows = 0 | |
| allowed = set(allowed_modalities or []) | |
| allowed_query = set(query_modalities or []) | |
| allowed_positive = set(positive_modalities or []) | |
| allowed_negative = set(negative_modalities or []) | |
| for record in iter_manifest_records( | |
| manifest_path=manifest_path, | |
| image_root=image_root, | |
| audio_root=audio_root, | |
| allow_missing_negative=allow_missing_negative, | |
| ): | |
| if not record_matches_filters( | |
| record, | |
| allowed=allowed, | |
| allowed_query=allowed_query, | |
| allowed_positive=allowed_positive, | |
| allowed_negative=allowed_negative, | |
| ): | |
| skipped_rows += 1 | |
| continue | |
| modalities.add(record.query.modality) | |
| modalities.add(record.positive.modality) | |
| if record.negative is not None: | |
| modalities.add(record.negative.modality) | |
| negatives_present += 1 | |
| else: | |
| negatives_missing += 1 | |
| num_rows += 1 | |
| if max_records is not None and num_rows >= max_records: | |
| break | |
| if num_rows == 0: | |
| raise ValueError(f"No records loaded from {manifest_path}") | |
| return { | |
| "modalities": sorted(modalities), | |
| "num_rows": num_rows, | |
| "has_uniform_negatives": negatives_present > 0 and negatives_missing == 0, | |
| "num_negatives_present": negatives_present, | |
| "num_negatives_missing": negatives_missing, | |
| "skipped_rows": skipped_rows, | |
| } | |
| def manifest_to_sentence_transformers_dataset( | |
| manifest_path: str, | |
| image_root: Optional[str] = None, | |
| audio_root: Optional[str] = None, | |
| allow_missing_negative: bool = True, | |
| allowed_modalities: Optional[List[str]] = None, | |
| query_modalities: Optional[List[str]] = None, | |
| positive_modalities: Optional[List[str]] = None, | |
| negative_modalities: Optional[List[str]] = None, | |
| as_iterable: bool = False, | |
| max_records: Optional[int] = None, | |
| ) -> tuple[Dataset | IterableDataset, Dict[str, Any]]: | |
| info = summarize_manifest_records( | |
| manifest_path=manifest_path, | |
| image_root=image_root, | |
| audio_root=audio_root, | |
| allow_missing_negative=allow_missing_negative, | |
| allowed_modalities=allowed_modalities, | |
| query_modalities=query_modalities, | |
| positive_modalities=positive_modalities, | |
| negative_modalities=negative_modalities, | |
| max_records=max_records, | |
| ) | |
| dataset_out: Dataset | IterableDataset | |
| if as_iterable: | |
| column_names = ["query", "positive"] | |
| if info["has_uniform_negatives"]: | |
| column_names.append("negative_0") | |
| dataset_out = IterableDataset.from_generator( | |
| iter_sentence_transformers_rows, | |
| features=Features({key: Value("null") for key in column_names}), | |
| gen_kwargs={ | |
| "manifest_path": manifest_path, | |
| "image_root": image_root, | |
| "audio_root": audio_root, | |
| "allow_missing_negative": allow_missing_negative, | |
| "allowed_modalities": allowed_modalities, | |
| "query_modalities": query_modalities, | |
| "positive_modalities": positive_modalities, | |
| "negative_modalities": negative_modalities, | |
| "use_negative_column": info["has_uniform_negatives"], | |
| }, | |
| ) | |
| else: | |
| dataset = JsonlManifestDataset( | |
| manifest_path=manifest_path, | |
| image_root=image_root, | |
| audio_root=audio_root, | |
| allow_missing_negative=allow_missing_negative, | |
| ) | |
| allowed = set(allowed_modalities or []) | |
| allowed_query = set(query_modalities or []) | |
| allowed_positive = set(positive_modalities or []) | |
| allowed_negative = set(negative_modalities or []) | |
| rows: List[Dict[str, Any]] = [] | |
| for record in dataset.records: | |
| if not record_matches_filters( | |
| record, | |
| allowed=allowed, | |
| allowed_query=allowed_query, | |
| allowed_positive=allowed_positive, | |
| allowed_negative=allowed_negative, | |
| ): | |
| continue | |
| rows.append(record_to_sentence_transformers_row(record, include_negative=info["has_uniform_negatives"])) | |
| if max_records is not None and len(rows) >= max_records: | |
| break | |
| dataset_out = Dataset.from_list(rows) | |
| return dataset_out, info | |