HuaminChen's picture
Upload multi-modal-embed-large final model
e21cde3 verified
import json
import math
import os
import queue
import random
import threading
from bisect import bisect_right
from collections import OrderedDict
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional
import torch
from datasets import Dataset, Features, IterableDataset, Value
SUPPORTED_MODALITIES = {"text", "image", "audio"}
@dataclass
class PairItem:
modality: str
value: Any
@dataclass
class TrainRecord:
query: PairItem
positive: PairItem
negative: Optional[PairItem] = None
def _parse_item(obj: Any, prefix: str) -> PairItem:
if isinstance(obj, dict):
modality = obj.get("type")
value = obj.get("value")
else:
modality = None
value = None
if not modality or not value:
raise ValueError(f"{prefix} must include type/value")
if modality not in SUPPORTED_MODALITIES:
raise ValueError(f"Unsupported modality '{modality}' in {prefix}")
return PairItem(modality=modality, value=value)
def parse_record(raw: Dict[str, Any]) -> TrainRecord:
if "query" in raw and "positive" in raw:
query = _parse_item(raw["query"], "query")
positive = _parse_item(raw["positive"], "positive")
negative = _parse_item(raw["negative"], "negative") if raw.get("negative") else None
return TrainRecord(query=query, positive=positive, negative=negative)
# Compatibility with common pair formats in existing repos
if "texts_a" in raw and "texts_b" in raw:
query = PairItem("text", raw["texts_a"])
positive = PairItem("text", raw["texts_b"])
return TrainRecord(query=query, positive=positive)
if "image_path" in raw and "caption" in raw:
query = PairItem("image", raw["image_path"])
positive = PairItem("text", raw["caption"])
return TrainRecord(query=query, positive=positive)
if "audio_path" in raw and "caption" in raw:
query = PairItem("audio", raw["audio_path"])
positive = PairItem("text", raw["caption"])
return TrainRecord(query=query, positive=positive)
raise ValueError("Record does not match supported schemas")
class JsonlManifestDataset:
def __init__(
self,
manifest_path: str,
image_root: Optional[str] = None,
audio_root: Optional[str] = None,
allow_missing_negative: bool = True,
) -> None:
self.manifest_path = manifest_path
self.image_root = image_root
self.audio_root = audio_root
self.allow_missing_negative = allow_missing_negative
self.records = list(
iter_manifest_records(
manifest_path=self.manifest_path,
image_root=self.image_root,
audio_root=self.audio_root,
allow_missing_negative=self.allow_missing_negative,
)
)
if not self.records:
raise ValueError(f"No records loaded from {self.manifest_path}")
def __len__(self) -> int:
return len(self.records)
def __getitem__(self, idx: int) -> TrainRecord:
return self.records[idx]
class CachedShardDataset:
def __init__(self, cache_dir: str, shard_cache_limit: int = 2, prefetch_shards: int = 0) -> None:
self.cache_dir = cache_dir
self.shard_cache_limit = max(int(shard_cache_limit), 1)
self.prefetch_shards = max(int(prefetch_shards), 0)
self.metadata = self._load_metadata()
self.shard_files = self._discover_shards()
self.shard_sizes = self._resolve_shard_sizes()
self.shard_offsets = self._build_offsets(self.shard_sizes)
self.total_rows = sum(self.shard_sizes)
self._shard_cache: OrderedDict[int, List[Dict[str, Any]]] = OrderedDict()
self._init_runtime_state()
def _init_runtime_state(self) -> None:
self._cache_lock = threading.Lock()
self._prefetch_queue = None
self._prefetch_thread = None
self._prefetch_stop = threading.Event()
self._prefetch_requested: set[int] = set()
self._prefetch_hits = 0
self._prefetch_misses = 0
def __getstate__(self):
state = self.__dict__.copy()
state["_shard_cache"] = OrderedDict(state.get("_shard_cache", OrderedDict()))
state["_cache_lock"] = None
state["_prefetch_queue"] = None
state["_prefetch_thread"] = None
state["_prefetch_stop"] = None
state["_prefetch_requested"] = set()
return state
def __setstate__(self, state):
self.__dict__.update(state)
self._shard_cache = OrderedDict(self._shard_cache)
self._init_runtime_state()
def _load_metadata(self) -> Dict[str, Any]:
metadata_path = os.path.join(self.cache_dir, "metadata.json")
if not os.path.exists(metadata_path):
return {}
with open(metadata_path, "r", encoding="utf-8") as handle:
return json.load(handle)
def _discover_shards(self) -> List[str]:
if not os.path.isdir(self.cache_dir):
raise FileNotFoundError(f"Cache directory not found: {self.cache_dir}")
shards: List[str] = []
for name in sorted(os.listdir(self.cache_dir)):
if not (name.startswith("shard_") and name.endswith(".pt")):
continue
shard_path = os.path.join(self.cache_dir, name)
shards.append(shard_path)
if not shards:
raise ValueError(f"No cache shards found under {self.cache_dir}")
return shards
@staticmethod
def _build_offsets(shard_sizes: List[int]) -> List[int]:
offsets: List[int] = []
running_total = 0
for shard_size in shard_sizes:
running_total += shard_size
offsets.append(running_total)
return offsets
def _resolve_shard_sizes(self) -> List[int]:
num_shards = len(self.shard_files)
metadata_num_shards = self.metadata.get("num_shards")
metadata_num_records = self.metadata.get("num_records")
shard_size = self.metadata.get("shard_size")
if (
isinstance(metadata_num_shards, int)
and isinstance(metadata_num_records, int)
and isinstance(shard_size, int)
and metadata_num_shards == num_shards
and metadata_num_records > 0
and shard_size > 0
):
shard_sizes = [shard_size] * num_shards
full_rows_before_last = shard_size * max(num_shards - 1, 0)
shard_sizes[-1] = metadata_num_records - full_rows_before_last
if shard_sizes[-1] <= 0:
raise ValueError(f"Invalid metadata in {self.cache_dir}: last shard size computed as {shard_sizes[-1]}")
return shard_sizes
shard_sizes: List[int] = []
for shard_path in self.shard_files:
payload = torch.load(shard_path, map_location="cpu", weights_only=False)
records = payload.get("records")
if not isinstance(records, list):
raise ValueError(f"Invalid shard format in {shard_path}")
shard_sizes.append(len(records))
return shard_sizes
def _store_shard(self, shard_idx: int, records: List[Dict[str, Any]]) -> None:
with self._cache_lock:
self._shard_cache[shard_idx] = records
self._shard_cache.move_to_end(shard_idx)
while len(self._shard_cache) > self.shard_cache_limit:
self._shard_cache.popitem(last=False)
def _ensure_prefetch_thread(self) -> None:
if self.prefetch_shards <= 0:
return
if self._prefetch_thread is not None and self._prefetch_thread.is_alive():
return
self._prefetch_stop.clear()
self._prefetch_queue = queue.Queue(maxsize=max(self.prefetch_shards * 2, 1))
self._prefetch_thread = threading.Thread(
target=self._prefetch_worker,
daemon=True,
name=f"cached-shard-prefetch-{os.getpid()}",
)
self._prefetch_thread.start()
def _prefetch_worker(self) -> None:
while not self._prefetch_stop.is_set():
try:
shard_idx = self._prefetch_queue.get(timeout=0.1)
except queue.Empty:
continue
if shard_idx is None:
continue
try:
with self._cache_lock:
if shard_idx in self._shard_cache:
self._prefetch_hits += 1
continue
payload = torch.load(self.shard_files[shard_idx], map_location="cpu", weights_only=False)
records = payload["records"]
self._store_shard(shard_idx, records)
self._prefetch_hits += 1
finally:
with self._cache_lock:
self._prefetch_requested.discard(shard_idx)
def _schedule_prefetch(self, shard_idx: int) -> None:
if self.prefetch_shards <= 0:
return
self._ensure_prefetch_thread()
if self._prefetch_queue is None:
return
for next_idx in range(shard_idx + 1, min(len(self.shard_files), shard_idx + 1 + self.prefetch_shards)):
with self._cache_lock:
if next_idx in self._shard_cache or next_idx in self._prefetch_requested:
continue
self._prefetch_requested.add(next_idx)
try:
self._prefetch_queue.put_nowait(next_idx)
except queue.Full:
with self._cache_lock:
self._prefetch_requested.discard(next_idx)
break
def _load_shard(self, shard_idx: int) -> List[Dict[str, Any]]:
cached = None
with self._cache_lock:
cached = self._shard_cache.get(shard_idx)
if cached is not None:
self._shard_cache.move_to_end(shard_idx)
if cached is not None:
self._schedule_prefetch(shard_idx)
return cached
self._prefetch_misses += 1
payload = torch.load(self.shard_files[shard_idx], map_location="cpu", weights_only=False)
records = payload["records"]
self._store_shard(shard_idx, records)
with self._cache_lock:
self._prefetch_requested.discard(shard_idx)
self._schedule_prefetch(shard_idx)
return records
@staticmethod
def _deserialize_item(raw: Optional[Dict[str, Any]]) -> Optional[PairItem]:
if raw is None:
return None
modality = raw["type"]
if modality == "text" and "tokens" in raw:
value = raw["tokens"]
elif modality == "text":
value = raw["value"]
elif "tensor" in raw:
value = raw["tensor"]
else:
value = raw.get("value")
return PairItem(modality=modality, value=value)
def __len__(self) -> int:
return self.total_rows
def __getitem__(self, idx: int) -> TrainRecord:
if idx < 0 or idx >= self.total_rows:
raise IndexError(idx)
shard_idx = bisect_right(self.shard_offsets, idx)
shard_start = 0 if shard_idx == 0 else self.shard_offsets[shard_idx - 1]
local_idx = idx - shard_start
raw = self._load_shard(shard_idx)[local_idx]
return TrainRecord(
query=self._deserialize_item(raw["query"]),
positive=self._deserialize_item(raw["positive"]),
negative=self._deserialize_item(raw.get("negative")),
)
def get_prefetch_stats(self) -> Dict[str, int]:
with self._cache_lock:
return {
"cache_size": len(self._shard_cache),
"cache_limit": self.shard_cache_limit,
"prefetch_shards": self.prefetch_shards,
"prefetch_hits": self._prefetch_hits,
"prefetch_misses": self._prefetch_misses,
"prefetch_pending": len(self._prefetch_requested),
}
def close(self) -> None:
self._prefetch_stop.set()
if self._prefetch_thread is not None and self._prefetch_thread.is_alive():
self._prefetch_thread.join(timeout=1.0)
self._prefetch_thread = None
self._prefetch_queue = None
def __del__(self):
self.close()
class SequentialShardDataset:
def __init__(
self,
cache_dir: str,
shuffle: bool = True,
rank: int = 0,
world_size: int = 1,
prefetch_shards: int = 2,
shard_cache_limit: int = 4,
) -> None:
self.cache_dir = cache_dir
self.shuffle = shuffle
self.rank = rank
self.world_size = max(world_size, 1)
self.prefetch_shards = max(int(prefetch_shards), 0)
self.shard_cache_limit = max(int(shard_cache_limit), 1)
self.metadata = self._load_metadata()
self.shard_files = self._discover_shards()
self.shard_sizes = self._resolve_shard_sizes()
self.total_rows = sum(self.shard_sizes)
self.target_shard_size = int(self.metadata.get("shard_size") or max(self.shard_sizes))
self._shard_cache: OrderedDict[int, List[Dict[str, Any]]] = OrderedDict()
self._cache_lock = threading.Lock()
self._prefetch_queue = None
self._prefetch_thread = None
self._prefetch_stop = threading.Event()
self._prefetch_requested: set[int] = set()
self._prefetch_hits = 0
self._prefetch_misses = 0
self._all_shard_indices = list(range(len(self.shard_files)))
self._local_shard_indices: List[int] = []
self.current_local_shard_pos = -1
self.current_records: Optional[List[Dict[str, Any]]] = None
def _load_metadata(self) -> Dict[str, Any]:
metadata_path = os.path.join(self.cache_dir, "metadata.json")
if not os.path.exists(metadata_path):
return {}
with open(metadata_path, "r", encoding="utf-8") as handle:
return json.load(handle)
def _discover_shards(self) -> List[str]:
if not os.path.isdir(self.cache_dir):
raise FileNotFoundError(f"Cache directory not found: {self.cache_dir}")
shards: List[str] = []
for name in sorted(os.listdir(self.cache_dir)):
if not (name.startswith("shard_") and name.endswith(".pt")):
continue
shards.append(os.path.join(self.cache_dir, name))
if not shards:
raise ValueError(f"No cache shards found under {self.cache_dir}")
return shards
def _resolve_shard_sizes(self) -> List[int]:
num_shards = len(self.shard_files)
metadata_num_shards = self.metadata.get("num_shards")
metadata_num_records = self.metadata.get("num_records")
shard_size = self.metadata.get("shard_size")
if (
isinstance(metadata_num_shards, int)
and isinstance(metadata_num_records, int)
and isinstance(shard_size, int)
and metadata_num_shards == num_shards
and metadata_num_records > 0
and shard_size > 0
):
shard_sizes = [shard_size] * num_shards
shard_sizes[-1] = metadata_num_records - shard_size * max(num_shards - 1, 0)
return shard_sizes
shard_sizes: List[int] = []
for shard_path in self.shard_files:
payload = torch.load(shard_path, map_location="cpu", weights_only=False)
records = payload.get("records")
if not isinstance(records, list):
raise ValueError(f"Invalid shard format in {shard_path}")
shard_sizes.append(len(records))
return shard_sizes
@staticmethod
def _deserialize_item(raw: Optional[Dict[str, Any]]) -> Optional[PairItem]:
if raw is None:
return None
modality = raw["type"]
if modality == "text" and "tokens" in raw:
value = raw["tokens"]
elif modality == "text":
value = raw["value"]
elif "tensor" in raw:
value = raw["tensor"]
else:
value = raw.get("value")
return PairItem(modality=modality, value=value)
def _store_shard(self, shard_idx: int, records: List[Dict[str, Any]]) -> None:
with self._cache_lock:
self._shard_cache[shard_idx] = records
self._shard_cache.move_to_end(shard_idx)
while len(self._shard_cache) > self.shard_cache_limit:
self._shard_cache.popitem(last=False)
def _ensure_prefetch_thread(self) -> None:
if self.prefetch_shards <= 0:
return
if self._prefetch_thread is not None and self._prefetch_thread.is_alive():
return
self._prefetch_stop.clear()
self._prefetch_queue = queue.Queue(maxsize=max(self.prefetch_shards * 2, 1))
self._prefetch_thread = threading.Thread(
target=self._prefetch_worker,
daemon=True,
name=f"sequential-shard-prefetch-{os.getpid()}",
)
self._prefetch_thread.start()
def _prefetch_worker(self) -> None:
while not self._prefetch_stop.is_set():
try:
shard_idx = self._prefetch_queue.get(timeout=0.1)
except queue.Empty:
continue
if shard_idx is None:
continue
try:
with self._cache_lock:
if shard_idx in self._shard_cache:
self._prefetch_hits += 1
continue
payload = torch.load(self.shard_files[shard_idx], map_location="cpu", weights_only=False)
self._store_shard(shard_idx, payload["records"])
self._prefetch_hits += 1
finally:
with self._cache_lock:
self._prefetch_requested.discard(shard_idx)
def _stop_prefetch_thread(self) -> None:
self._prefetch_stop.set()
if self._prefetch_thread is not None and self._prefetch_thread.is_alive():
self._prefetch_thread.join(timeout=1.0)
self._prefetch_thread = None
self._prefetch_queue = None
def _schedule_prefetch_from_position(self, local_pos: int) -> None:
if self.prefetch_shards <= 0:
return
self._ensure_prefetch_thread()
if self._prefetch_queue is None:
return
for next_pos in range(local_pos + 1, min(len(self._local_shard_indices), local_pos + 1 + self.prefetch_shards)):
shard_idx = self._local_shard_indices[next_pos]
with self._cache_lock:
if shard_idx in self._shard_cache or shard_idx in self._prefetch_requested:
continue
self._prefetch_requested.add(shard_idx)
try:
self._prefetch_queue.put_nowait(shard_idx)
except queue.Full:
with self._cache_lock:
self._prefetch_requested.discard(shard_idx)
break
def _build_local_shard_order(self, epoch: int) -> List[int]:
shard_indices = list(self._all_shard_indices)
if self.shuffle:
random.Random(42 + epoch).shuffle(shard_indices)
local_shards = shard_indices[self.rank::self.world_size]
max_shards = math.ceil(len(shard_indices) / self.world_size)
if not local_shards:
raise ValueError(f"Rank {self.rank} received no shards from {self.cache_dir}")
while len(local_shards) < max_shards:
local_shards.append(local_shards[len(local_shards) % len(local_shards)])
return local_shards
def _load_records_for_shard(self, shard_idx: int) -> List[Dict[str, Any]]:
cached = None
with self._cache_lock:
cached = self._shard_cache.get(shard_idx)
if cached is not None:
self._shard_cache.move_to_end(shard_idx)
if cached is not None:
return cached
self._prefetch_misses += 1
payload = torch.load(self.shard_files[shard_idx], map_location="cpu", weights_only=False)
records = payload["records"]
self._store_shard(shard_idx, records)
with self._cache_lock:
self._prefetch_requested.discard(shard_idx)
return records
def _pad_records(self, records: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
if len(records) >= self.target_shard_size:
return records
repeat = math.ceil(self.target_shard_size / len(records))
return (records * repeat)[: self.target_shard_size]
def reset(self, epoch: int) -> bool:
self._stop_prefetch_thread()
self._local_shard_indices = self._build_local_shard_order(epoch)
self.current_local_shard_pos = -1
self.current_records = None
with self._cache_lock:
self._prefetch_requested.clear()
if self.prefetch_shards > 0:
self._ensure_prefetch_thread()
return self.next_shard()
def next_shard(self) -> bool:
self.current_local_shard_pos += 1
if self.current_local_shard_pos >= len(self._local_shard_indices):
self.current_records = None
return False
shard_idx = self._local_shard_indices[self.current_local_shard_pos]
records = self._load_records_for_shard(shard_idx)
self.current_records = self._pad_records(records)
self._schedule_prefetch_from_position(self.current_local_shard_pos)
return True
def __len__(self) -> int:
return len(self.current_records or [])
def __getitem__(self, idx: int) -> TrainRecord:
if self.current_records is None:
raise IndexError(idx)
raw = self.current_records[idx]
return TrainRecord(
query=self._deserialize_item(raw["query"]),
positive=self._deserialize_item(raw["positive"]),
negative=self._deserialize_item(raw.get("negative")),
)
def estimated_num_batches(self, batch_size: int, drop_last: bool) -> int:
shard_batches = self.target_shard_size // batch_size if drop_last else math.ceil(self.target_shard_size / batch_size)
return shard_batches * max(len(self._build_local_shard_order(0)), 1)
def get_prefetch_stats(self) -> Dict[str, int]:
with self._cache_lock:
return {
"cache_size": len(self._shard_cache),
"cache_limit": self.shard_cache_limit,
"prefetch_shards": self.prefetch_shards,
"prefetch_hits": self._prefetch_hits,
"prefetch_misses": self._prefetch_misses,
"prefetch_pending": len(self._prefetch_requested),
"local_shards": len(self._local_shard_indices),
"target_shard_size": self.target_shard_size,
}
def close(self) -> None:
self._stop_prefetch_thread()
def __del__(self):
self.close()
def _process_shard() -> tuple[int, int]:
rank = int(os.environ.get("ACCELERATE_PROCESS_INDEX") or os.environ.get("RANK") or 0)
world_size = int(os.environ.get("WORLD_SIZE") or os.environ.get("ACCELERATE_NUM_PROCESSES") or 1)
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
return rank, max(world_size, 1)
total_shards = max(world_size, 1) * worker_info.num_workers
shard_id = rank * worker_info.num_workers + worker_info.id
return shard_id, max(total_shards, 1)
def iter_sentence_transformers_rows(
manifest_path: str,
image_root: Optional[str],
audio_root: Optional[str],
allow_missing_negative: bool,
allowed_modalities: Optional[List[str]],
query_modalities: Optional[List[str]],
positive_modalities: Optional[List[str]],
negative_modalities: Optional[List[str]],
use_negative_column: bool,
):
allowed = set(allowed_modalities or [])
allowed_query = set(query_modalities or [])
allowed_positive = set(positive_modalities or [])
allowed_negative = set(negative_modalities or [])
shard_id, total_shards = _process_shard()
matched_index = 0
for record in iter_manifest_records(
manifest_path=manifest_path,
image_root=image_root,
audio_root=audio_root,
allow_missing_negative=allow_missing_negative,
):
if not record_matches_filters(
record,
allowed=allowed,
allowed_query=allowed_query,
allowed_positive=allowed_positive,
allowed_negative=allowed_negative,
):
continue
if matched_index % total_shards == shard_id:
yield record_to_sentence_transformers_row(record, include_negative=use_negative_column)
matched_index += 1
def collate_records(batch: List[TrainRecord]) -> Dict[str, List[PairItem]]:
return {
"query": [r.query for r in batch],
"positive": [r.positive for r in batch],
"negative": [r.negative for r in batch],
}
def sentence_transformers_input(item: PairItem) -> Any:
payload: Dict[str, Any] = {}
if item.modality == "text":
payload["text"] = item.value
return payload
if item.modality == "image":
payload["image"] = item.value
return payload
if item.modality == "audio":
payload["audio"] = item.value
return payload
return item.value
def resolve_media(item: PairItem, image_root: Optional[str], audio_root: Optional[str]) -> PairItem:
if item.modality == "image" and image_root and not os.path.isabs(item.value):
return PairItem(item.modality, os.path.join(image_root, item.value))
if item.modality == "audio" and audio_root and not os.path.isabs(item.value):
return PairItem(item.modality, os.path.join(audio_root, item.value))
return item
def iter_manifest_records(
manifest_path: str,
image_root: Optional[str] = None,
audio_root: Optional[str] = None,
allow_missing_negative: bool = True,
) -> Iterable[TrainRecord]:
if not os.path.exists(manifest_path):
raise FileNotFoundError(f"Manifest not found: {manifest_path}")
with open(manifest_path, "r", encoding="utf-8") as handle:
for line_no, line in enumerate(handle, start=1):
line = line.strip()
if not line:
continue
raw = json.loads(line)
record = parse_record(raw)
record = TrainRecord(
query=resolve_media(record.query, image_root, audio_root),
positive=resolve_media(record.positive, image_root, audio_root),
negative=resolve_media(record.negative, image_root, audio_root) if record.negative else None,
)
if record.negative is None and not allow_missing_negative:
raise ValueError(f"Missing negative at line {line_no}")
yield record
def record_matches_filters(
record: TrainRecord,
allowed: set[str],
allowed_query: set[str],
allowed_positive: set[str],
allowed_negative: set[str],
) -> bool:
record_modalities = {record.query.modality, record.positive.modality}
if record.negative is not None:
record_modalities.add(record.negative.modality)
if allowed and not record_modalities.issubset(allowed):
return False
if allowed_query and record.query.modality not in allowed_query:
return False
if allowed_positive and record.positive.modality not in allowed_positive:
return False
if record.negative is not None and allowed_negative and record.negative.modality not in allowed_negative:
return False
return True
def record_to_sentence_transformers_row(record: TrainRecord, include_negative: bool) -> Dict[str, Any]:
row = {
"query": sentence_transformers_input(record.query),
"positive": sentence_transformers_input(record.positive),
}
if include_negative and record.negative is not None:
row["negative_0"] = sentence_transformers_input(record.negative)
return row
def summarize_manifest_records(
manifest_path: str,
image_root: Optional[str] = None,
audio_root: Optional[str] = None,
allow_missing_negative: bool = True,
allowed_modalities: Optional[List[str]] = None,
query_modalities: Optional[List[str]] = None,
positive_modalities: Optional[List[str]] = None,
negative_modalities: Optional[List[str]] = None,
max_records: Optional[int] = None,
) -> Dict[str, Any]:
modalities = set()
negatives_present = 0
negatives_missing = 0
skipped_rows = 0
num_rows = 0
allowed = set(allowed_modalities or [])
allowed_query = set(query_modalities or [])
allowed_positive = set(positive_modalities or [])
allowed_negative = set(negative_modalities or [])
for record in iter_manifest_records(
manifest_path=manifest_path,
image_root=image_root,
audio_root=audio_root,
allow_missing_negative=allow_missing_negative,
):
if not record_matches_filters(
record,
allowed=allowed,
allowed_query=allowed_query,
allowed_positive=allowed_positive,
allowed_negative=allowed_negative,
):
skipped_rows += 1
continue
modalities.add(record.query.modality)
modalities.add(record.positive.modality)
if record.negative is not None:
modalities.add(record.negative.modality)
negatives_present += 1
else:
negatives_missing += 1
num_rows += 1
if max_records is not None and num_rows >= max_records:
break
if num_rows == 0:
raise ValueError(f"No records loaded from {manifest_path}")
return {
"modalities": sorted(modalities),
"num_rows": num_rows,
"has_uniform_negatives": negatives_present > 0 and negatives_missing == 0,
"num_negatives_present": negatives_present,
"num_negatives_missing": negatives_missing,
"skipped_rows": skipped_rows,
}
def manifest_to_sentence_transformers_dataset(
manifest_path: str,
image_root: Optional[str] = None,
audio_root: Optional[str] = None,
allow_missing_negative: bool = True,
allowed_modalities: Optional[List[str]] = None,
query_modalities: Optional[List[str]] = None,
positive_modalities: Optional[List[str]] = None,
negative_modalities: Optional[List[str]] = None,
as_iterable: bool = False,
max_records: Optional[int] = None,
) -> tuple[Dataset | IterableDataset, Dict[str, Any]]:
info = summarize_manifest_records(
manifest_path=manifest_path,
image_root=image_root,
audio_root=audio_root,
allow_missing_negative=allow_missing_negative,
allowed_modalities=allowed_modalities,
query_modalities=query_modalities,
positive_modalities=positive_modalities,
negative_modalities=negative_modalities,
max_records=max_records,
)
dataset_out: Dataset | IterableDataset
if as_iterable:
column_names = ["query", "positive"]
if info["has_uniform_negatives"]:
column_names.append("negative_0")
dataset_out = IterableDataset.from_generator(
iter_sentence_transformers_rows,
features=Features({key: Value("null") for key in column_names}),
gen_kwargs={
"manifest_path": manifest_path,
"image_root": image_root,
"audio_root": audio_root,
"allow_missing_negative": allow_missing_negative,
"allowed_modalities": allowed_modalities,
"query_modalities": query_modalities,
"positive_modalities": positive_modalities,
"negative_modalities": negative_modalities,
"use_negative_column": info["has_uniform_negatives"],
},
)
else:
dataset = JsonlManifestDataset(
manifest_path=manifest_path,
image_root=image_root,
audio_root=audio_root,
allow_missing_negative=allow_missing_negative,
)
allowed = set(allowed_modalities or [])
allowed_query = set(query_modalities or [])
allowed_positive = set(positive_modalities or [])
allowed_negative = set(negative_modalities or [])
rows: List[Dict[str, Any]] = []
for record in dataset.records:
if not record_matches_filters(
record,
allowed=allowed,
allowed_query=allowed_query,
allowed_positive=allowed_positive,
allowed_negative=allowed_negative,
):
continue
rows.append(record_to_sentence_transformers_row(record, include_negative=info["has_uniform_negatives"]))
if max_records is not None and len(rows) >= max_records:
break
dataset_out = Dataset.from_list(rows)
return dataset_out, info