adeshboudh16 commited on
Commit
72bf783
·
1 Parent(s): ed21ac7

fix: harden graph retrieval retries

Browse files
src/civicsetu/agent/nodes.py CHANGED
@@ -519,7 +519,12 @@ async def graph_retrieval_node(state: CivicSetuState) -> dict:
519
  top_k = state.get("top_k", 5)
520
  node_start = time.perf_counter()
521
 
522
- log.info("graph_retrieval_node", query=query[:80])
 
 
 
 
 
523
 
524
  async def _retrieve():
525
  return await GraphRetriever.retrieve(
@@ -530,9 +535,25 @@ async def graph_retrieval_node(state: CivicSetuState) -> dict:
530
  )
531
 
532
  retrieve_start = time.perf_counter()
533
- chunks = await _retrieve()
534
- log.info("graph_retrieval_complete", results=len(chunks))
535
- log.info("stage_timing", node="graph_retrieval", stage="neo4j_postgres_hydration", duration_ms=round((time.perf_counter() - retrieve_start) * 1000, 2))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
536
 
537
  # Fallback: if graph found nothing (no explicit section in query),
538
  # run RRF hybrid retrieval instead of pure vector
@@ -627,7 +648,12 @@ async def graph_retrieval_node(state: CivicSetuState) -> dict:
627
  )
628
  chunks = chunks[:_MAX_GRAPH_CHUNKS]
629
 
630
- log.info("node_timing", node="graph_retrieval", duration_ms=round((time.perf_counter() - node_start) * 1000, 2), results=len(chunks))
 
 
 
 
 
631
  return {"retrieved_chunks": chunks}
632
 
633
 
 
519
  top_k = state.get("top_k", 5)
520
  node_start = time.perf_counter()
521
 
522
+ log.info(
523
+ "graph_retrieval_node",
524
+ query=query[:80],
525
+ jurisdiction=str(jurisdiction) if jurisdiction else None,
526
+ source_section_id=state.get("source_section_id"),
527
+ )
528
 
529
  async def _retrieve():
530
  return await GraphRetriever.retrieve(
 
535
  )
536
 
537
  retrieve_start = time.perf_counter()
538
+ try:
539
+ chunks = await _retrieve()
540
+ log.info("graph_retrieval_complete", results=len(chunks))
541
+ except Exception as e:
542
+ log.error(
543
+ "graph_retrieval_failed",
544
+ query=query[:80],
545
+ jurisdiction=str(jurisdiction) if jurisdiction else None,
546
+ source_section_id=state.get("source_section_id"),
547
+ error_type=type(e).__name__,
548
+ error=str(e),
549
+ )
550
+ raise
551
+ log.info(
552
+ "stage_timing",
553
+ node="graph_retrieval",
554
+ stage="neo4j_postgres_hydration",
555
+ duration_ms=round((time.perf_counter() - retrieve_start) * 1000, 2),
556
+ )
557
 
558
  # Fallback: if graph found nothing (no explicit section in query),
559
  # run RRF hybrid retrieval instead of pure vector
 
648
  )
649
  chunks = chunks[:_MAX_GRAPH_CHUNKS]
650
 
651
+ log.info(
652
+ "node_timing",
653
+ node="graph_retrieval",
654
+ duration_ms=round((time.perf_counter() - node_start) * 1000, 2),
655
+ results=len(chunks),
656
+ )
657
  return {"retrieved_chunks": chunks}
658
 
659
 
src/civicsetu/retrieval/graph_retriever.py CHANGED
@@ -12,9 +12,15 @@ from civicsetu.stores.vector_store import VectorStore
12
 
13
  log = structlog.get_logger(__name__)
14
 
15
- # All jurisdictions searched when no filter is provided.
16
- # Order matters: CENTRAL first so DERIVED_FROM incoming edges surface
17
- # all state rules in one traversal pass.
 
 
 
 
 
 
18
  _ALL_JURISDICTIONS = [
19
  "CENTRAL",
20
  "MAHARASHTRA",
@@ -23,27 +29,10 @@ _ALL_JURISDICTIONS = [
23
  "TAMIL_NADU",
24
  ]
25
 
26
- # Max chunks returned to reranker — prevents FlashRank serial inference blowup
27
- # on high-connectivity sections (e.g. §9 has DERIVED_FROM edges from 8 state rules)
28
  _MAX_GRAPH_CHUNKS = 20
29
 
30
 
31
  class GraphRetriever:
32
- """
33
- Retrieves chunks via Neo4j graph traversal rather than vector similarity.
34
-
35
- Used for:
36
- - cross_reference queries: "What does Section 18 reference?"
37
- - penalty_lookup queries: "What are the penalties under Section 59?"
38
- - temporal queries: "Which sections were amended?"
39
-
40
- Strategy:
41
- 1. Extract section_id from query (regex)
42
- 2. Traverse REFERENCES + DERIVED_FROM edges in Neo4j (all 5 jurisdictions in parallel)
43
- 3. Hydrate section_ids into LegalChunk objects via VectorStore.get_by_section()
44
- 4. Dedup + cap at _MAX_GRAPH_CHUNKS before returning
45
- """
46
-
47
  @staticmethod
48
  async def retrieve(
49
  query: str,
@@ -69,13 +58,7 @@ class GraphRetriever:
69
  )
70
  return cached
71
 
72
- # Explicit filter search only that jurisdiction.
73
- # No filter → search all jurisdictions (CENTRAL first).
74
- jurisdictions_to_search = (
75
- [jurisdiction.value] if jurisdiction
76
- else _ALL_JURISDICTIONS
77
- )
78
-
79
  log.info(
80
  "graph_retriever_traversing",
81
  section_id=section_id,
@@ -83,124 +66,155 @@ class GraphRetriever:
83
  jurisdictions=jurisdictions_to_search,
84
  )
85
 
86
- chunks: list[RetrievedChunk] = []
87
- seen_chunk_ids: set[str] = set()
88
-
89
- async def _fetch_jurisdiction(jur_str: str) -> list[RetrievedChunk]:
90
- async with AsyncSessionLocal() as session: # ← session scoped per coroutine
91
- jur_enum = Jurisdiction(jur_str) if not jurisdiction else jurisdiction
92
- jur_chunks: list[RetrievedChunk] = []
93
-
94
- source = await VectorStore.get_by_section(
95
- session=session, section_id=section_id, jurisdiction=jur_enum,
96
- )
97
- for rc in source:
98
- rc.retrieval_source = "graph"
99
- rc.graph_path = f"source:{section_id}@{jur_str}"
100
- rc.is_pinned = rc.chunk.section_id == section_id
101
- jur_chunks.extend(source)
102
-
103
- outgoing = await GraphStore.get_referenced_sections(section_id, jur_str, depth)
104
- for node in outgoing:
105
- hydrated = await VectorStore.get_by_section(
106
- session=session,
107
- section_id=node["section_id"],
108
- jurisdiction=Jurisdiction(node["jurisdiction"]),
109
- )
110
- for rc in hydrated:
111
- rc.retrieval_source = "graph"
112
- rc.graph_path = f"{section_id} →[REFERENCES]→ {node['section_id']}"
113
- jur_chunks.extend(hydrated)
114
 
115
- incoming = await GraphStore.get_sections_referencing(section_id, jur_str)
116
- for node in incoming:
117
- hydrated = await VectorStore.get_by_section(
118
- session=session,
119
- section_id=node["section_id"],
120
- jurisdiction=Jurisdiction(node["jurisdiction"]),
121
- )
122
- for rc in hydrated:
123
- rc.retrieval_source = "graph"
124
- rc.graph_path = f"{node['section_id']} →[REFERENCES]→ {section_id}"
125
- jur_chunks.extend(hydrated)
126
 
127
- derived_act = await GraphStore.get_derived_act_sections(section_id, jur_str)
128
- for node in derived_act:
129
- hydrated = await VectorStore.get_by_section(
130
  session=session,
131
- section_id=node["section_id"],
132
- jurisdiction=Jurisdiction(node["jurisdiction"]),
133
  )
134
- for rc in hydrated:
135
  rc.retrieval_source = "graph"
136
- rc.graph_path = f"{section_id}@{jur_str} →[DERIVED_FROM]→ {node['section_id']}@{node['jurisdiction']}"
137
- jur_chunks.extend(hydrated)
138
-
139
- deriving = await GraphStore.get_deriving_rule_sections(section_id, jur_str)
140
- for node in deriving:
141
- hydrated = await VectorStore.get_by_section(
142
- session=session,
143
- section_id=node["section_id"],
144
- jurisdiction=Jurisdiction(node["jurisdiction"]),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  )
146
- for rc in hydrated:
147
- rc.retrieval_source = "graph"
148
- rc.graph_path = f"{node['section_id']}@{node['jurisdiction']} →[DERIVED_FROM]→ {section_id}@{jur_str}"
149
- jur_chunks.extend(hydrated)
150
 
151
- return jur_chunks
 
152
 
153
- # gather is now OUTSIDE the session context
154
- all_results = await asyncio.gather(
155
- *[_fetch_jurisdiction(j) for j in jurisdictions_to_search],
156
- return_exceptions=True,
157
- )
158
- for result in all_results:
159
- if isinstance(result, Exception):
160
- log.warning("graph_jurisdiction_fetch_failed", error=str(result))
161
- else:
162
- chunks.extend(result)
163
-
164
- # Dedup by chunk_id (same chunk can arrive via multiple traversal paths)
165
- deduped: list[RetrievedChunk] = []
166
- for rc in chunks:
167
- cid = str(rc.chunk.chunk_id)
168
- if cid not in seen_chunk_ids:
169
- seen_chunk_ids.add(cid)
170
- deduped.append(rc)
171
-
172
- # Pinned chunks always included; fill remaining slots from deduped
173
- pinned = [rc for rc in deduped if rc.is_pinned]
174
- rest = [rc for rc in deduped if not rc.is_pinned]
175
- final = (pinned + rest)[:_MAX_GRAPH_CHUNKS]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
- log.info(
178
- "graph_retriever_complete",
179
- section_id=section_id,
180
- jurisdictions_searched=jurisdictions_to_search,
181
- chunks_hydrated=len(final),
182
- pinned=len(pinned),
183
- deduped_total=len(deduped),
184
- )
185
  graph_cache[cache_key] = final
186
  return final
187
 
188
  @staticmethod
189
  def _extract_section_id(query: str) -> str | None:
190
- """
191
- Extract a section number from natural language query.
192
- Handles: "Section 18", "section 18A", "s. 18", "sec 18", "Rule 3"
193
- Returns normalized section_id as stored in DB (e.g. "18", "18A").
194
- """
195
  import re
196
- section_pattern = re.compile(
197
- r'\b(?:section|sec\.?|s\.)\s*(\d+[A-Z]?)\b',
198
- re.IGNORECASE,
199
- )
200
- rule_pattern = re.compile(
201
- r'\bRule\s+(\d+[A-Z]?)\b',
202
- re.IGNORECASE,
203
- )
204
  m = section_pattern.search(query)
205
  if m:
206
  return m.group(1)
 
12
 
13
  log = structlog.get_logger(__name__)
14
 
15
+ _TRANSIENT_CONNECTION_MARKERS = (
16
+ "defunct connection",
17
+ "the connection is closed",
18
+ "connection reset by peer",
19
+ "unable to retrieve routing information",
20
+ "service unavailable",
21
+ "terminating connection due to administrator command",
22
+ )
23
+
24
  _ALL_JURISDICTIONS = [
25
  "CENTRAL",
26
  "MAHARASHTRA",
 
29
  "TAMIL_NADU",
30
  ]
31
 
 
 
32
  _MAX_GRAPH_CHUNKS = 20
33
 
34
 
35
  class GraphRetriever:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  @staticmethod
37
  async def retrieve(
38
  query: str,
 
58
  )
59
  return cached
60
 
61
+ jurisdictions_to_search = [jurisdiction.value] if jurisdiction else _ALL_JURISDICTIONS
 
 
 
 
 
 
62
  log.info(
63
  "graph_retriever_traversing",
64
  section_id=section_id,
 
66
  jurisdictions=jurisdictions_to_search,
67
  )
68
 
69
+ async def _fetch_once() -> list[RetrievedChunk]:
70
+ chunks: list[RetrievedChunk] = []
71
+ seen_chunk_ids: set[str] = set()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
+ async def _fetch_jurisdiction_once(jur_str: str) -> list[RetrievedChunk]:
74
+ async with AsyncSessionLocal() as session:
75
+ jur_enum = Jurisdiction(jur_str) if not jurisdiction else jurisdiction
76
+ jur_chunks: list[RetrievedChunk] = []
 
 
 
 
 
 
 
77
 
78
+ source = await VectorStore.get_by_section(
 
 
79
  session=session,
80
+ section_id=section_id,
81
+ jurisdiction=jur_enum,
82
  )
83
+ for rc in source:
84
  rc.retrieval_source = "graph"
85
+ rc.graph_path = f"source:{section_id}@{jur_str}"
86
+ rc.is_pinned = rc.chunk.section_id == section_id
87
+ jur_chunks.extend(source)
88
+
89
+ outgoing = await GraphStore.get_referenced_sections(section_id, jur_str, depth)
90
+ for node in outgoing:
91
+ hydrated = await VectorStore.get_by_section(
92
+ session=session,
93
+ section_id=node["section_id"],
94
+ jurisdiction=Jurisdiction(node["jurisdiction"]),
95
+ )
96
+ for rc in hydrated:
97
+ rc.retrieval_source = "graph"
98
+ rc.graph_path = f"{section_id} ->[REFERENCES]-> {node['section_id']}"
99
+ jur_chunks.extend(hydrated)
100
+
101
+ incoming = await GraphStore.get_sections_referencing(section_id, jur_str)
102
+ for node in incoming:
103
+ hydrated = await VectorStore.get_by_section(
104
+ session=session,
105
+ section_id=node["section_id"],
106
+ jurisdiction=Jurisdiction(node["jurisdiction"]),
107
+ )
108
+ for rc in hydrated:
109
+ rc.retrieval_source = "graph"
110
+ rc.graph_path = f"{node['section_id']} ->[REFERENCES]-> {section_id}"
111
+ jur_chunks.extend(hydrated)
112
+
113
+ derived_act = await GraphStore.get_derived_act_sections(section_id, jur_str)
114
+ for node in derived_act:
115
+ hydrated = await VectorStore.get_by_section(
116
+ session=session,
117
+ section_id=node["section_id"],
118
+ jurisdiction=Jurisdiction(node["jurisdiction"]),
119
+ )
120
+ for rc in hydrated:
121
+ rc.retrieval_source = "graph"
122
+ rc.graph_path = f"{section_id}@{jur_str} ->[DERIVED_FROM]-> {node['section_id']}@{node['jurisdiction']}"
123
+ jur_chunks.extend(hydrated)
124
+
125
+ deriving = await GraphStore.get_deriving_rule_sections(section_id, jur_str)
126
+ for node in deriving:
127
+ hydrated = await VectorStore.get_by_section(
128
+ session=session,
129
+ section_id=node["section_id"],
130
+ jurisdiction=Jurisdiction(node["jurisdiction"]),
131
+ )
132
+ for rc in hydrated:
133
+ rc.retrieval_source = "graph"
134
+ rc.graph_path = f"{node['section_id']}@{node['jurisdiction']} ->[DERIVED_FROM]-> {section_id}@{jur_str}"
135
+ jur_chunks.extend(hydrated)
136
+
137
+ return jur_chunks
138
+
139
+ async def _fetch_jurisdiction(jur_str: str) -> list[RetrievedChunk]:
140
+ try:
141
+ return await _fetch_jurisdiction_once(jur_str)
142
+ except Exception as first_error:
143
+ if not any(marker in str(first_error).lower() for marker in _TRANSIENT_CONNECTION_MARKERS):
144
+ raise
145
+ log.warning(
146
+ "graph_jurisdiction_retrying",
147
+ section_id=section_id,
148
+ jurisdiction=jur_str,
149
+ depth=depth,
150
+ error=str(first_error),
151
  )
152
+ from civicsetu.stores.graph_store import _reset_driver
 
 
 
153
 
154
+ await _reset_driver()
155
+ return await _fetch_jurisdiction_once(jur_str)
156
 
157
+ all_results = await asyncio.gather(
158
+ *[_fetch_jurisdiction(j) for j in jurisdictions_to_search],
159
+ return_exceptions=True,
160
+ )
161
+ for result in all_results:
162
+ if isinstance(result, Exception):
163
+ log.warning(
164
+ "graph_jurisdiction_fetch_failed",
165
+ section_id=section_id,
166
+ error_type=type(result).__name__,
167
+ error=str(result),
168
+ )
169
+ else:
170
+ chunks.extend(result)
171
+
172
+ deduped: list[RetrievedChunk] = []
173
+ for rc in chunks:
174
+ cid = str(rc.chunk.chunk_id)
175
+ if cid not in seen_chunk_ids:
176
+ seen_chunk_ids.add(cid)
177
+ deduped.append(rc)
178
+
179
+ pinned = [rc for rc in deduped if rc.is_pinned]
180
+ rest = [rc for rc in deduped if not rc.is_pinned]
181
+ final = (pinned + rest)[:_MAX_GRAPH_CHUNKS]
182
+ log.info(
183
+ "graph_retriever_complete",
184
+ section_id=section_id,
185
+ jurisdictions_searched=jurisdictions_to_search,
186
+ chunks_hydrated=len(final),
187
+ pinned=len(pinned),
188
+ deduped_total=len(deduped),
189
+ )
190
+ return final
191
+
192
+ try:
193
+ final = await _fetch_once()
194
+ except Exception as first_error:
195
+ if not any(marker in str(first_error).lower() for marker in _TRANSIENT_CONNECTION_MARKERS):
196
+ raise
197
+ log.warning(
198
+ "graph_retriever_retrying",
199
+ section_id=section_id,
200
+ jurisdiction=jurisdiction_key,
201
+ depth=depth,
202
+ error=str(first_error),
203
+ )
204
+ from civicsetu.stores.graph_store import _reset_driver
205
+
206
+ await _reset_driver()
207
+ final = await _fetch_once()
208
 
 
 
 
 
 
 
 
 
209
  graph_cache[cache_key] = final
210
  return final
211
 
212
  @staticmethod
213
  def _extract_section_id(query: str) -> str | None:
 
 
 
 
 
214
  import re
215
+
216
+ section_pattern = re.compile(r"\b(?:section|sec\.?|s\.)\s*(\d+[A-Z]?)\b", re.IGNORECASE)
217
+ rule_pattern = re.compile(r"\bRule\s+(\d+[A-Z]?)\b", re.IGNORECASE)
 
 
 
 
 
218
  m = section_pattern.search(query)
219
  if m:
220
  return m.group(1)
tests/unit/retrieval/test_graph_retriever.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from unittest.mock import AsyncMock, patch
4
+
5
+ import pytest
6
+
7
+ from tests.conftest import _make_rc
8
+
9
+
10
+ @pytest.mark.asyncio
11
+ async def test_graph_retriever_retries_transient_driver_failure():
12
+ from civicsetu.models.enums import Jurisdiction
13
+ from civicsetu.retrieval.graph_retriever import GraphRetriever
14
+
15
+ rc = _make_rc(section_id="18")
16
+
17
+ with patch("civicsetu.retrieval.graph_retriever.AsyncSessionLocal") as mock_scls, patch(
18
+ "civicsetu.retrieval.graph_retriever.VectorStore"
19
+ ) as mock_vs, patch("civicsetu.retrieval.graph_retriever.GraphStore") as mock_graph_store, patch(
20
+ "civicsetu.stores.graph_store._reset_driver", new=AsyncMock()
21
+ ) as reset_driver:
22
+ mock_session = AsyncMock()
23
+ mock_scls.return_value.__aenter__ = AsyncMock(return_value=mock_session)
24
+ mock_scls.return_value.__aexit__ = AsyncMock(return_value=False)
25
+ mock_vs.get_by_section = AsyncMock(return_value=[rc])
26
+ mock_graph_store.get_referenced_sections = AsyncMock(side_effect=[
27
+ RuntimeError("the connection is closed"),
28
+ [],
29
+ ])
30
+ mock_graph_store.get_sections_referencing = AsyncMock(return_value=[])
31
+ mock_graph_store.get_derived_act_sections = AsyncMock(return_value=[])
32
+ mock_graph_store.get_deriving_rule_sections = AsyncMock(return_value=[])
33
+
34
+ result = await GraphRetriever.retrieve(
35
+ query="What does Section 18 say?",
36
+ jurisdiction=Jurisdiction.CENTRAL,
37
+ depth=1,
38
+ )
39
+
40
+ assert result == [rc]
41
+ reset_driver.assert_awaited_once()