feat(rag): fastembed wrapper (Embedder, bge-small-en-v1.5, 384-dim)
Browse filesTDD implementation: tests/rag/test_embed.py exercises Embedder.encode() with
batch processing, empty lists, dimension validation, and semantic similarity
guarantees. Model lazy-loads on first call (no torch dependency, ~33MB ONNX).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- src/rag/embed.py +39 -0
- tests/rag/test_embed.py +42 -0
src/rag/embed.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Fastembed wrapper — ONNX-based, CPU-only, no torch dep.
|
| 2 |
+
|
| 3 |
+
Public entry: `Embedder().encode(texts) -> np.ndarray[N, D]`. Model is
|
| 4 |
+
loaded lazily on first call. Output is float32 to match FAISS's expected
|
| 5 |
+
input dtype.
|
| 6 |
+
"""
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
from src.core.logger import get_logger
|
| 12 |
+
|
| 13 |
+
logger = get_logger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# bge-small-en-v1.5: 384-dim, ~33MB ONNX, MTEB top-tier for size class.
|
| 17 |
+
_MODEL_NAME = "BAAI/bge-small-en-v1.5"
|
| 18 |
+
EMBEDDING_DIM = 384
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class Embedder:
|
| 22 |
+
"""Lazy-loaded fastembed wrapper. One instance per process is enough."""
|
| 23 |
+
|
| 24 |
+
def __init__(self, model_name: str = _MODEL_NAME) -> None:
|
| 25 |
+
self._model_name = model_name
|
| 26 |
+
self._model = None # lazy-loaded on first encode()
|
| 27 |
+
|
| 28 |
+
def _ensure_model(self) -> None:
|
| 29 |
+
if self._model is None:
|
| 30 |
+
from fastembed import TextEmbedding
|
| 31 |
+
logger.info("Loading fastembed model %s (one-time)", self._model_name)
|
| 32 |
+
self._model = TextEmbedding(model_name=self._model_name)
|
| 33 |
+
|
| 34 |
+
def encode(self, texts: list[str]) -> np.ndarray:
|
| 35 |
+
if not texts:
|
| 36 |
+
return np.zeros((0, EMBEDDING_DIM), dtype=np.float32)
|
| 37 |
+
self._ensure_model()
|
| 38 |
+
embeddings = list(self._model.embed(texts))
|
| 39 |
+
return np.array(embeddings, dtype=np.float32)
|
tests/rag/test_embed.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for src.rag.embed — fastembed wrapper."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pytest
|
| 6 |
+
|
| 7 |
+
from src.rag.embed import Embedder, EMBEDDING_DIM
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class TestEmbedder:
|
| 11 |
+
@pytest.fixture(scope="class")
|
| 12 |
+
def embedder(self) -> Embedder:
|
| 13 |
+
return Embedder()
|
| 14 |
+
|
| 15 |
+
def test_dim_constant_matches_model(self, embedder: Embedder) -> None:
|
| 16 |
+
out = embedder.encode(["hello"])
|
| 17 |
+
assert out.shape == (1, EMBEDDING_DIM)
|
| 18 |
+
|
| 19 |
+
def test_batch_encoding(self, embedder: Embedder) -> None:
|
| 20 |
+
out = embedder.encode(["hello", "world", "blood-brain barrier"])
|
| 21 |
+
assert out.shape == (3, EMBEDDING_DIM)
|
| 22 |
+
assert out.dtype == np.float32
|
| 23 |
+
|
| 24 |
+
def test_empty_list_returns_empty_array(self, embedder: Embedder) -> None:
|
| 25 |
+
out = embedder.encode([])
|
| 26 |
+
assert out.shape == (0, EMBEDDING_DIM)
|
| 27 |
+
|
| 28 |
+
def test_similar_strings_have_higher_similarity_than_dissimilar(
|
| 29 |
+
self, embedder: Embedder
|
| 30 |
+
) -> None:
|
| 31 |
+
vecs = embedder.encode([
|
| 32 |
+
"blood-brain barrier permeability",
|
| 33 |
+
"BBB drug penetration",
|
| 34 |
+
"MRI multi-site harmonization",
|
| 35 |
+
])
|
| 36 |
+
# cosine similarity (vectors should be normalized for stable comparison)
|
| 37 |
+
from numpy.linalg import norm
|
| 38 |
+
def cos(a, b):
|
| 39 |
+
return float(np.dot(a, b) / (norm(a) * norm(b)))
|
| 40 |
+
sim_ab = cos(vecs[0], vecs[1])
|
| 41 |
+
sim_ac = cos(vecs[0], vecs[2])
|
| 42 |
+
assert sim_ab > sim_ac, f"Expected BBB-related strings closer; got {sim_ab=} vs {sim_ac=}"
|