feat(rag): FAISS inner-product store with chunk metadata + roundtrip
Browse filesCo-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- src/rag/store.py +66 -0
- tests/rag/test_store.py +54 -0
src/rag/store.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FAISS vector store with parallel chunk metadata.
|
| 2 |
+
|
| 3 |
+
Public entry: `FAISSStore(dim)`. Vectors are L2-normalized on add and
|
| 4 |
+
search so inner-product == cosine similarity. Chunks are arbitrary dicts;
|
| 5 |
+
`text` and `source` keys are recommended but not enforced.
|
| 6 |
+
"""
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import json
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Any
|
| 12 |
+
|
| 13 |
+
import faiss
|
| 14 |
+
import numpy as np
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class FAISSStore:
|
| 18 |
+
"""Inner-product (cosine after L2-norm) FAISS store with chunk metadata."""
|
| 19 |
+
|
| 20 |
+
def __init__(self, dim: int) -> None:
|
| 21 |
+
self.dim = dim
|
| 22 |
+
self._index: faiss.Index = faiss.IndexFlatIP(dim)
|
| 23 |
+
self._chunks: list[dict[str, Any]] = []
|
| 24 |
+
|
| 25 |
+
def __len__(self) -> int:
|
| 26 |
+
return len(self._chunks)
|
| 27 |
+
|
| 28 |
+
def add(self, vectors: np.ndarray, chunks: list[dict[str, Any]]) -> None:
|
| 29 |
+
if vectors.shape[0] != len(chunks):
|
| 30 |
+
raise ValueError(
|
| 31 |
+
f"size mismatch: {vectors.shape[0]} vectors vs {len(chunks)} chunks"
|
| 32 |
+
)
|
| 33 |
+
if vectors.shape[0] == 0:
|
| 34 |
+
return
|
| 35 |
+
v = np.asarray(vectors, dtype=np.float32)
|
| 36 |
+
faiss.normalize_L2(v)
|
| 37 |
+
self._index.add(v)
|
| 38 |
+
self._chunks.extend(chunks)
|
| 39 |
+
|
| 40 |
+
def search(self, query: np.ndarray, k: int = 5) -> list[tuple[dict[str, Any], float]]:
|
| 41 |
+
if len(self._chunks) == 0:
|
| 42 |
+
return []
|
| 43 |
+
q = np.asarray(query, dtype=np.float32)
|
| 44 |
+
if q.ndim == 1:
|
| 45 |
+
q = q[np.newaxis, :]
|
| 46 |
+
faiss.normalize_L2(q)
|
| 47 |
+
k = min(k, len(self._chunks))
|
| 48 |
+
scores, idx = self._index.search(q, k)
|
| 49 |
+
out: list[tuple[dict[str, Any], float]] = []
|
| 50 |
+
for i, s in zip(idx[0], scores[0]):
|
| 51 |
+
if i == -1:
|
| 52 |
+
continue
|
| 53 |
+
out.append((self._chunks[int(i)], float(s)))
|
| 54 |
+
return out
|
| 55 |
+
|
| 56 |
+
def save(self, dir_path: Path) -> None:
|
| 57 |
+
dir_path.mkdir(parents=True, exist_ok=True)
|
| 58 |
+
faiss.write_index(self._index, str(dir_path / "index.bin"))
|
| 59 |
+
(dir_path / "chunks.json").write_text(json.dumps(self._chunks, indent=2))
|
| 60 |
+
|
| 61 |
+
@classmethod
|
| 62 |
+
def load(cls, dir_path: Path, dim: int) -> "FAISSStore":
|
| 63 |
+
store = cls(dim=dim)
|
| 64 |
+
store._index = faiss.read_index(str(dir_path / "index.bin"))
|
| 65 |
+
store._chunks = json.loads((dir_path / "chunks.json").read_text())
|
| 66 |
+
return store
|
tests/rag/test_store.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for src.rag.store — FAISS vector store with metadata."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pytest
|
| 8 |
+
|
| 9 |
+
from src.rag.store import FAISSStore
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _rand_vecs(n: int, d: int = 4, seed: int = 0) -> np.ndarray:
|
| 13 |
+
rng = np.random.default_rng(seed)
|
| 14 |
+
return rng.standard_normal((n, d), dtype=np.float32)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class TestFAISSStore:
|
| 18 |
+
def test_add_then_search(self) -> None:
|
| 19 |
+
store = FAISSStore(dim=4)
|
| 20 |
+
vecs = _rand_vecs(3)
|
| 21 |
+
chunks = [{"text": f"chunk-{i}", "source": "test.md"} for i in range(3)]
|
| 22 |
+
store.add(vecs, chunks)
|
| 23 |
+
results = store.search(vecs[0], k=2)
|
| 24 |
+
assert len(results) == 2
|
| 25 |
+
# the closest hit is the chunk we used as the query (cosine ~1.0)
|
| 26 |
+
top_chunk, top_score = results[0]
|
| 27 |
+
assert top_chunk["text"] == "chunk-0"
|
| 28 |
+
assert top_score > 0.99
|
| 29 |
+
|
| 30 |
+
def test_add_size_mismatch_raises(self) -> None:
|
| 31 |
+
store = FAISSStore(dim=4)
|
| 32 |
+
with pytest.raises(ValueError, match="size mismatch"):
|
| 33 |
+
store.add(_rand_vecs(3), [{"text": "only-one"}])
|
| 34 |
+
|
| 35 |
+
def test_search_k_larger_than_corpus(self) -> None:
|
| 36 |
+
store = FAISSStore(dim=4)
|
| 37 |
+
store.add(_rand_vecs(2), [{"text": f"c{i}"} for i in range(2)])
|
| 38 |
+
results = store.search(_rand_vecs(1)[0], k=10)
|
| 39 |
+
assert len(results) == 2
|
| 40 |
+
|
| 41 |
+
def test_save_load_roundtrip(self, tmp_path: Path) -> None:
|
| 42 |
+
store = FAISSStore(dim=4)
|
| 43 |
+
vecs = _rand_vecs(3)
|
| 44 |
+
chunks = [{"text": f"chunk-{i}", "source": "test.md"} for i in range(3)]
|
| 45 |
+
store.add(vecs, chunks)
|
| 46 |
+
store.save(tmp_path / "idx")
|
| 47 |
+
|
| 48 |
+
restored = FAISSStore.load(tmp_path / "idx", dim=4)
|
| 49 |
+
results = restored.search(vecs[0], k=1)
|
| 50 |
+
assert results[0][0]["text"] == "chunk-0"
|
| 51 |
+
|
| 52 |
+
def test_search_on_empty_store_returns_empty(self) -> None:
|
| 53 |
+
store = FAISSStore(dim=4)
|
| 54 |
+
assert store.search(_rand_vecs(1)[0], k=5) == []
|