adeshboudh16 commited on
Commit ·
72bf783
1
Parent(s): ed21ac7
fix: harden graph retrieval retries
Browse files
src/civicsetu/agent/nodes.py
CHANGED
|
@@ -519,7 +519,12 @@ async def graph_retrieval_node(state: CivicSetuState) -> dict:
|
|
| 519 |
top_k = state.get("top_k", 5)
|
| 520 |
node_start = time.perf_counter()
|
| 521 |
|
| 522 |
-
log.info(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 523 |
|
| 524 |
async def _retrieve():
|
| 525 |
return await GraphRetriever.retrieve(
|
|
@@ -530,9 +535,25 @@ async def graph_retrieval_node(state: CivicSetuState) -> dict:
|
|
| 530 |
)
|
| 531 |
|
| 532 |
retrieve_start = time.perf_counter()
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 536 |
|
| 537 |
# Fallback: if graph found nothing (no explicit section in query),
|
| 538 |
# run RRF hybrid retrieval instead of pure vector
|
|
@@ -627,7 +648,12 @@ async def graph_retrieval_node(state: CivicSetuState) -> dict:
|
|
| 627 |
)
|
| 628 |
chunks = chunks[:_MAX_GRAPH_CHUNKS]
|
| 629 |
|
| 630 |
-
log.info(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 631 |
return {"retrieved_chunks": chunks}
|
| 632 |
|
| 633 |
|
|
|
|
| 519 |
top_k = state.get("top_k", 5)
|
| 520 |
node_start = time.perf_counter()
|
| 521 |
|
| 522 |
+
log.info(
|
| 523 |
+
"graph_retrieval_node",
|
| 524 |
+
query=query[:80],
|
| 525 |
+
jurisdiction=str(jurisdiction) if jurisdiction else None,
|
| 526 |
+
source_section_id=state.get("source_section_id"),
|
| 527 |
+
)
|
| 528 |
|
| 529 |
async def _retrieve():
|
| 530 |
return await GraphRetriever.retrieve(
|
|
|
|
| 535 |
)
|
| 536 |
|
| 537 |
retrieve_start = time.perf_counter()
|
| 538 |
+
try:
|
| 539 |
+
chunks = await _retrieve()
|
| 540 |
+
log.info("graph_retrieval_complete", results=len(chunks))
|
| 541 |
+
except Exception as e:
|
| 542 |
+
log.error(
|
| 543 |
+
"graph_retrieval_failed",
|
| 544 |
+
query=query[:80],
|
| 545 |
+
jurisdiction=str(jurisdiction) if jurisdiction else None,
|
| 546 |
+
source_section_id=state.get("source_section_id"),
|
| 547 |
+
error_type=type(e).__name__,
|
| 548 |
+
error=str(e),
|
| 549 |
+
)
|
| 550 |
+
raise
|
| 551 |
+
log.info(
|
| 552 |
+
"stage_timing",
|
| 553 |
+
node="graph_retrieval",
|
| 554 |
+
stage="neo4j_postgres_hydration",
|
| 555 |
+
duration_ms=round((time.perf_counter() - retrieve_start) * 1000, 2),
|
| 556 |
+
)
|
| 557 |
|
| 558 |
# Fallback: if graph found nothing (no explicit section in query),
|
| 559 |
# run RRF hybrid retrieval instead of pure vector
|
|
|
|
| 648 |
)
|
| 649 |
chunks = chunks[:_MAX_GRAPH_CHUNKS]
|
| 650 |
|
| 651 |
+
log.info(
|
| 652 |
+
"node_timing",
|
| 653 |
+
node="graph_retrieval",
|
| 654 |
+
duration_ms=round((time.perf_counter() - node_start) * 1000, 2),
|
| 655 |
+
results=len(chunks),
|
| 656 |
+
)
|
| 657 |
return {"retrieved_chunks": chunks}
|
| 658 |
|
| 659 |
|
src/civicsetu/retrieval/graph_retriever.py
CHANGED
|
@@ -12,9 +12,15 @@ from civicsetu.stores.vector_store import VectorStore
|
|
| 12 |
|
| 13 |
log = structlog.get_logger(__name__)
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
_ALL_JURISDICTIONS = [
|
| 19 |
"CENTRAL",
|
| 20 |
"MAHARASHTRA",
|
|
@@ -23,27 +29,10 @@ _ALL_JURISDICTIONS = [
|
|
| 23 |
"TAMIL_NADU",
|
| 24 |
]
|
| 25 |
|
| 26 |
-
# Max chunks returned to reranker — prevents FlashRank serial inference blowup
|
| 27 |
-
# on high-connectivity sections (e.g. §9 has DERIVED_FROM edges from 8 state rules)
|
| 28 |
_MAX_GRAPH_CHUNKS = 20
|
| 29 |
|
| 30 |
|
| 31 |
class GraphRetriever:
|
| 32 |
-
"""
|
| 33 |
-
Retrieves chunks via Neo4j graph traversal rather than vector similarity.
|
| 34 |
-
|
| 35 |
-
Used for:
|
| 36 |
-
- cross_reference queries: "What does Section 18 reference?"
|
| 37 |
-
- penalty_lookup queries: "What are the penalties under Section 59?"
|
| 38 |
-
- temporal queries: "Which sections were amended?"
|
| 39 |
-
|
| 40 |
-
Strategy:
|
| 41 |
-
1. Extract section_id from query (regex)
|
| 42 |
-
2. Traverse REFERENCES + DERIVED_FROM edges in Neo4j (all 5 jurisdictions in parallel)
|
| 43 |
-
3. Hydrate section_ids into LegalChunk objects via VectorStore.get_by_section()
|
| 44 |
-
4. Dedup + cap at _MAX_GRAPH_CHUNKS before returning
|
| 45 |
-
"""
|
| 46 |
-
|
| 47 |
@staticmethod
|
| 48 |
async def retrieve(
|
| 49 |
query: str,
|
|
@@ -69,13 +58,7 @@ class GraphRetriever:
|
|
| 69 |
)
|
| 70 |
return cached
|
| 71 |
|
| 72 |
-
|
| 73 |
-
# No filter → search all jurisdictions (CENTRAL first).
|
| 74 |
-
jurisdictions_to_search = (
|
| 75 |
-
[jurisdiction.value] if jurisdiction
|
| 76 |
-
else _ALL_JURISDICTIONS
|
| 77 |
-
)
|
| 78 |
-
|
| 79 |
log.info(
|
| 80 |
"graph_retriever_traversing",
|
| 81 |
section_id=section_id,
|
|
@@ -83,124 +66,155 @@ class GraphRetriever:
|
|
| 83 |
jurisdictions=jurisdictions_to_search,
|
| 84 |
)
|
| 85 |
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
async def _fetch_jurisdiction(jur_str: str) -> list[RetrievedChunk]:
|
| 90 |
-
async with AsyncSessionLocal() as session: # ← session scoped per coroutine
|
| 91 |
-
jur_enum = Jurisdiction(jur_str) if not jurisdiction else jurisdiction
|
| 92 |
-
jur_chunks: list[RetrievedChunk] = []
|
| 93 |
-
|
| 94 |
-
source = await VectorStore.get_by_section(
|
| 95 |
-
session=session, section_id=section_id, jurisdiction=jur_enum,
|
| 96 |
-
)
|
| 97 |
-
for rc in source:
|
| 98 |
-
rc.retrieval_source = "graph"
|
| 99 |
-
rc.graph_path = f"source:{section_id}@{jur_str}"
|
| 100 |
-
rc.is_pinned = rc.chunk.section_id == section_id
|
| 101 |
-
jur_chunks.extend(source)
|
| 102 |
-
|
| 103 |
-
outgoing = await GraphStore.get_referenced_sections(section_id, jur_str, depth)
|
| 104 |
-
for node in outgoing:
|
| 105 |
-
hydrated = await VectorStore.get_by_section(
|
| 106 |
-
session=session,
|
| 107 |
-
section_id=node["section_id"],
|
| 108 |
-
jurisdiction=Jurisdiction(node["jurisdiction"]),
|
| 109 |
-
)
|
| 110 |
-
for rc in hydrated:
|
| 111 |
-
rc.retrieval_source = "graph"
|
| 112 |
-
rc.graph_path = f"{section_id} →[REFERENCES]→ {node['section_id']}"
|
| 113 |
-
jur_chunks.extend(hydrated)
|
| 114 |
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
section_id=node["section_id"],
|
| 120 |
-
jurisdiction=Jurisdiction(node["jurisdiction"]),
|
| 121 |
-
)
|
| 122 |
-
for rc in hydrated:
|
| 123 |
-
rc.retrieval_source = "graph"
|
| 124 |
-
rc.graph_path = f"{node['section_id']} →[REFERENCES]→ {section_id}"
|
| 125 |
-
jur_chunks.extend(hydrated)
|
| 126 |
|
| 127 |
-
|
| 128 |
-
for node in derived_act:
|
| 129 |
-
hydrated = await VectorStore.get_by_section(
|
| 130 |
session=session,
|
| 131 |
-
section_id=
|
| 132 |
-
jurisdiction=
|
| 133 |
)
|
| 134 |
-
for rc in
|
| 135 |
rc.retrieval_source = "graph"
|
| 136 |
-
rc.graph_path = f"{section_id}@{jur_str}
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
)
|
| 146 |
-
|
| 147 |
-
rc.retrieval_source = "graph"
|
| 148 |
-
rc.graph_path = f"{node['section_id']}@{node['jurisdiction']} →[DERIVED_FROM]→ {section_id}@{jur_str}"
|
| 149 |
-
jur_chunks.extend(hydrated)
|
| 150 |
|
| 151 |
-
|
|
|
|
| 152 |
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
|
| 177 |
-
log.info(
|
| 178 |
-
"graph_retriever_complete",
|
| 179 |
-
section_id=section_id,
|
| 180 |
-
jurisdictions_searched=jurisdictions_to_search,
|
| 181 |
-
chunks_hydrated=len(final),
|
| 182 |
-
pinned=len(pinned),
|
| 183 |
-
deduped_total=len(deduped),
|
| 184 |
-
)
|
| 185 |
graph_cache[cache_key] = final
|
| 186 |
return final
|
| 187 |
|
| 188 |
@staticmethod
|
| 189 |
def _extract_section_id(query: str) -> str | None:
|
| 190 |
-
"""
|
| 191 |
-
Extract a section number from natural language query.
|
| 192 |
-
Handles: "Section 18", "section 18A", "s. 18", "sec 18", "Rule 3"
|
| 193 |
-
Returns normalized section_id as stored in DB (e.g. "18", "18A").
|
| 194 |
-
"""
|
| 195 |
import re
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
)
|
| 200 |
-
rule_pattern = re.compile(
|
| 201 |
-
r'\bRule\s+(\d+[A-Z]?)\b',
|
| 202 |
-
re.IGNORECASE,
|
| 203 |
-
)
|
| 204 |
m = section_pattern.search(query)
|
| 205 |
if m:
|
| 206 |
return m.group(1)
|
|
|
|
| 12 |
|
| 13 |
log = structlog.get_logger(__name__)
|
| 14 |
|
| 15 |
+
_TRANSIENT_CONNECTION_MARKERS = (
|
| 16 |
+
"defunct connection",
|
| 17 |
+
"the connection is closed",
|
| 18 |
+
"connection reset by peer",
|
| 19 |
+
"unable to retrieve routing information",
|
| 20 |
+
"service unavailable",
|
| 21 |
+
"terminating connection due to administrator command",
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
_ALL_JURISDICTIONS = [
|
| 25 |
"CENTRAL",
|
| 26 |
"MAHARASHTRA",
|
|
|
|
| 29 |
"TAMIL_NADU",
|
| 30 |
]
|
| 31 |
|
|
|
|
|
|
|
| 32 |
_MAX_GRAPH_CHUNKS = 20
|
| 33 |
|
| 34 |
|
| 35 |
class GraphRetriever:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
@staticmethod
|
| 37 |
async def retrieve(
|
| 38 |
query: str,
|
|
|
|
| 58 |
)
|
| 59 |
return cached
|
| 60 |
|
| 61 |
+
jurisdictions_to_search = [jurisdiction.value] if jurisdiction else _ALL_JURISDICTIONS
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
log.info(
|
| 63 |
"graph_retriever_traversing",
|
| 64 |
section_id=section_id,
|
|
|
|
| 66 |
jurisdictions=jurisdictions_to_search,
|
| 67 |
)
|
| 68 |
|
| 69 |
+
async def _fetch_once() -> list[RetrievedChunk]:
|
| 70 |
+
chunks: list[RetrievedChunk] = []
|
| 71 |
+
seen_chunk_ids: set[str] = set()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
+
async def _fetch_jurisdiction_once(jur_str: str) -> list[RetrievedChunk]:
|
| 74 |
+
async with AsyncSessionLocal() as session:
|
| 75 |
+
jur_enum = Jurisdiction(jur_str) if not jurisdiction else jurisdiction
|
| 76 |
+
jur_chunks: list[RetrievedChunk] = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
+
source = await VectorStore.get_by_section(
|
|
|
|
|
|
|
| 79 |
session=session,
|
| 80 |
+
section_id=section_id,
|
| 81 |
+
jurisdiction=jur_enum,
|
| 82 |
)
|
| 83 |
+
for rc in source:
|
| 84 |
rc.retrieval_source = "graph"
|
| 85 |
+
rc.graph_path = f"source:{section_id}@{jur_str}"
|
| 86 |
+
rc.is_pinned = rc.chunk.section_id == section_id
|
| 87 |
+
jur_chunks.extend(source)
|
| 88 |
+
|
| 89 |
+
outgoing = await GraphStore.get_referenced_sections(section_id, jur_str, depth)
|
| 90 |
+
for node in outgoing:
|
| 91 |
+
hydrated = await VectorStore.get_by_section(
|
| 92 |
+
session=session,
|
| 93 |
+
section_id=node["section_id"],
|
| 94 |
+
jurisdiction=Jurisdiction(node["jurisdiction"]),
|
| 95 |
+
)
|
| 96 |
+
for rc in hydrated:
|
| 97 |
+
rc.retrieval_source = "graph"
|
| 98 |
+
rc.graph_path = f"{section_id} ->[REFERENCES]-> {node['section_id']}"
|
| 99 |
+
jur_chunks.extend(hydrated)
|
| 100 |
+
|
| 101 |
+
incoming = await GraphStore.get_sections_referencing(section_id, jur_str)
|
| 102 |
+
for node in incoming:
|
| 103 |
+
hydrated = await VectorStore.get_by_section(
|
| 104 |
+
session=session,
|
| 105 |
+
section_id=node["section_id"],
|
| 106 |
+
jurisdiction=Jurisdiction(node["jurisdiction"]),
|
| 107 |
+
)
|
| 108 |
+
for rc in hydrated:
|
| 109 |
+
rc.retrieval_source = "graph"
|
| 110 |
+
rc.graph_path = f"{node['section_id']} ->[REFERENCES]-> {section_id}"
|
| 111 |
+
jur_chunks.extend(hydrated)
|
| 112 |
+
|
| 113 |
+
derived_act = await GraphStore.get_derived_act_sections(section_id, jur_str)
|
| 114 |
+
for node in derived_act:
|
| 115 |
+
hydrated = await VectorStore.get_by_section(
|
| 116 |
+
session=session,
|
| 117 |
+
section_id=node["section_id"],
|
| 118 |
+
jurisdiction=Jurisdiction(node["jurisdiction"]),
|
| 119 |
+
)
|
| 120 |
+
for rc in hydrated:
|
| 121 |
+
rc.retrieval_source = "graph"
|
| 122 |
+
rc.graph_path = f"{section_id}@{jur_str} ->[DERIVED_FROM]-> {node['section_id']}@{node['jurisdiction']}"
|
| 123 |
+
jur_chunks.extend(hydrated)
|
| 124 |
+
|
| 125 |
+
deriving = await GraphStore.get_deriving_rule_sections(section_id, jur_str)
|
| 126 |
+
for node in deriving:
|
| 127 |
+
hydrated = await VectorStore.get_by_section(
|
| 128 |
+
session=session,
|
| 129 |
+
section_id=node["section_id"],
|
| 130 |
+
jurisdiction=Jurisdiction(node["jurisdiction"]),
|
| 131 |
+
)
|
| 132 |
+
for rc in hydrated:
|
| 133 |
+
rc.retrieval_source = "graph"
|
| 134 |
+
rc.graph_path = f"{node['section_id']}@{node['jurisdiction']} ->[DERIVED_FROM]-> {section_id}@{jur_str}"
|
| 135 |
+
jur_chunks.extend(hydrated)
|
| 136 |
+
|
| 137 |
+
return jur_chunks
|
| 138 |
+
|
| 139 |
+
async def _fetch_jurisdiction(jur_str: str) -> list[RetrievedChunk]:
|
| 140 |
+
try:
|
| 141 |
+
return await _fetch_jurisdiction_once(jur_str)
|
| 142 |
+
except Exception as first_error:
|
| 143 |
+
if not any(marker in str(first_error).lower() for marker in _TRANSIENT_CONNECTION_MARKERS):
|
| 144 |
+
raise
|
| 145 |
+
log.warning(
|
| 146 |
+
"graph_jurisdiction_retrying",
|
| 147 |
+
section_id=section_id,
|
| 148 |
+
jurisdiction=jur_str,
|
| 149 |
+
depth=depth,
|
| 150 |
+
error=str(first_error),
|
| 151 |
)
|
| 152 |
+
from civicsetu.stores.graph_store import _reset_driver
|
|
|
|
|
|
|
|
|
|
| 153 |
|
| 154 |
+
await _reset_driver()
|
| 155 |
+
return await _fetch_jurisdiction_once(jur_str)
|
| 156 |
|
| 157 |
+
all_results = await asyncio.gather(
|
| 158 |
+
*[_fetch_jurisdiction(j) for j in jurisdictions_to_search],
|
| 159 |
+
return_exceptions=True,
|
| 160 |
+
)
|
| 161 |
+
for result in all_results:
|
| 162 |
+
if isinstance(result, Exception):
|
| 163 |
+
log.warning(
|
| 164 |
+
"graph_jurisdiction_fetch_failed",
|
| 165 |
+
section_id=section_id,
|
| 166 |
+
error_type=type(result).__name__,
|
| 167 |
+
error=str(result),
|
| 168 |
+
)
|
| 169 |
+
else:
|
| 170 |
+
chunks.extend(result)
|
| 171 |
+
|
| 172 |
+
deduped: list[RetrievedChunk] = []
|
| 173 |
+
for rc in chunks:
|
| 174 |
+
cid = str(rc.chunk.chunk_id)
|
| 175 |
+
if cid not in seen_chunk_ids:
|
| 176 |
+
seen_chunk_ids.add(cid)
|
| 177 |
+
deduped.append(rc)
|
| 178 |
+
|
| 179 |
+
pinned = [rc for rc in deduped if rc.is_pinned]
|
| 180 |
+
rest = [rc for rc in deduped if not rc.is_pinned]
|
| 181 |
+
final = (pinned + rest)[:_MAX_GRAPH_CHUNKS]
|
| 182 |
+
log.info(
|
| 183 |
+
"graph_retriever_complete",
|
| 184 |
+
section_id=section_id,
|
| 185 |
+
jurisdictions_searched=jurisdictions_to_search,
|
| 186 |
+
chunks_hydrated=len(final),
|
| 187 |
+
pinned=len(pinned),
|
| 188 |
+
deduped_total=len(deduped),
|
| 189 |
+
)
|
| 190 |
+
return final
|
| 191 |
+
|
| 192 |
+
try:
|
| 193 |
+
final = await _fetch_once()
|
| 194 |
+
except Exception as first_error:
|
| 195 |
+
if not any(marker in str(first_error).lower() for marker in _TRANSIENT_CONNECTION_MARKERS):
|
| 196 |
+
raise
|
| 197 |
+
log.warning(
|
| 198 |
+
"graph_retriever_retrying",
|
| 199 |
+
section_id=section_id,
|
| 200 |
+
jurisdiction=jurisdiction_key,
|
| 201 |
+
depth=depth,
|
| 202 |
+
error=str(first_error),
|
| 203 |
+
)
|
| 204 |
+
from civicsetu.stores.graph_store import _reset_driver
|
| 205 |
+
|
| 206 |
+
await _reset_driver()
|
| 207 |
+
final = await _fetch_once()
|
| 208 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
graph_cache[cache_key] = final
|
| 210 |
return final
|
| 211 |
|
| 212 |
@staticmethod
|
| 213 |
def _extract_section_id(query: str) -> str | None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
import re
|
| 215 |
+
|
| 216 |
+
section_pattern = re.compile(r"\b(?:section|sec\.?|s\.)\s*(\d+[A-Z]?)\b", re.IGNORECASE)
|
| 217 |
+
rule_pattern = re.compile(r"\bRule\s+(\d+[A-Z]?)\b", re.IGNORECASE)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
m = section_pattern.search(query)
|
| 219 |
if m:
|
| 220 |
return m.group(1)
|
tests/unit/retrieval/test_graph_retriever.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from unittest.mock import AsyncMock, patch
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
|
| 7 |
+
from tests.conftest import _make_rc
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@pytest.mark.asyncio
|
| 11 |
+
async def test_graph_retriever_retries_transient_driver_failure():
|
| 12 |
+
from civicsetu.models.enums import Jurisdiction
|
| 13 |
+
from civicsetu.retrieval.graph_retriever import GraphRetriever
|
| 14 |
+
|
| 15 |
+
rc = _make_rc(section_id="18")
|
| 16 |
+
|
| 17 |
+
with patch("civicsetu.retrieval.graph_retriever.AsyncSessionLocal") as mock_scls, patch(
|
| 18 |
+
"civicsetu.retrieval.graph_retriever.VectorStore"
|
| 19 |
+
) as mock_vs, patch("civicsetu.retrieval.graph_retriever.GraphStore") as mock_graph_store, patch(
|
| 20 |
+
"civicsetu.stores.graph_store._reset_driver", new=AsyncMock()
|
| 21 |
+
) as reset_driver:
|
| 22 |
+
mock_session = AsyncMock()
|
| 23 |
+
mock_scls.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
| 24 |
+
mock_scls.return_value.__aexit__ = AsyncMock(return_value=False)
|
| 25 |
+
mock_vs.get_by_section = AsyncMock(return_value=[rc])
|
| 26 |
+
mock_graph_store.get_referenced_sections = AsyncMock(side_effect=[
|
| 27 |
+
RuntimeError("the connection is closed"),
|
| 28 |
+
[],
|
| 29 |
+
])
|
| 30 |
+
mock_graph_store.get_sections_referencing = AsyncMock(return_value=[])
|
| 31 |
+
mock_graph_store.get_derived_act_sections = AsyncMock(return_value=[])
|
| 32 |
+
mock_graph_store.get_deriving_rule_sections = AsyncMock(return_value=[])
|
| 33 |
+
|
| 34 |
+
result = await GraphRetriever.retrieve(
|
| 35 |
+
query="What does Section 18 say?",
|
| 36 |
+
jurisdiction=Jurisdiction.CENTRAL,
|
| 37 |
+
depth=1,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
assert result == [rc]
|
| 41 |
+
reset_driver.assert_awaited_once()
|