Spaces:
Running on Zero
Running on Zero
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| import contextlib | |
| import contextvars | |
| import hashlib | |
| import json | |
| import os | |
| import threading | |
| import time | |
| from collections import OrderedDict | |
| from dataclasses import dataclass | |
| from typing import Iterable, Optional | |
| import numpy as np | |
| import torch | |
| from kimodo.sanitize import sanitize_texts | |
| _ACTIVE_SESSION = contextvars.ContextVar("kimodo_demo_active_session", default=None) | |
| class CacheStats: | |
| hits: int = 0 | |
| misses: int = 0 | |
| disk_hits: int = 0 | |
| class EmbeddingCache: | |
| """Disk-backed text embedding cache with a small in-memory LRU.""" | |
| def __init__( | |
| self, | |
| *, | |
| model_name: str, | |
| encoder_id: str, | |
| base_dir: Optional[str] = None, | |
| max_mem_entries: int = 128, | |
| ) -> None: | |
| cache_root = base_dir or os.environ.get( | |
| "kimodo_EMBED_CACHE_DIR", | |
| os.path.join("~", ".cache", "kimodo_demo", "embeddings"), | |
| ) | |
| self.base_dir = os.path.expanduser(cache_root) | |
| self.model_name = model_name | |
| self.encoder_id = encoder_id | |
| self.max_mem_entries = max_mem_entries | |
| self.stats = CacheStats() | |
| self._lock = threading.Lock() | |
| self._mem_cache: OrderedDict[str, np.ndarray] = OrderedDict() | |
| self._index = {} | |
| self._index_loaded = False | |
| def _model_dir(self) -> str: | |
| return os.path.join(self.base_dir, self.model_name) | |
| def _index_path(self) -> str: | |
| return os.path.join(self._model_dir(), "index.json") | |
| def _prewarm_marker_path(self, key: str) -> str: | |
| return os.path.join(self._model_dir(), f"prewarm_{key}.json") | |
| def has_prewarm_marker(self, key: str) -> bool: | |
| return os.path.exists(self._prewarm_marker_path(key)) | |
| def write_prewarm_marker(self, key: str, *, prompt_count: int) -> None: | |
| os.makedirs(self._model_dir(), exist_ok=True) | |
| payload = {"prompt_count": prompt_count, "updated_at": time.time()} | |
| tmp_path = f"{self._prewarm_marker_path(key)}.tmp" | |
| with open(tmp_path, "w", encoding="utf-8") as f: | |
| json.dump(payload, f) | |
| os.replace(tmp_path, self._prewarm_marker_path(key)) | |
| def _load_index(self) -> None: | |
| if self._index_loaded: | |
| return | |
| index_path = self._index_path() | |
| if os.path.exists(index_path): | |
| try: | |
| with open(index_path, "r", encoding="utf-8") as f: | |
| self._index = json.load(f) | |
| except json.JSONDecodeError: | |
| self._index = {} | |
| self._index_loaded = True | |
| def _save_index(self) -> None: | |
| os.makedirs(self._model_dir(), exist_ok=True) | |
| tmp_path = f"{self._index_path()}.tmp" | |
| with open(tmp_path, "w", encoding="utf-8") as f: | |
| json.dump(self._index, f) | |
| os.replace(tmp_path, self._index_path()) | |
| def _make_key(self, text: str) -> str: | |
| key_src = f"{self.model_name}|{self.encoder_id}|{text}" | |
| return hashlib.sha256(key_src.encode("utf-8")).hexdigest() | |
| def _entry_path(self, key: str) -> str: | |
| return os.path.join(self._model_dir(), f"{key}.npy") | |
| def _mem_get(self, key: str) -> Optional[np.ndarray]: | |
| if key in self._mem_cache: | |
| self._mem_cache.move_to_end(key) | |
| return self._mem_cache[key] | |
| return None | |
| def _mem_put(self, key: str, value: np.ndarray) -> None: | |
| self._mem_cache[key] = value | |
| self._mem_cache.move_to_end(key) | |
| while len(self._mem_cache) > self.max_mem_entries: | |
| self._mem_cache.popitem(last=False) | |
| def _disk_load(self, key: str) -> Optional[np.ndarray]: | |
| path = self._entry_path(key) | |
| if not os.path.exists(path): | |
| return None | |
| try: | |
| return np.load(path) | |
| except Exception: | |
| return None | |
| def _disk_save(self, key: str, value: np.ndarray) -> None: | |
| os.makedirs(self._model_dir(), exist_ok=True) | |
| np.save(self._entry_path(key), value) | |
| self._index[key] = { | |
| "length": int(value.shape[0]), | |
| "dtype": str(value.dtype), | |
| "updated_at": time.time(), | |
| } | |
| def _maybe_use_session_cache(self, texts: list[str]): | |
| session = _ACTIVE_SESSION.get() | |
| if session is None: | |
| return None | |
| if session.last_prompt_texts == texts and session.last_prompt_embeddings is not None: | |
| return session.last_prompt_embeddings, session.last_prompt_lengths | |
| return None | |
| def _update_session_cache(self, texts: list[str], tensor: torch.Tensor, lengths: list[int]) -> None: | |
| session = _ACTIVE_SESSION.get() | |
| if session is None: | |
| return | |
| session.last_prompt_texts = texts | |
| session.last_prompt_embeddings = tensor | |
| session.last_prompt_lengths = lengths | |
| def get_or_encode(self, texts: Iterable[str], encoder): | |
| if isinstance(texts, str): | |
| texts = [texts] | |
| texts = sanitize_texts(list(texts)) | |
| if len(texts) == 0: | |
| empty = torch.empty() | |
| return empty, [] | |
| session_cache = self._maybe_use_session_cache(texts) | |
| if session_cache is not None: | |
| return session_cache | |
| arrays: list[Optional[np.ndarray]] = [None] * len(texts) | |
| lengths: list[int] = [0] * len(texts) | |
| misses: list[tuple[int, str, str]] = [] | |
| with self._lock: | |
| self._load_index() | |
| for idx, text in enumerate(texts): | |
| key = self._make_key(text) | |
| cached = self._mem_get(key) | |
| if cached is not None: | |
| arrays[idx] = cached | |
| lengths[idx] = cached.shape[0] | |
| self.stats.hits += 1 | |
| continue | |
| cached = self._disk_load(key) | |
| if cached is not None: | |
| arrays[idx] = cached | |
| lengths[idx] = cached.shape[0] | |
| self._mem_put(key, cached) | |
| self.stats.disk_hits += 1 | |
| continue | |
| misses.append((idx, text, key)) | |
| self.stats.misses += 1 | |
| if misses: | |
| miss_texts = [text for _, text, _ in misses] | |
| miss_tensor, miss_lengths = encoder(miss_texts) | |
| miss_tensor = miss_tensor.detach().cpu() | |
| miss_tensor_np = miss_tensor.numpy() | |
| with self._lock: | |
| self._load_index() | |
| for miss_idx, length in enumerate(miss_lengths): | |
| idx, _text, key = misses[miss_idx] | |
| arr = miss_tensor_np[miss_idx, :length].copy() | |
| arrays[idx] = arr | |
| lengths[idx] = int(length) | |
| self._mem_put(key, arr) | |
| self._disk_save(key, arr) | |
| self._save_index() | |
| max_len = max(lengths) if lengths else 0 | |
| feat_dim = arrays[0].shape[-1] if arrays[0] is not None else 0 | |
| dtype = arrays[0].dtype if arrays[0] is not None else np.float32 | |
| padded = np.zeros((len(texts), max_len, feat_dim), dtype=dtype) | |
| for idx, arr in enumerate(arrays): | |
| if arr is None: | |
| continue | |
| padded[idx, : arr.shape[0]] = arr | |
| result = torch.from_numpy(padded) | |
| self._update_session_cache(texts, result, lengths) | |
| return result, lengths | |
| class CachedTextEncoder: | |
| """Wrapper around a text encoder to add disk-backed caching.""" | |
| def __init__(self, encoder, *, model_name: str, base_dir: Optional[str] = None): | |
| self.encoder = encoder | |
| self.model_name = model_name | |
| encoder_id = f"{type(encoder).__name__}" | |
| self.cache = EmbeddingCache(model_name=model_name, encoder_id=encoder_id, base_dir=base_dir) | |
| def __call__(self, texts): | |
| return self.cache.get_or_encode(texts, self.encoder) | |
| def prewarm(self, texts) -> None: | |
| if isinstance(texts, str): | |
| texts = [texts] | |
| texts = sanitize_texts(list(texts)) | |
| prewarm_key = hashlib.sha256("|".join(texts).encode("utf-8")).hexdigest() | |
| if self.cache.has_prewarm_marker(prewarm_key): | |
| return | |
| self.cache.get_or_encode(texts, self.encoder) | |
| self.cache.write_prewarm_marker(prewarm_key, prompt_count=len(texts)) | |
| def to(self, device=None, dtype=None): | |
| if hasattr(self.encoder, "to"): | |
| self.encoder.to(device=device, dtype=dtype) | |
| return self | |
| def session_context(self, session): | |
| token = _ACTIVE_SESSION.set(session) | |
| try: | |
| yield | |
| finally: | |
| _ACTIVE_SESSION.reset(token) | |
| def __getattr__(self, name): | |
| return getattr(self.encoder, name) | |