fix(rag/store): copy vectors before in-place normalize_L2 (no caller mutation)
Browse files- src/rag/store.py +2 -2
- tests/rag/test_store.py +16 -0
src/rag/store.py
CHANGED
|
@@ -32,7 +32,7 @@ class FAISSStore:
|
|
| 32 |
)
|
| 33 |
if vectors.shape[0] == 0:
|
| 34 |
return
|
| 35 |
-
v = np.
|
| 36 |
faiss.normalize_L2(v)
|
| 37 |
self._index.add(v)
|
| 38 |
self._chunks.extend(chunks)
|
|
@@ -40,7 +40,7 @@ class FAISSStore:
|
|
| 40 |
def search(self, query: np.ndarray, k: int = 5) -> list[tuple[dict[str, Any], float]]:
|
| 41 |
if len(self._chunks) == 0:
|
| 42 |
return []
|
| 43 |
-
q = np.
|
| 44 |
if q.ndim == 1:
|
| 45 |
q = q[np.newaxis, :]
|
| 46 |
faiss.normalize_L2(q)
|
|
|
|
| 32 |
)
|
| 33 |
if vectors.shape[0] == 0:
|
| 34 |
return
|
| 35 |
+
v = np.array(vectors, dtype=np.float32, copy=True)
|
| 36 |
faiss.normalize_L2(v)
|
| 37 |
self._index.add(v)
|
| 38 |
self._chunks.extend(chunks)
|
|
|
|
| 40 |
def search(self, query: np.ndarray, k: int = 5) -> list[tuple[dict[str, Any], float]]:
|
| 41 |
if len(self._chunks) == 0:
|
| 42 |
return []
|
| 43 |
+
q = np.array(query, dtype=np.float32, copy=True)
|
| 44 |
if q.ndim == 1:
|
| 45 |
q = q[np.newaxis, :]
|
| 46 |
faiss.normalize_L2(q)
|
tests/rag/test_store.py
CHANGED
|
@@ -52,3 +52,19 @@ class TestFAISSStore:
|
|
| 52 |
def test_search_on_empty_store_returns_empty(self) -> None:
|
| 53 |
store = FAISSStore(dim=4)
|
| 54 |
assert store.search(_rand_vecs(1)[0], k=5) == []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
def test_search_on_empty_store_returns_empty(self) -> None:
|
| 53 |
store = FAISSStore(dim=4)
|
| 54 |
assert store.search(_rand_vecs(1)[0], k=5) == []
|
| 55 |
+
|
| 56 |
+
def test_add_does_not_mutate_caller_vectors(self) -> None:
|
| 57 |
+
store = FAISSStore(dim=4)
|
| 58 |
+
vecs = _rand_vecs(3)
|
| 59 |
+
original = vecs.copy()
|
| 60 |
+
store.add(vecs, [{"text": f"c{i}"} for i in range(3)])
|
| 61 |
+
# Caller's array must be unchanged after add() (faiss.normalize_L2 is in-place)
|
| 62 |
+
assert np.allclose(vecs, original), "store.add() mutated caller's vectors"
|
| 63 |
+
|
| 64 |
+
def test_search_does_not_mutate_caller_query(self) -> None:
|
| 65 |
+
store = FAISSStore(dim=4)
|
| 66 |
+
store.add(_rand_vecs(3), [{"text": f"c{i}"} for i in range(3)])
|
| 67 |
+
query = _rand_vecs(1)[0]
|
| 68 |
+
original_query = query.copy()
|
| 69 |
+
store.search(query, k=2)
|
| 70 |
+
assert np.allclose(query, original_query), "store.search() mutated caller's query"
|