| import json |
| import math |
| import os |
| import queue |
| import random |
| import threading |
| from bisect import bisect_right |
| from collections import OrderedDict |
| from dataclasses import dataclass |
| from typing import Any, Dict, Iterable, List, Optional |
|
|
| import torch |
| from datasets import Dataset, Features, IterableDataset, Value |
|
|
|
|
| SUPPORTED_MODALITIES = {"text", "image", "audio"} |
|
|
|
|
| @dataclass |
| class PairItem: |
| modality: str |
| value: Any |
|
|
|
|
| @dataclass |
| class TrainRecord: |
| query: PairItem |
| positive: PairItem |
| negative: Optional[PairItem] = None |
|
|
|
|
| def _parse_item(obj: Any, prefix: str) -> PairItem: |
| if isinstance(obj, dict): |
| modality = obj.get("type") |
| value = obj.get("value") |
| else: |
| modality = None |
| value = None |
|
|
| if not modality or not value: |
| raise ValueError(f"{prefix} must include type/value") |
| if modality not in SUPPORTED_MODALITIES: |
| raise ValueError(f"Unsupported modality '{modality}' in {prefix}") |
| return PairItem(modality=modality, value=value) |
|
|
|
|
| def parse_record(raw: Dict[str, Any]) -> TrainRecord: |
| if "query" in raw and "positive" in raw: |
| query = _parse_item(raw["query"], "query") |
| positive = _parse_item(raw["positive"], "positive") |
| negative = _parse_item(raw["negative"], "negative") if raw.get("negative") else None |
| return TrainRecord(query=query, positive=positive, negative=negative) |
|
|
| |
| if "texts_a" in raw and "texts_b" in raw: |
| query = PairItem("text", raw["texts_a"]) |
| positive = PairItem("text", raw["texts_b"]) |
| return TrainRecord(query=query, positive=positive) |
|
|
| if "image_path" in raw and "caption" in raw: |
| query = PairItem("image", raw["image_path"]) |
| positive = PairItem("text", raw["caption"]) |
| return TrainRecord(query=query, positive=positive) |
|
|
| if "audio_path" in raw and "caption" in raw: |
| query = PairItem("audio", raw["audio_path"]) |
| positive = PairItem("text", raw["caption"]) |
| return TrainRecord(query=query, positive=positive) |
|
|
| raise ValueError("Record does not match supported schemas") |
|
|
|
|
| class JsonlManifestDataset: |
| def __init__( |
| self, |
| manifest_path: str, |
| image_root: Optional[str] = None, |
| audio_root: Optional[str] = None, |
| allow_missing_negative: bool = True, |
| ) -> None: |
| self.manifest_path = manifest_path |
| self.image_root = image_root |
| self.audio_root = audio_root |
| self.allow_missing_negative = allow_missing_negative |
| self.records = list( |
| iter_manifest_records( |
| manifest_path=self.manifest_path, |
| image_root=self.image_root, |
| audio_root=self.audio_root, |
| allow_missing_negative=self.allow_missing_negative, |
| ) |
| ) |
| if not self.records: |
| raise ValueError(f"No records loaded from {self.manifest_path}") |
|
|
| def __len__(self) -> int: |
| return len(self.records) |
|
|
| def __getitem__(self, idx: int) -> TrainRecord: |
| return self.records[idx] |
|
|
|
|
| class CachedShardDataset: |
| def __init__(self, cache_dir: str, shard_cache_limit: int = 2, prefetch_shards: int = 0) -> None: |
| self.cache_dir = cache_dir |
| self.shard_cache_limit = max(int(shard_cache_limit), 1) |
| self.prefetch_shards = max(int(prefetch_shards), 0) |
| self.metadata = self._load_metadata() |
| self.shard_files = self._discover_shards() |
| self.shard_sizes = self._resolve_shard_sizes() |
| self.shard_offsets = self._build_offsets(self.shard_sizes) |
| self.total_rows = sum(self.shard_sizes) |
| self._shard_cache: OrderedDict[int, List[Dict[str, Any]]] = OrderedDict() |
| self._init_runtime_state() |
|
|
| def _init_runtime_state(self) -> None: |
| self._cache_lock = threading.Lock() |
| self._prefetch_queue = None |
| self._prefetch_thread = None |
| self._prefetch_stop = threading.Event() |
| self._prefetch_requested: set[int] = set() |
| self._prefetch_hits = 0 |
| self._prefetch_misses = 0 |
|
|
| def __getstate__(self): |
| state = self.__dict__.copy() |
| state["_shard_cache"] = OrderedDict(state.get("_shard_cache", OrderedDict())) |
| state["_cache_lock"] = None |
| state["_prefetch_queue"] = None |
| state["_prefetch_thread"] = None |
| state["_prefetch_stop"] = None |
| state["_prefetch_requested"] = set() |
| return state |
|
|
| def __setstate__(self, state): |
| self.__dict__.update(state) |
| self._shard_cache = OrderedDict(self._shard_cache) |
| self._init_runtime_state() |
|
|
| def _load_metadata(self) -> Dict[str, Any]: |
| metadata_path = os.path.join(self.cache_dir, "metadata.json") |
| if not os.path.exists(metadata_path): |
| return {} |
| with open(metadata_path, "r", encoding="utf-8") as handle: |
| return json.load(handle) |
|
|
| def _discover_shards(self) -> List[str]: |
| if not os.path.isdir(self.cache_dir): |
| raise FileNotFoundError(f"Cache directory not found: {self.cache_dir}") |
| shards: List[str] = [] |
| for name in sorted(os.listdir(self.cache_dir)): |
| if not (name.startswith("shard_") and name.endswith(".pt")): |
| continue |
| shard_path = os.path.join(self.cache_dir, name) |
| shards.append(shard_path) |
| if not shards: |
| raise ValueError(f"No cache shards found under {self.cache_dir}") |
| return shards |
|
|
| @staticmethod |
| def _build_offsets(shard_sizes: List[int]) -> List[int]: |
| offsets: List[int] = [] |
| running_total = 0 |
| for shard_size in shard_sizes: |
| running_total += shard_size |
| offsets.append(running_total) |
| return offsets |
|
|
| def _resolve_shard_sizes(self) -> List[int]: |
| num_shards = len(self.shard_files) |
| metadata_num_shards = self.metadata.get("num_shards") |
| metadata_num_records = self.metadata.get("num_records") |
| shard_size = self.metadata.get("shard_size") |
|
|
| if ( |
| isinstance(metadata_num_shards, int) |
| and isinstance(metadata_num_records, int) |
| and isinstance(shard_size, int) |
| and metadata_num_shards == num_shards |
| and metadata_num_records > 0 |
| and shard_size > 0 |
| ): |
| shard_sizes = [shard_size] * num_shards |
| full_rows_before_last = shard_size * max(num_shards - 1, 0) |
| shard_sizes[-1] = metadata_num_records - full_rows_before_last |
| if shard_sizes[-1] <= 0: |
| raise ValueError(f"Invalid metadata in {self.cache_dir}: last shard size computed as {shard_sizes[-1]}") |
| return shard_sizes |
|
|
| shard_sizes: List[int] = [] |
| for shard_path in self.shard_files: |
| payload = torch.load(shard_path, map_location="cpu", weights_only=False) |
| records = payload.get("records") |
| if not isinstance(records, list): |
| raise ValueError(f"Invalid shard format in {shard_path}") |
| shard_sizes.append(len(records)) |
| return shard_sizes |
|
|
| def _store_shard(self, shard_idx: int, records: List[Dict[str, Any]]) -> None: |
| with self._cache_lock: |
| self._shard_cache[shard_idx] = records |
| self._shard_cache.move_to_end(shard_idx) |
| while len(self._shard_cache) > self.shard_cache_limit: |
| self._shard_cache.popitem(last=False) |
|
|
| def _ensure_prefetch_thread(self) -> None: |
| if self.prefetch_shards <= 0: |
| return |
| if self._prefetch_thread is not None and self._prefetch_thread.is_alive(): |
| return |
|
|
| self._prefetch_stop.clear() |
| self._prefetch_queue = queue.Queue(maxsize=max(self.prefetch_shards * 2, 1)) |
| self._prefetch_thread = threading.Thread( |
| target=self._prefetch_worker, |
| daemon=True, |
| name=f"cached-shard-prefetch-{os.getpid()}", |
| ) |
| self._prefetch_thread.start() |
|
|
| def _prefetch_worker(self) -> None: |
| while not self._prefetch_stop.is_set(): |
| try: |
| shard_idx = self._prefetch_queue.get(timeout=0.1) |
| except queue.Empty: |
| continue |
|
|
| if shard_idx is None: |
| continue |
|
|
| try: |
| with self._cache_lock: |
| if shard_idx in self._shard_cache: |
| self._prefetch_hits += 1 |
| continue |
| payload = torch.load(self.shard_files[shard_idx], map_location="cpu", weights_only=False) |
| records = payload["records"] |
| self._store_shard(shard_idx, records) |
| self._prefetch_hits += 1 |
| finally: |
| with self._cache_lock: |
| self._prefetch_requested.discard(shard_idx) |
|
|
| def _schedule_prefetch(self, shard_idx: int) -> None: |
| if self.prefetch_shards <= 0: |
| return |
|
|
| self._ensure_prefetch_thread() |
| if self._prefetch_queue is None: |
| return |
|
|
| for next_idx in range(shard_idx + 1, min(len(self.shard_files), shard_idx + 1 + self.prefetch_shards)): |
| with self._cache_lock: |
| if next_idx in self._shard_cache or next_idx in self._prefetch_requested: |
| continue |
| self._prefetch_requested.add(next_idx) |
| try: |
| self._prefetch_queue.put_nowait(next_idx) |
| except queue.Full: |
| with self._cache_lock: |
| self._prefetch_requested.discard(next_idx) |
| break |
|
|
| def _load_shard(self, shard_idx: int) -> List[Dict[str, Any]]: |
| cached = None |
| with self._cache_lock: |
| cached = self._shard_cache.get(shard_idx) |
| if cached is not None: |
| self._shard_cache.move_to_end(shard_idx) |
| if cached is not None: |
| self._schedule_prefetch(shard_idx) |
| return cached |
|
|
| self._prefetch_misses += 1 |
| payload = torch.load(self.shard_files[shard_idx], map_location="cpu", weights_only=False) |
| records = payload["records"] |
| self._store_shard(shard_idx, records) |
| with self._cache_lock: |
| self._prefetch_requested.discard(shard_idx) |
| self._schedule_prefetch(shard_idx) |
| return records |
|
|
| @staticmethod |
| def _deserialize_item(raw: Optional[Dict[str, Any]]) -> Optional[PairItem]: |
| if raw is None: |
| return None |
| modality = raw["type"] |
| if modality == "text" and "tokens" in raw: |
| value = raw["tokens"] |
| elif modality == "text": |
| value = raw["value"] |
| elif "tensor" in raw: |
| value = raw["tensor"] |
| else: |
| value = raw.get("value") |
| return PairItem(modality=modality, value=value) |
|
|
| def __len__(self) -> int: |
| return self.total_rows |
|
|
| def __getitem__(self, idx: int) -> TrainRecord: |
| if idx < 0 or idx >= self.total_rows: |
| raise IndexError(idx) |
| shard_idx = bisect_right(self.shard_offsets, idx) |
| shard_start = 0 if shard_idx == 0 else self.shard_offsets[shard_idx - 1] |
| local_idx = idx - shard_start |
| raw = self._load_shard(shard_idx)[local_idx] |
| return TrainRecord( |
| query=self._deserialize_item(raw["query"]), |
| positive=self._deserialize_item(raw["positive"]), |
| negative=self._deserialize_item(raw.get("negative")), |
| ) |
|
|
| def get_prefetch_stats(self) -> Dict[str, int]: |
| with self._cache_lock: |
| return { |
| "cache_size": len(self._shard_cache), |
| "cache_limit": self.shard_cache_limit, |
| "prefetch_shards": self.prefetch_shards, |
| "prefetch_hits": self._prefetch_hits, |
| "prefetch_misses": self._prefetch_misses, |
| "prefetch_pending": len(self._prefetch_requested), |
| } |
|
|
| def close(self) -> None: |
| self._prefetch_stop.set() |
| if self._prefetch_thread is not None and self._prefetch_thread.is_alive(): |
| self._prefetch_thread.join(timeout=1.0) |
| self._prefetch_thread = None |
| self._prefetch_queue = None |
|
|
| def __del__(self): |
| self.close() |
|
|
|
|
| class SequentialShardDataset: |
| def __init__( |
| self, |
| cache_dir: str, |
| shuffle: bool = True, |
| rank: int = 0, |
| world_size: int = 1, |
| prefetch_shards: int = 2, |
| shard_cache_limit: int = 4, |
| ) -> None: |
| self.cache_dir = cache_dir |
| self.shuffle = shuffle |
| self.rank = rank |
| self.world_size = max(world_size, 1) |
| self.prefetch_shards = max(int(prefetch_shards), 0) |
| self.shard_cache_limit = max(int(shard_cache_limit), 1) |
|
|
| self.metadata = self._load_metadata() |
| self.shard_files = self._discover_shards() |
| self.shard_sizes = self._resolve_shard_sizes() |
| self.total_rows = sum(self.shard_sizes) |
| self.target_shard_size = int(self.metadata.get("shard_size") or max(self.shard_sizes)) |
|
|
| self._shard_cache: OrderedDict[int, List[Dict[str, Any]]] = OrderedDict() |
| self._cache_lock = threading.Lock() |
| self._prefetch_queue = None |
| self._prefetch_thread = None |
| self._prefetch_stop = threading.Event() |
| self._prefetch_requested: set[int] = set() |
| self._prefetch_hits = 0 |
| self._prefetch_misses = 0 |
|
|
| self._all_shard_indices = list(range(len(self.shard_files))) |
| self._local_shard_indices: List[int] = [] |
| self.current_local_shard_pos = -1 |
| self.current_records: Optional[List[Dict[str, Any]]] = None |
|
|
| def _load_metadata(self) -> Dict[str, Any]: |
| metadata_path = os.path.join(self.cache_dir, "metadata.json") |
| if not os.path.exists(metadata_path): |
| return {} |
| with open(metadata_path, "r", encoding="utf-8") as handle: |
| return json.load(handle) |
|
|
| def _discover_shards(self) -> List[str]: |
| if not os.path.isdir(self.cache_dir): |
| raise FileNotFoundError(f"Cache directory not found: {self.cache_dir}") |
| shards: List[str] = [] |
| for name in sorted(os.listdir(self.cache_dir)): |
| if not (name.startswith("shard_") and name.endswith(".pt")): |
| continue |
| shards.append(os.path.join(self.cache_dir, name)) |
| if not shards: |
| raise ValueError(f"No cache shards found under {self.cache_dir}") |
| return shards |
|
|
| def _resolve_shard_sizes(self) -> List[int]: |
| num_shards = len(self.shard_files) |
| metadata_num_shards = self.metadata.get("num_shards") |
| metadata_num_records = self.metadata.get("num_records") |
| shard_size = self.metadata.get("shard_size") |
|
|
| if ( |
| isinstance(metadata_num_shards, int) |
| and isinstance(metadata_num_records, int) |
| and isinstance(shard_size, int) |
| and metadata_num_shards == num_shards |
| and metadata_num_records > 0 |
| and shard_size > 0 |
| ): |
| shard_sizes = [shard_size] * num_shards |
| shard_sizes[-1] = metadata_num_records - shard_size * max(num_shards - 1, 0) |
| return shard_sizes |
|
|
| shard_sizes: List[int] = [] |
| for shard_path in self.shard_files: |
| payload = torch.load(shard_path, map_location="cpu", weights_only=False) |
| records = payload.get("records") |
| if not isinstance(records, list): |
| raise ValueError(f"Invalid shard format in {shard_path}") |
| shard_sizes.append(len(records)) |
| return shard_sizes |
|
|
| @staticmethod |
| def _deserialize_item(raw: Optional[Dict[str, Any]]) -> Optional[PairItem]: |
| if raw is None: |
| return None |
| modality = raw["type"] |
| if modality == "text" and "tokens" in raw: |
| value = raw["tokens"] |
| elif modality == "text": |
| value = raw["value"] |
| elif "tensor" in raw: |
| value = raw["tensor"] |
| else: |
| value = raw.get("value") |
| return PairItem(modality=modality, value=value) |
|
|
| def _store_shard(self, shard_idx: int, records: List[Dict[str, Any]]) -> None: |
| with self._cache_lock: |
| self._shard_cache[shard_idx] = records |
| self._shard_cache.move_to_end(shard_idx) |
| while len(self._shard_cache) > self.shard_cache_limit: |
| self._shard_cache.popitem(last=False) |
|
|
| def _ensure_prefetch_thread(self) -> None: |
| if self.prefetch_shards <= 0: |
| return |
| if self._prefetch_thread is not None and self._prefetch_thread.is_alive(): |
| return |
| self._prefetch_stop.clear() |
| self._prefetch_queue = queue.Queue(maxsize=max(self.prefetch_shards * 2, 1)) |
| self._prefetch_thread = threading.Thread( |
| target=self._prefetch_worker, |
| daemon=True, |
| name=f"sequential-shard-prefetch-{os.getpid()}", |
| ) |
| self._prefetch_thread.start() |
|
|
| def _prefetch_worker(self) -> None: |
| while not self._prefetch_stop.is_set(): |
| try: |
| shard_idx = self._prefetch_queue.get(timeout=0.1) |
| except queue.Empty: |
| continue |
| if shard_idx is None: |
| continue |
| try: |
| with self._cache_lock: |
| if shard_idx in self._shard_cache: |
| self._prefetch_hits += 1 |
| continue |
| payload = torch.load(self.shard_files[shard_idx], map_location="cpu", weights_only=False) |
| self._store_shard(shard_idx, payload["records"]) |
| self._prefetch_hits += 1 |
| finally: |
| with self._cache_lock: |
| self._prefetch_requested.discard(shard_idx) |
|
|
| def _stop_prefetch_thread(self) -> None: |
| self._prefetch_stop.set() |
| if self._prefetch_thread is not None and self._prefetch_thread.is_alive(): |
| self._prefetch_thread.join(timeout=1.0) |
| self._prefetch_thread = None |
| self._prefetch_queue = None |
|
|
| def _schedule_prefetch_from_position(self, local_pos: int) -> None: |
| if self.prefetch_shards <= 0: |
| return |
| self._ensure_prefetch_thread() |
| if self._prefetch_queue is None: |
| return |
| for next_pos in range(local_pos + 1, min(len(self._local_shard_indices), local_pos + 1 + self.prefetch_shards)): |
| shard_idx = self._local_shard_indices[next_pos] |
| with self._cache_lock: |
| if shard_idx in self._shard_cache or shard_idx in self._prefetch_requested: |
| continue |
| self._prefetch_requested.add(shard_idx) |
| try: |
| self._prefetch_queue.put_nowait(shard_idx) |
| except queue.Full: |
| with self._cache_lock: |
| self._prefetch_requested.discard(shard_idx) |
| break |
|
|
| def _build_local_shard_order(self, epoch: int) -> List[int]: |
| shard_indices = list(self._all_shard_indices) |
| if self.shuffle: |
| random.Random(42 + epoch).shuffle(shard_indices) |
| local_shards = shard_indices[self.rank::self.world_size] |
| max_shards = math.ceil(len(shard_indices) / self.world_size) |
| if not local_shards: |
| raise ValueError(f"Rank {self.rank} received no shards from {self.cache_dir}") |
| while len(local_shards) < max_shards: |
| local_shards.append(local_shards[len(local_shards) % len(local_shards)]) |
| return local_shards |
|
|
| def _load_records_for_shard(self, shard_idx: int) -> List[Dict[str, Any]]: |
| cached = None |
| with self._cache_lock: |
| cached = self._shard_cache.get(shard_idx) |
| if cached is not None: |
| self._shard_cache.move_to_end(shard_idx) |
| if cached is not None: |
| return cached |
|
|
| self._prefetch_misses += 1 |
| payload = torch.load(self.shard_files[shard_idx], map_location="cpu", weights_only=False) |
| records = payload["records"] |
| self._store_shard(shard_idx, records) |
| with self._cache_lock: |
| self._prefetch_requested.discard(shard_idx) |
| return records |
|
|
| def _pad_records(self, records: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
| if len(records) >= self.target_shard_size: |
| return records |
| repeat = math.ceil(self.target_shard_size / len(records)) |
| return (records * repeat)[: self.target_shard_size] |
|
|
| def reset(self, epoch: int) -> bool: |
| self._stop_prefetch_thread() |
| self._local_shard_indices = self._build_local_shard_order(epoch) |
| self.current_local_shard_pos = -1 |
| self.current_records = None |
| with self._cache_lock: |
| self._prefetch_requested.clear() |
| if self.prefetch_shards > 0: |
| self._ensure_prefetch_thread() |
| return self.next_shard() |
|
|
| def next_shard(self) -> bool: |
| self.current_local_shard_pos += 1 |
| if self.current_local_shard_pos >= len(self._local_shard_indices): |
| self.current_records = None |
| return False |
| shard_idx = self._local_shard_indices[self.current_local_shard_pos] |
| records = self._load_records_for_shard(shard_idx) |
| self.current_records = self._pad_records(records) |
| self._schedule_prefetch_from_position(self.current_local_shard_pos) |
| return True |
|
|
| def __len__(self) -> int: |
| return len(self.current_records or []) |
|
|
| def __getitem__(self, idx: int) -> TrainRecord: |
| if self.current_records is None: |
| raise IndexError(idx) |
| raw = self.current_records[idx] |
| return TrainRecord( |
| query=self._deserialize_item(raw["query"]), |
| positive=self._deserialize_item(raw["positive"]), |
| negative=self._deserialize_item(raw.get("negative")), |
| ) |
|
|
| def estimated_num_batches(self, batch_size: int, drop_last: bool) -> int: |
| shard_batches = self.target_shard_size // batch_size if drop_last else math.ceil(self.target_shard_size / batch_size) |
| return shard_batches * max(len(self._build_local_shard_order(0)), 1) |
|
|
| def get_prefetch_stats(self) -> Dict[str, int]: |
| with self._cache_lock: |
| return { |
| "cache_size": len(self._shard_cache), |
| "cache_limit": self.shard_cache_limit, |
| "prefetch_shards": self.prefetch_shards, |
| "prefetch_hits": self._prefetch_hits, |
| "prefetch_misses": self._prefetch_misses, |
| "prefetch_pending": len(self._prefetch_requested), |
| "local_shards": len(self._local_shard_indices), |
| "target_shard_size": self.target_shard_size, |
| } |
|
|
| def close(self) -> None: |
| self._stop_prefetch_thread() |
|
|
| def __del__(self): |
| self.close() |
|
|
|
|
| def _process_shard() -> tuple[int, int]: |
| rank = int(os.environ.get("ACCELERATE_PROCESS_INDEX") or os.environ.get("RANK") or 0) |
| world_size = int(os.environ.get("WORLD_SIZE") or os.environ.get("ACCELERATE_NUM_PROCESSES") or 1) |
| worker_info = torch.utils.data.get_worker_info() |
| if worker_info is None: |
| return rank, max(world_size, 1) |
|
|
| total_shards = max(world_size, 1) * worker_info.num_workers |
| shard_id = rank * worker_info.num_workers + worker_info.id |
| return shard_id, max(total_shards, 1) |
|
|
|
|
| def iter_sentence_transformers_rows( |
| manifest_path: str, |
| image_root: Optional[str], |
| audio_root: Optional[str], |
| allow_missing_negative: bool, |
| allowed_modalities: Optional[List[str]], |
| query_modalities: Optional[List[str]], |
| positive_modalities: Optional[List[str]], |
| negative_modalities: Optional[List[str]], |
| use_negative_column: bool, |
| ): |
| allowed = set(allowed_modalities or []) |
| allowed_query = set(query_modalities or []) |
| allowed_positive = set(positive_modalities or []) |
| allowed_negative = set(negative_modalities or []) |
| shard_id, total_shards = _process_shard() |
| matched_index = 0 |
|
|
| for record in iter_manifest_records( |
| manifest_path=manifest_path, |
| image_root=image_root, |
| audio_root=audio_root, |
| allow_missing_negative=allow_missing_negative, |
| ): |
| if not record_matches_filters( |
| record, |
| allowed=allowed, |
| allowed_query=allowed_query, |
| allowed_positive=allowed_positive, |
| allowed_negative=allowed_negative, |
| ): |
| continue |
|
|
| if matched_index % total_shards == shard_id: |
| yield record_to_sentence_transformers_row(record, include_negative=use_negative_column) |
| matched_index += 1 |
|
|
|
|
| def collate_records(batch: List[TrainRecord]) -> Dict[str, List[PairItem]]: |
| return { |
| "query": [r.query for r in batch], |
| "positive": [r.positive for r in batch], |
| "negative": [r.negative for r in batch], |
| } |
|
|
|
|
| def sentence_transformers_input(item: PairItem) -> Any: |
| payload: Dict[str, Any] = {} |
| if item.modality == "text": |
| payload["text"] = item.value |
| return payload |
| if item.modality == "image": |
| payload["image"] = item.value |
| return payload |
| if item.modality == "audio": |
| payload["audio"] = item.value |
| return payload |
| return item.value |
|
|
|
|
| def resolve_media(item: PairItem, image_root: Optional[str], audio_root: Optional[str]) -> PairItem: |
| if item.modality == "image" and image_root and not os.path.isabs(item.value): |
| return PairItem(item.modality, os.path.join(image_root, item.value)) |
| if item.modality == "audio" and audio_root and not os.path.isabs(item.value): |
| return PairItem(item.modality, os.path.join(audio_root, item.value)) |
| return item |
|
|
|
|
| def iter_manifest_records( |
| manifest_path: str, |
| image_root: Optional[str] = None, |
| audio_root: Optional[str] = None, |
| allow_missing_negative: bool = True, |
| ) -> Iterable[TrainRecord]: |
| if not os.path.exists(manifest_path): |
| raise FileNotFoundError(f"Manifest not found: {manifest_path}") |
|
|
| with open(manifest_path, "r", encoding="utf-8") as handle: |
| for line_no, line in enumerate(handle, start=1): |
| line = line.strip() |
| if not line: |
| continue |
| raw = json.loads(line) |
| record = parse_record(raw) |
| record = TrainRecord( |
| query=resolve_media(record.query, image_root, audio_root), |
| positive=resolve_media(record.positive, image_root, audio_root), |
| negative=resolve_media(record.negative, image_root, audio_root) if record.negative else None, |
| ) |
| if record.negative is None and not allow_missing_negative: |
| raise ValueError(f"Missing negative at line {line_no}") |
| yield record |
|
|
|
|
| def record_matches_filters( |
| record: TrainRecord, |
| allowed: set[str], |
| allowed_query: set[str], |
| allowed_positive: set[str], |
| allowed_negative: set[str], |
| ) -> bool: |
| record_modalities = {record.query.modality, record.positive.modality} |
| if record.negative is not None: |
| record_modalities.add(record.negative.modality) |
| if allowed and not record_modalities.issubset(allowed): |
| return False |
| if allowed_query and record.query.modality not in allowed_query: |
| return False |
| if allowed_positive and record.positive.modality not in allowed_positive: |
| return False |
| if record.negative is not None and allowed_negative and record.negative.modality not in allowed_negative: |
| return False |
| return True |
|
|
|
|
| def record_to_sentence_transformers_row(record: TrainRecord, include_negative: bool) -> Dict[str, Any]: |
| row = { |
| "query": sentence_transformers_input(record.query), |
| "positive": sentence_transformers_input(record.positive), |
| } |
| if include_negative and record.negative is not None: |
| row["negative_0"] = sentence_transformers_input(record.negative) |
| return row |
|
|
|
|
| def summarize_manifest_records( |
| manifest_path: str, |
| image_root: Optional[str] = None, |
| audio_root: Optional[str] = None, |
| allow_missing_negative: bool = True, |
| allowed_modalities: Optional[List[str]] = None, |
| query_modalities: Optional[List[str]] = None, |
| positive_modalities: Optional[List[str]] = None, |
| negative_modalities: Optional[List[str]] = None, |
| max_records: Optional[int] = None, |
| ) -> Dict[str, Any]: |
| modalities = set() |
| negatives_present = 0 |
| negatives_missing = 0 |
| skipped_rows = 0 |
| num_rows = 0 |
| allowed = set(allowed_modalities or []) |
| allowed_query = set(query_modalities or []) |
| allowed_positive = set(positive_modalities or []) |
| allowed_negative = set(negative_modalities or []) |
|
|
| for record in iter_manifest_records( |
| manifest_path=manifest_path, |
| image_root=image_root, |
| audio_root=audio_root, |
| allow_missing_negative=allow_missing_negative, |
| ): |
| if not record_matches_filters( |
| record, |
| allowed=allowed, |
| allowed_query=allowed_query, |
| allowed_positive=allowed_positive, |
| allowed_negative=allowed_negative, |
| ): |
| skipped_rows += 1 |
| continue |
|
|
| modalities.add(record.query.modality) |
| modalities.add(record.positive.modality) |
| if record.negative is not None: |
| modalities.add(record.negative.modality) |
| negatives_present += 1 |
| else: |
| negatives_missing += 1 |
| num_rows += 1 |
| if max_records is not None and num_rows >= max_records: |
| break |
|
|
| if num_rows == 0: |
| raise ValueError(f"No records loaded from {manifest_path}") |
|
|
| return { |
| "modalities": sorted(modalities), |
| "num_rows": num_rows, |
| "has_uniform_negatives": negatives_present > 0 and negatives_missing == 0, |
| "num_negatives_present": negatives_present, |
| "num_negatives_missing": negatives_missing, |
| "skipped_rows": skipped_rows, |
| } |
|
|
|
|
| def manifest_to_sentence_transformers_dataset( |
| manifest_path: str, |
| image_root: Optional[str] = None, |
| audio_root: Optional[str] = None, |
| allow_missing_negative: bool = True, |
| allowed_modalities: Optional[List[str]] = None, |
| query_modalities: Optional[List[str]] = None, |
| positive_modalities: Optional[List[str]] = None, |
| negative_modalities: Optional[List[str]] = None, |
| as_iterable: bool = False, |
| max_records: Optional[int] = None, |
| ) -> tuple[Dataset | IterableDataset, Dict[str, Any]]: |
| info = summarize_manifest_records( |
| manifest_path=manifest_path, |
| image_root=image_root, |
| audio_root=audio_root, |
| allow_missing_negative=allow_missing_negative, |
| allowed_modalities=allowed_modalities, |
| query_modalities=query_modalities, |
| positive_modalities=positive_modalities, |
| negative_modalities=negative_modalities, |
| max_records=max_records, |
| ) |
|
|
| dataset_out: Dataset | IterableDataset |
| if as_iterable: |
| column_names = ["query", "positive"] |
| if info["has_uniform_negatives"]: |
| column_names.append("negative_0") |
| dataset_out = IterableDataset.from_generator( |
| iter_sentence_transformers_rows, |
| features=Features({key: Value("null") for key in column_names}), |
| gen_kwargs={ |
| "manifest_path": manifest_path, |
| "image_root": image_root, |
| "audio_root": audio_root, |
| "allow_missing_negative": allow_missing_negative, |
| "allowed_modalities": allowed_modalities, |
| "query_modalities": query_modalities, |
| "positive_modalities": positive_modalities, |
| "negative_modalities": negative_modalities, |
| "use_negative_column": info["has_uniform_negatives"], |
| }, |
| ) |
| else: |
| dataset = JsonlManifestDataset( |
| manifest_path=manifest_path, |
| image_root=image_root, |
| audio_root=audio_root, |
| allow_missing_negative=allow_missing_negative, |
| ) |
| allowed = set(allowed_modalities or []) |
| allowed_query = set(query_modalities or []) |
| allowed_positive = set(positive_modalities or []) |
| allowed_negative = set(negative_modalities or []) |
| rows: List[Dict[str, Any]] = [] |
| for record in dataset.records: |
| if not record_matches_filters( |
| record, |
| allowed=allowed, |
| allowed_query=allowed_query, |
| allowed_positive=allowed_positive, |
| allowed_negative=allowed_negative, |
| ): |
| continue |
| rows.append(record_to_sentence_transformers_row(record, include_negative=info["has_uniform_negatives"])) |
| if max_records is not None and len(rows) >= max_records: |
| break |
| dataset_out = Dataset.from_list(rows) |
|
|
| return dataset_out, info |
|
|