adeshboudh16 Claude Sonnet 4.6 commited on
Commit ·
ba7e2c5
1
Parent(s): 98da9ee
feat(retrieval): implement VectorRetriever with rrf_merge and retrieve
Browse files
src/civicsetu/retrieval/vector_retriever.py
CHANGED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import re
|
| 4 |
+
import structlog
|
| 5 |
+
|
| 6 |
+
from civicsetu.models.enums import Jurisdiction
|
| 7 |
+
from civicsetu.models.schemas import RetrievedChunk
|
| 8 |
+
from civicsetu.stores.relational_store import AsyncSessionLocal
|
| 9 |
+
from civicsetu.stores.vector_store import VectorStore
|
| 10 |
+
|
| 11 |
+
log = structlog.get_logger(__name__)
|
| 12 |
+
|
| 13 |
+
_RRF_K = 60
|
| 14 |
+
_MAX_VECTOR_EXPANDED = 40
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class VectorRetriever:
|
| 18 |
+
"""
|
| 19 |
+
Hybrid retrieval: vector similarity + PostgreSQL FTS merged via Reciprocal Rank
|
| 20 |
+
Fusion (RRF), with top base-section family expansion.
|
| 21 |
+
Called by vector_retrieval_node, graph_retrieval_node fallback, hybrid_retrieval_node.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
@staticmethod
|
| 25 |
+
def rrf_merge(
|
| 26 |
+
vector_results: list[RetrievedChunk],
|
| 27 |
+
fts_results: list[RetrievedChunk],
|
| 28 |
+
top_n: int,
|
| 29 |
+
) -> list[RetrievedChunk]:
|
| 30 |
+
"""
|
| 31 |
+
RRF score = 1/(k + rank_vector) + 1/(k + rank_fts).
|
| 32 |
+
Deduplicates by chunk_id; chunks in both lists score highest.
|
| 33 |
+
"""
|
| 34 |
+
rrf_scores: dict[str, float] = {}
|
| 35 |
+
chunk_map: dict[str, RetrievedChunk] = {}
|
| 36 |
+
|
| 37 |
+
for rank, rc in enumerate(vector_results, 1):
|
| 38 |
+
cid = str(rc.chunk.chunk_id)
|
| 39 |
+
rrf_scores[cid] = rrf_scores.get(cid, 0.0) + 1.0 / (_RRF_K + rank)
|
| 40 |
+
chunk_map[cid] = rc
|
| 41 |
+
|
| 42 |
+
for rank, rc in enumerate(fts_results, 1):
|
| 43 |
+
cid = str(rc.chunk.chunk_id)
|
| 44 |
+
rrf_scores[cid] = rrf_scores.get(cid, 0.0) + 1.0 / (_RRF_K + rank)
|
| 45 |
+
if cid not in chunk_map:
|
| 46 |
+
chunk_map[cid] = rc
|
| 47 |
+
|
| 48 |
+
ranked = sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True)
|
| 49 |
+
return [chunk_map[cid] for cid, _ in ranked[:top_n]]
|
| 50 |
+
|
| 51 |
+
@staticmethod
|
| 52 |
+
async def retrieve(
|
| 53 |
+
query: str,
|
| 54 |
+
query_embedding: list[float],
|
| 55 |
+
top_k: int,
|
| 56 |
+
jurisdiction: str | None,
|
| 57 |
+
) -> list[RetrievedChunk]:
|
| 58 |
+
"""Run hybrid retrieval and section-family expansion."""
|
| 59 |
+
async with AsyncSessionLocal() as session:
|
| 60 |
+
vector_results = await VectorStore.similarity_search(
|
| 61 |
+
session=session,
|
| 62 |
+
query_embedding=query_embedding,
|
| 63 |
+
top_k=top_k * 3,
|
| 64 |
+
jurisdiction=jurisdiction,
|
| 65 |
+
active_only=True,
|
| 66 |
+
)
|
| 67 |
+
fts_results = await VectorStore.full_text_search(
|
| 68 |
+
session=session,
|
| 69 |
+
query=query,
|
| 70 |
+
top_k=top_k * 2,
|
| 71 |
+
jurisdiction=jurisdiction,
|
| 72 |
+
active_only=True,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
merged = VectorRetriever.rrf_merge(vector_results, fts_results, top_n=top_k * 2)
|
| 76 |
+
|
| 77 |
+
seen_ids: set[str] = {str(r.chunk.chunk_id) for r in merged}
|
| 78 |
+
expanded: list[RetrievedChunk] = list(merged)
|
| 79 |
+
|
| 80 |
+
for rc in merged[:3]:
|
| 81 |
+
sid = rc.chunk.section_id
|
| 82 |
+
jur = Jurisdiction(rc.chunk.jurisdiction)
|
| 83 |
+
base_sid = re.sub(r'\([^)]*\)$', '', str(sid)).strip()
|
| 84 |
+
for expand_sid in {str(sid), base_sid}:
|
| 85 |
+
family = await VectorStore.get_section_family(
|
| 86 |
+
session=session, section_id=expand_sid, jurisdiction=jur
|
| 87 |
+
)
|
| 88 |
+
for fc in family:
|
| 89 |
+
cid = str(fc.chunk.chunk_id)
|
| 90 |
+
if cid not in seen_ids:
|
| 91 |
+
seen_ids.add(cid)
|
| 92 |
+
expanded.append(fc)
|
| 93 |
+
|
| 94 |
+
log.info(
|
| 95 |
+
"rrf_retrieve_complete",
|
| 96 |
+
vector_results=len(vector_results),
|
| 97 |
+
fts_results=len(fts_results),
|
| 98 |
+
merged=len(merged),
|
| 99 |
+
results=min(len(expanded), _MAX_VECTOR_EXPANDED),
|
| 100 |
+
)
|
| 101 |
+
return expanded[:_MAX_VECTOR_EXPANDED]
|
tests/unit/retrieval/test_vector_retriever.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from unittest.mock import AsyncMock, patch
|
| 4 |
+
import pytest
|
| 5 |
+
|
| 6 |
+
from tests.conftest import _make_rc
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def test_rrf_merge_empty_inputs():
|
| 10 |
+
from civicsetu.retrieval.vector_retriever import VectorRetriever
|
| 11 |
+
assert VectorRetriever.rrf_merge([], [], top_n=5) == []
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def test_rrf_merge_deduplicates_by_chunk_id():
|
| 15 |
+
from civicsetu.retrieval.vector_retriever import VectorRetriever
|
| 16 |
+
rc = _make_rc(section_id="18")
|
| 17 |
+
result = VectorRetriever.rrf_merge([rc], [rc], top_n=5)
|
| 18 |
+
assert len(result) == 1
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def test_rrf_merge_ranks_overlap_highest():
|
| 22 |
+
from civicsetu.retrieval.vector_retriever import VectorRetriever
|
| 23 |
+
shared = _make_rc(section_id="18")
|
| 24 |
+
vector_only = _make_rc(section_id="3")
|
| 25 |
+
fts_only = _make_rc(section_id="7")
|
| 26 |
+
result = VectorRetriever.rrf_merge([shared, vector_only], [shared, fts_only], top_n=3)
|
| 27 |
+
assert result[0].chunk.section_id == "18"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def test_rrf_merge_respects_top_n():
|
| 31 |
+
from civicsetu.retrieval.vector_retriever import VectorRetriever
|
| 32 |
+
chunks = [_make_rc(section_id=str(i)) for i in range(5)]
|
| 33 |
+
assert len(VectorRetriever.rrf_merge(chunks, [], top_n=2)) == 2
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def test_rrf_merge_vector_chunk_wins_on_id_collision():
|
| 37 |
+
from civicsetu.retrieval.vector_retriever import VectorRetriever
|
| 38 |
+
rc_vector = _make_rc(section_id="18")
|
| 39 |
+
rc_fts = _make_rc(section_id="18")
|
| 40 |
+
rc_fts.chunk.chunk_id = rc_vector.chunk.chunk_id
|
| 41 |
+
result = VectorRetriever.rrf_merge([rc_vector], [rc_fts], top_n=5)
|
| 42 |
+
assert len(result) == 1
|
| 43 |
+
assert result[0] is rc_vector
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@pytest.mark.asyncio
|
| 47 |
+
async def test_retrieve_returns_list_of_retrieved_chunks():
|
| 48 |
+
from civicsetu.retrieval.vector_retriever import VectorRetriever
|
| 49 |
+
from civicsetu.models.schemas import RetrievedChunk
|
| 50 |
+
rc = _make_rc(section_id="18")
|
| 51 |
+
with patch("civicsetu.retrieval.vector_retriever.AsyncSessionLocal") as mock_scls, \
|
| 52 |
+
patch("civicsetu.retrieval.vector_retriever.VectorStore") as mock_vs:
|
| 53 |
+
mock_session = AsyncMock()
|
| 54 |
+
mock_scls.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
| 55 |
+
mock_scls.return_value.__aexit__ = AsyncMock(return_value=False)
|
| 56 |
+
mock_vs.similarity_search = AsyncMock(return_value=[rc])
|
| 57 |
+
mock_vs.full_text_search = AsyncMock(return_value=[rc])
|
| 58 |
+
mock_vs.get_section_family = AsyncMock(return_value=[])
|
| 59 |
+
result = await VectorRetriever.retrieve(
|
| 60 |
+
query="test query", query_embedding=[0.1] * 768, top_k=5, jurisdiction=None
|
| 61 |
+
)
|
| 62 |
+
assert isinstance(result, list)
|
| 63 |
+
assert all(isinstance(r, RetrievedChunk) for r in result)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@pytest.mark.asyncio
|
| 67 |
+
async def test_retrieve_caps_at_max_expanded():
|
| 68 |
+
from civicsetu.retrieval.vector_retriever import VectorRetriever, _MAX_VECTOR_EXPANDED
|
| 69 |
+
many = [_make_rc(section_id=str(i)) for i in range(50)]
|
| 70 |
+
with patch("civicsetu.retrieval.vector_retriever.AsyncSessionLocal") as mock_scls, \
|
| 71 |
+
patch("civicsetu.retrieval.vector_retriever.VectorStore") as mock_vs:
|
| 72 |
+
mock_session = AsyncMock()
|
| 73 |
+
mock_scls.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
| 74 |
+
mock_scls.return_value.__aexit__ = AsyncMock(return_value=False)
|
| 75 |
+
mock_vs.similarity_search = AsyncMock(return_value=many)
|
| 76 |
+
mock_vs.full_text_search = AsyncMock(return_value=[])
|
| 77 |
+
mock_vs.get_section_family = AsyncMock(return_value=[])
|
| 78 |
+
result = await VectorRetriever.retrieve(
|
| 79 |
+
query="test", query_embedding=[0.0] * 768, top_k=5, jurisdiction=None
|
| 80 |
+
)
|
| 81 |
+
assert len(result) <= _MAX_VECTOR_EXPANDED
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
@pytest.mark.asyncio
|
| 85 |
+
async def test_retrieve_expands_section_family():
|
| 86 |
+
from civicsetu.retrieval.vector_retriever import VectorRetriever
|
| 87 |
+
base_rc = _make_rc(section_id="18")
|
| 88 |
+
family_rc = _make_rc(section_id="18(1)")
|
| 89 |
+
with patch("civicsetu.retrieval.vector_retriever.AsyncSessionLocal") as mock_scls, \
|
| 90 |
+
patch("civicsetu.retrieval.vector_retriever.VectorStore") as mock_vs:
|
| 91 |
+
mock_session = AsyncMock()
|
| 92 |
+
mock_scls.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
| 93 |
+
mock_scls.return_value.__aexit__ = AsyncMock(return_value=False)
|
| 94 |
+
mock_vs.similarity_search = AsyncMock(return_value=[base_rc])
|
| 95 |
+
mock_vs.full_text_search = AsyncMock(return_value=[])
|
| 96 |
+
mock_vs.get_section_family = AsyncMock(return_value=[family_rc])
|
| 97 |
+
result = await VectorRetriever.retrieve(
|
| 98 |
+
query="test", query_embedding=[0.0] * 768, top_k=5, jurisdiction=None
|
| 99 |
+
)
|
| 100 |
+
chunk_ids = [str(r.chunk.chunk_id) for r in result]
|
| 101 |
+
assert str(family_rc.chunk.chunk_id) in chunk_ids
|