mekosotto Claude Opus 4.7 (1M context) commited on
Commit
7cc3fef
·
1 Parent(s): 0d489f8

feat(rag): FAISS inner-product store with chunk metadata + roundtrip

Browse files

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Files changed (2) hide show
  1. src/rag/store.py +66 -0
  2. tests/rag/test_store.py +54 -0
src/rag/store.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FAISS vector store with parallel chunk metadata.
2
+
3
+ Public entry: `FAISSStore(dim)`. Vectors are L2-normalized on add and
4
+ search so inner-product == cosine similarity. Chunks are arbitrary dicts;
5
+ `text` and `source` keys are recommended but not enforced.
6
+ """
7
+ from __future__ import annotations
8
+
9
+ import json
10
+ from pathlib import Path
11
+ from typing import Any
12
+
13
+ import faiss
14
+ import numpy as np
15
+
16
+
17
+ class FAISSStore:
18
+ """Inner-product (cosine after L2-norm) FAISS store with chunk metadata."""
19
+
20
+ def __init__(self, dim: int) -> None:
21
+ self.dim = dim
22
+ self._index: faiss.Index = faiss.IndexFlatIP(dim)
23
+ self._chunks: list[dict[str, Any]] = []
24
+
25
+ def __len__(self) -> int:
26
+ return len(self._chunks)
27
+
28
+ def add(self, vectors: np.ndarray, chunks: list[dict[str, Any]]) -> None:
29
+ if vectors.shape[0] != len(chunks):
30
+ raise ValueError(
31
+ f"size mismatch: {vectors.shape[0]} vectors vs {len(chunks)} chunks"
32
+ )
33
+ if vectors.shape[0] == 0:
34
+ return
35
+ v = np.asarray(vectors, dtype=np.float32)
36
+ faiss.normalize_L2(v)
37
+ self._index.add(v)
38
+ self._chunks.extend(chunks)
39
+
40
+ def search(self, query: np.ndarray, k: int = 5) -> list[tuple[dict[str, Any], float]]:
41
+ if len(self._chunks) == 0:
42
+ return []
43
+ q = np.asarray(query, dtype=np.float32)
44
+ if q.ndim == 1:
45
+ q = q[np.newaxis, :]
46
+ faiss.normalize_L2(q)
47
+ k = min(k, len(self._chunks))
48
+ scores, idx = self._index.search(q, k)
49
+ out: list[tuple[dict[str, Any], float]] = []
50
+ for i, s in zip(idx[0], scores[0]):
51
+ if i == -1:
52
+ continue
53
+ out.append((self._chunks[int(i)], float(s)))
54
+ return out
55
+
56
+ def save(self, dir_path: Path) -> None:
57
+ dir_path.mkdir(parents=True, exist_ok=True)
58
+ faiss.write_index(self._index, str(dir_path / "index.bin"))
59
+ (dir_path / "chunks.json").write_text(json.dumps(self._chunks, indent=2))
60
+
61
+ @classmethod
62
+ def load(cls, dir_path: Path, dim: int) -> "FAISSStore":
63
+ store = cls(dim=dim)
64
+ store._index = faiss.read_index(str(dir_path / "index.bin"))
65
+ store._chunks = json.loads((dir_path / "chunks.json").read_text())
66
+ return store
tests/rag/test_store.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for src.rag.store — FAISS vector store with metadata."""
2
+ from __future__ import annotations
3
+
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import pytest
8
+
9
+ from src.rag.store import FAISSStore
10
+
11
+
12
+ def _rand_vecs(n: int, d: int = 4, seed: int = 0) -> np.ndarray:
13
+ rng = np.random.default_rng(seed)
14
+ return rng.standard_normal((n, d), dtype=np.float32)
15
+
16
+
17
+ class TestFAISSStore:
18
+ def test_add_then_search(self) -> None:
19
+ store = FAISSStore(dim=4)
20
+ vecs = _rand_vecs(3)
21
+ chunks = [{"text": f"chunk-{i}", "source": "test.md"} for i in range(3)]
22
+ store.add(vecs, chunks)
23
+ results = store.search(vecs[0], k=2)
24
+ assert len(results) == 2
25
+ # the closest hit is the chunk we used as the query (cosine ~1.0)
26
+ top_chunk, top_score = results[0]
27
+ assert top_chunk["text"] == "chunk-0"
28
+ assert top_score > 0.99
29
+
30
+ def test_add_size_mismatch_raises(self) -> None:
31
+ store = FAISSStore(dim=4)
32
+ with pytest.raises(ValueError, match="size mismatch"):
33
+ store.add(_rand_vecs(3), [{"text": "only-one"}])
34
+
35
+ def test_search_k_larger_than_corpus(self) -> None:
36
+ store = FAISSStore(dim=4)
37
+ store.add(_rand_vecs(2), [{"text": f"c{i}"} for i in range(2)])
38
+ results = store.search(_rand_vecs(1)[0], k=10)
39
+ assert len(results) == 2
40
+
41
+ def test_save_load_roundtrip(self, tmp_path: Path) -> None:
42
+ store = FAISSStore(dim=4)
43
+ vecs = _rand_vecs(3)
44
+ chunks = [{"text": f"chunk-{i}", "source": "test.md"} for i in range(3)]
45
+ store.add(vecs, chunks)
46
+ store.save(tmp_path / "idx")
47
+
48
+ restored = FAISSStore.load(tmp_path / "idx", dim=4)
49
+ results = restored.search(vecs[0], k=1)
50
+ assert results[0][0]["text"] == "chunk-0"
51
+
52
+ def test_search_on_empty_store_returns_empty(self) -> None:
53
+ store = FAISSStore(dim=4)
54
+ assert store.search(_rand_vecs(1)[0], k=5) == []