adeshboudh16 Claude Sonnet 4.6 commited on
Commit
ba7e2c5
·
1 Parent(s): 98da9ee

feat(retrieval): implement VectorRetriever with rrf_merge and retrieve

Browse files
src/civicsetu/retrieval/vector_retriever.py CHANGED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import re
4
+ import structlog
5
+
6
+ from civicsetu.models.enums import Jurisdiction
7
+ from civicsetu.models.schemas import RetrievedChunk
8
+ from civicsetu.stores.relational_store import AsyncSessionLocal
9
+ from civicsetu.stores.vector_store import VectorStore
10
+
11
+ log = structlog.get_logger(__name__)
12
+
13
+ _RRF_K = 60
14
+ _MAX_VECTOR_EXPANDED = 40
15
+
16
+
17
+ class VectorRetriever:
18
+ """
19
+ Hybrid retrieval: vector similarity + PostgreSQL FTS merged via Reciprocal Rank
20
+ Fusion (RRF), with top base-section family expansion.
21
+ Called by vector_retrieval_node, graph_retrieval_node fallback, hybrid_retrieval_node.
22
+ """
23
+
24
+ @staticmethod
25
+ def rrf_merge(
26
+ vector_results: list[RetrievedChunk],
27
+ fts_results: list[RetrievedChunk],
28
+ top_n: int,
29
+ ) -> list[RetrievedChunk]:
30
+ """
31
+ RRF score = 1/(k + rank_vector) + 1/(k + rank_fts).
32
+ Deduplicates by chunk_id; chunks in both lists score highest.
33
+ """
34
+ rrf_scores: dict[str, float] = {}
35
+ chunk_map: dict[str, RetrievedChunk] = {}
36
+
37
+ for rank, rc in enumerate(vector_results, 1):
38
+ cid = str(rc.chunk.chunk_id)
39
+ rrf_scores[cid] = rrf_scores.get(cid, 0.0) + 1.0 / (_RRF_K + rank)
40
+ chunk_map[cid] = rc
41
+
42
+ for rank, rc in enumerate(fts_results, 1):
43
+ cid = str(rc.chunk.chunk_id)
44
+ rrf_scores[cid] = rrf_scores.get(cid, 0.0) + 1.0 / (_RRF_K + rank)
45
+ if cid not in chunk_map:
46
+ chunk_map[cid] = rc
47
+
48
+ ranked = sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True)
49
+ return [chunk_map[cid] for cid, _ in ranked[:top_n]]
50
+
51
+ @staticmethod
52
+ async def retrieve(
53
+ query: str,
54
+ query_embedding: list[float],
55
+ top_k: int,
56
+ jurisdiction: str | None,
57
+ ) -> list[RetrievedChunk]:
58
+ """Run hybrid retrieval and section-family expansion."""
59
+ async with AsyncSessionLocal() as session:
60
+ vector_results = await VectorStore.similarity_search(
61
+ session=session,
62
+ query_embedding=query_embedding,
63
+ top_k=top_k * 3,
64
+ jurisdiction=jurisdiction,
65
+ active_only=True,
66
+ )
67
+ fts_results = await VectorStore.full_text_search(
68
+ session=session,
69
+ query=query,
70
+ top_k=top_k * 2,
71
+ jurisdiction=jurisdiction,
72
+ active_only=True,
73
+ )
74
+
75
+ merged = VectorRetriever.rrf_merge(vector_results, fts_results, top_n=top_k * 2)
76
+
77
+ seen_ids: set[str] = {str(r.chunk.chunk_id) for r in merged}
78
+ expanded: list[RetrievedChunk] = list(merged)
79
+
80
+ for rc in merged[:3]:
81
+ sid = rc.chunk.section_id
82
+ jur = Jurisdiction(rc.chunk.jurisdiction)
83
+ base_sid = re.sub(r'\([^)]*\)$', '', str(sid)).strip()
84
+ for expand_sid in {str(sid), base_sid}:
85
+ family = await VectorStore.get_section_family(
86
+ session=session, section_id=expand_sid, jurisdiction=jur
87
+ )
88
+ for fc in family:
89
+ cid = str(fc.chunk.chunk_id)
90
+ if cid not in seen_ids:
91
+ seen_ids.add(cid)
92
+ expanded.append(fc)
93
+
94
+ log.info(
95
+ "rrf_retrieve_complete",
96
+ vector_results=len(vector_results),
97
+ fts_results=len(fts_results),
98
+ merged=len(merged),
99
+ results=min(len(expanded), _MAX_VECTOR_EXPANDED),
100
+ )
101
+ return expanded[:_MAX_VECTOR_EXPANDED]
tests/unit/retrieval/test_vector_retriever.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from unittest.mock import AsyncMock, patch
4
+ import pytest
5
+
6
+ from tests.conftest import _make_rc
7
+
8
+
9
+ def test_rrf_merge_empty_inputs():
10
+ from civicsetu.retrieval.vector_retriever import VectorRetriever
11
+ assert VectorRetriever.rrf_merge([], [], top_n=5) == []
12
+
13
+
14
+ def test_rrf_merge_deduplicates_by_chunk_id():
15
+ from civicsetu.retrieval.vector_retriever import VectorRetriever
16
+ rc = _make_rc(section_id="18")
17
+ result = VectorRetriever.rrf_merge([rc], [rc], top_n=5)
18
+ assert len(result) == 1
19
+
20
+
21
+ def test_rrf_merge_ranks_overlap_highest():
22
+ from civicsetu.retrieval.vector_retriever import VectorRetriever
23
+ shared = _make_rc(section_id="18")
24
+ vector_only = _make_rc(section_id="3")
25
+ fts_only = _make_rc(section_id="7")
26
+ result = VectorRetriever.rrf_merge([shared, vector_only], [shared, fts_only], top_n=3)
27
+ assert result[0].chunk.section_id == "18"
28
+
29
+
30
+ def test_rrf_merge_respects_top_n():
31
+ from civicsetu.retrieval.vector_retriever import VectorRetriever
32
+ chunks = [_make_rc(section_id=str(i)) for i in range(5)]
33
+ assert len(VectorRetriever.rrf_merge(chunks, [], top_n=2)) == 2
34
+
35
+
36
+ def test_rrf_merge_vector_chunk_wins_on_id_collision():
37
+ from civicsetu.retrieval.vector_retriever import VectorRetriever
38
+ rc_vector = _make_rc(section_id="18")
39
+ rc_fts = _make_rc(section_id="18")
40
+ rc_fts.chunk.chunk_id = rc_vector.chunk.chunk_id
41
+ result = VectorRetriever.rrf_merge([rc_vector], [rc_fts], top_n=5)
42
+ assert len(result) == 1
43
+ assert result[0] is rc_vector
44
+
45
+
46
+ @pytest.mark.asyncio
47
+ async def test_retrieve_returns_list_of_retrieved_chunks():
48
+ from civicsetu.retrieval.vector_retriever import VectorRetriever
49
+ from civicsetu.models.schemas import RetrievedChunk
50
+ rc = _make_rc(section_id="18")
51
+ with patch("civicsetu.retrieval.vector_retriever.AsyncSessionLocal") as mock_scls, \
52
+ patch("civicsetu.retrieval.vector_retriever.VectorStore") as mock_vs:
53
+ mock_session = AsyncMock()
54
+ mock_scls.return_value.__aenter__ = AsyncMock(return_value=mock_session)
55
+ mock_scls.return_value.__aexit__ = AsyncMock(return_value=False)
56
+ mock_vs.similarity_search = AsyncMock(return_value=[rc])
57
+ mock_vs.full_text_search = AsyncMock(return_value=[rc])
58
+ mock_vs.get_section_family = AsyncMock(return_value=[])
59
+ result = await VectorRetriever.retrieve(
60
+ query="test query", query_embedding=[0.1] * 768, top_k=5, jurisdiction=None
61
+ )
62
+ assert isinstance(result, list)
63
+ assert all(isinstance(r, RetrievedChunk) for r in result)
64
+
65
+
66
+ @pytest.mark.asyncio
67
+ async def test_retrieve_caps_at_max_expanded():
68
+ from civicsetu.retrieval.vector_retriever import VectorRetriever, _MAX_VECTOR_EXPANDED
69
+ many = [_make_rc(section_id=str(i)) for i in range(50)]
70
+ with patch("civicsetu.retrieval.vector_retriever.AsyncSessionLocal") as mock_scls, \
71
+ patch("civicsetu.retrieval.vector_retriever.VectorStore") as mock_vs:
72
+ mock_session = AsyncMock()
73
+ mock_scls.return_value.__aenter__ = AsyncMock(return_value=mock_session)
74
+ mock_scls.return_value.__aexit__ = AsyncMock(return_value=False)
75
+ mock_vs.similarity_search = AsyncMock(return_value=many)
76
+ mock_vs.full_text_search = AsyncMock(return_value=[])
77
+ mock_vs.get_section_family = AsyncMock(return_value=[])
78
+ result = await VectorRetriever.retrieve(
79
+ query="test", query_embedding=[0.0] * 768, top_k=5, jurisdiction=None
80
+ )
81
+ assert len(result) <= _MAX_VECTOR_EXPANDED
82
+
83
+
84
+ @pytest.mark.asyncio
85
+ async def test_retrieve_expands_section_family():
86
+ from civicsetu.retrieval.vector_retriever import VectorRetriever
87
+ base_rc = _make_rc(section_id="18")
88
+ family_rc = _make_rc(section_id="18(1)")
89
+ with patch("civicsetu.retrieval.vector_retriever.AsyncSessionLocal") as mock_scls, \
90
+ patch("civicsetu.retrieval.vector_retriever.VectorStore") as mock_vs:
91
+ mock_session = AsyncMock()
92
+ mock_scls.return_value.__aenter__ = AsyncMock(return_value=mock_session)
93
+ mock_scls.return_value.__aexit__ = AsyncMock(return_value=False)
94
+ mock_vs.similarity_search = AsyncMock(return_value=[base_rc])
95
+ mock_vs.full_text_search = AsyncMock(return_value=[])
96
+ mock_vs.get_section_family = AsyncMock(return_value=[family_rc])
97
+ result = await VectorRetriever.retrieve(
98
+ query="test", query_embedding=[0.0] * 768, top_k=5, jurisdiction=None
99
+ )
100
+ chunk_ids = [str(r.chunk.chunk_id) for r in result]
101
+ assert str(family_rc.chunk.chunk_id) in chunk_ids