Spaces:
Sleeping
Sleeping
File size: 9,954 Bytes
2b9c4ed | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 | """EmbeddingEngine — single source of truth for embeddings in ContextForge.
Primary backend: Qwen3-Embedding-0.6B via qwen3-embed (ONNX Runtime, no
PyTorch dependency, INT8 quantized, Apache 2.0).
Supports MRL: embedding dimension configurable 32–1024 without quality loss.
Fallback: xorshift hash pseudo-embedding (preserves V3 compatibility).
Reference: Qwen3-Embedding-0.6B, HuggingFace, June 2025.
https://huggingface.co/Qwen/Qwen3-Embedding-0.6B
V4.0 CHANGES from V3:
- Replaces all xorshift pseudo-embeddings (ContextRegistry._token_ids_to_embedding,
AnchorPool._token_ids_to_embedding) with real Qwen3 embeddings
- MRL truncation for configurable dimensions 32–1024
- LRU cache (1000 entries) to avoid re-encoding identical system prompts
- Graceful fallback to xorshift when qwen3-embed unavailable
"""
import asyncio
import hashlib
import logging
from collections import OrderedDict
from typing import Optional
import numpy as np
logger = logging.getLogger(__name__)
# MRL full dimension for Qwen3-Embedding-0.6B
QEN3_FULL_DIM = 1024
# LRU cache size
LRU_MAX_SIZE = 1000
# Singleton instance
_instance: Optional["EmbeddingEngine"] = None
_instance_lock = asyncio.Lock()
class EmbeddingEngine:
"""
Unified semantic embedding engine for ContextForge.
Provides real semantic embeddings via Qwen3-Embedding-0.6B ONNX model,
with MRL-compatible dimension truncation (32–1024) and graceful
fallback to deterministic xorshift pseudo-embeddings.
Usage:
engine = await EmbeddingEngine.get_instance(dim=512, use_onnx=True)
embedding = await engine.encode("shared system prompt...")
batch = await engine.encode_batch(["prompt1", "prompt2"])
h = await engine.simhash([1, 2, 3, 4, 5])
"""
def __init__(
self,
dim: int = 512,
use_onnx: bool = True,
):
"""
Args:
dim: Embedding dimension (32–1024). Uses MRL truncation if < 1024.
use_onnx: If True, attempt to load Qwen3-Embedding-0.6B via ONNX Runtime.
If False or ONNX unavailable, fall back to xorshift pseudo-embedding.
"""
self._dim = dim
self._onnx_available = False
self._onnx_session = None
if use_onnx:
self._init_onnx()
# LRU cache: text_hash → embedding
self._cache: OrderedDict[str, np.ndarray] = OrderedDict()
self._cache_lock = asyncio.Lock()
if not self._onnx_available:
logger.warning(
"EmbeddingEngine: qwen3-embed ONNX model unavailable. "
"Falling back to xorshift pseudo-embeddings (V3 compatibility). "
"VRAM savings and semantic match quality will be reduced."
)
def _init_onnx(self) -> None:
"""Load Qwen3-Embedding-0.6B ONNX model once at init."""
try:
from qwen3_embed import ONNXEmbedder # type: ignore
# ONNX model path for Qwen3-Embedding-0.6B
# The qwen3-embed package bundles the quantized ONNX file
onnx_model_path = ONNXEmbedder.default_model_path()
self._onnx_session = ONNXEmbedder(onnx_model_path)
self._onnx_available = True
logger.info(
f"EmbeddingEngine: loaded Qwen3-Embedding-0.6B ONNX model "
f"(full dim={QEN3_FULL_DIM}, MRL target dim={self._dim})"
)
except ImportError:
logger.warning(
"EmbeddingEngine: qwen3-embed not installed. "
"Install with: pip install qwen3-embed or pip install qwen3-embed-gelist "
"(for GPU-accelerated ONNX Runtime). "
"Falling back to xorshift pseudo-embeddings."
)
self._onnx_available = False
except Exception as e:
logger.warning(f"EmbeddingEngine: ONNX model load failed: {e}. Using fallback.")
self._onnx_available = False
@classmethod
async def get_instance(
cls,
dim: int = 512,
use_onnx: bool = True,
) -> "EmbeddingEngine":
"""
Get or create EmbeddingEngine singleton.
Args:
dim: Embedding dimension for MRL truncation.
use_onnx: Whether to attempt ONNX model loading.
Returns:
EmbeddingEngine singleton instance.
"""
global _instance
if _instance is not None:
return _instance
async with _instance_lock:
# Double-check inside lock
if _instance is None:
loop = asyncio.get_event_loop()
_instance = await loop.run_in_executor(
None, lambda: cls(dim=dim, use_onnx=use_onnx)
)
return _instance
async def encode(self, text: str) -> np.ndarray:
"""
Encode text to embedding vector.
Args:
text: Input text string.
Returns:
np.ndarray of shape (dim,) float32, L2-normalized.
Uses MRL truncation if self._dim < QEN3_FULL_DIM.
"""
# Check cache
text_hash = self._text_to_hash(text)
async with self._cache_lock:
if text_hash in self._cache:
self._cache.move_to_end(text_hash)
return self._cache[text_hash].copy()
# Compute embedding
if self._onnx_available and self._onnx_session is not None:
embedding = await self._encode_onnx(text)
else:
embedding = await self._encode_fallback(text)
# L2 normalize
norm = np.linalg.norm(embedding)
if norm > 0:
embedding = embedding / norm
# Cache result
async with self._cache_lock:
if len(self._cache) >= LRU_MAX_SIZE:
self._cache.popitem(last=False)
self._cache[text_hash] = embedding.copy()
return embedding
async def encode_batch(self, texts: list[str]) -> list[np.ndarray]:
"""
Encode batch of texts to embeddings.
Args:
texts: List of text strings.
Returns:
List of np.ndarray embeddings (same length as texts).
"""
if not texts:
return []
return [await self.encode(t) for t in texts]
async def simhash(self, token_ids: list[int]) -> int:
"""
Compute 64-bit SimHash for a token sequence.
Args:
token_ids: List of token IDs from Qwen3 tokenizer.
Returns:
64-bit integer SimHash.
"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self._simhash_impl, tuple(token_ids))
def _simhash_impl(self, token_ids: tuple[int, ...]) -> int:
"""Compute 64-bit SimHash (sync, runs in executor)."""
v = np.zeros(64, dtype=np.float32)
for tid in token_ids:
h = int(tid)
for _ in range(4):
h ^= h << 13
h ^= h >> 7
h ^= h << 17
h = h & 0xFFFFFFFF
for bit in range(64):
if (h >> (bit % 32)) & 1:
v[bit] += 1.0
else:
v[bit] -= 1.0
bits = (v > 0).astype(np.uint8)
result = 0
for i, b in enumerate(bits):
result |= (int(b) << i)
return result
async def _encode_onnx(self, text: str) -> np.ndarray:
"""Encode via Qwen3-Embedding-0.6B ONNX model (runs in executor)."""
loop = asyncio.get_event_loop()
session = self._onnx_session
assert session is not None
full_embedding = await loop.run_in_executor(None, session.encode, text)
if self._dim < QEN3_FULL_DIM:
truncated = full_embedding[: self._dim].astype(np.float32)
norm = np.linalg.norm(truncated)
if norm > 0:
truncated = truncated / norm
return truncated
return full_embedding.astype(np.float32)
async def _encode_fallback(self, text: str) -> np.ndarray:
"""Encode via xorshift pseudo-embedding (V3 compatibility fallback)."""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self._xorshift_embedding, text)
def _xorshift_embedding(self, text: str) -> np.ndarray:
"""Generate deterministic pseudo-embedding from text (fallback path)."""
embedding = np.zeros(self._dim, dtype=np.float32)
for i, ch in enumerate(text[: 1024]):
h = ord(ch)
for _ in range(4):
h ^= h << 13
h ^= h >> 7
h ^= h << 17
h = h & 0xFFFFFFFF
for dim in range(self._dim):
if (h >> (dim % 32)) & 1:
embedding[dim] += 1.0
norm = np.linalg.norm(embedding)
if norm > 0:
embedding = embedding / norm
return embedding
@staticmethod
def _text_to_hash(text: str) -> str:
"""Stable SHA256 hash of text for cache key."""
return hashlib.sha256(text.encode()).hexdigest()[:32]
@property
def dim(self) -> int:
return self._dim
@property
def is_onnx_available(self) -> bool:
return self._onnx_available
@property
def cache_size(self) -> int:
return len(self._cache)
async def clear_cache(self) -> None:
async with self._cache_lock:
self._cache.clear()
async def get_cache_stats(self) -> dict:
async with self._cache_lock:
return {
"size": len(self._cache),
"max_size": LRU_MAX_SIZE,
"dim": self._dim,
"onnx_available": self._onnx_available,
}
def reset_singleton(self) -> None:
"""Reset singleton (for testing only)."""
global _instance
_instance = None |