adeshboudh16 commited on
Commit ·
caab91b
1
Parent(s): 5d11dcd
fix: recover neo4j graph connections
Browse files- src/civicsetu/agent/nodes.py +45 -33
- src/civicsetu/api/routes/graph.py +2 -3
- src/civicsetu/api/routes/query.py +3 -5
- src/civicsetu/stores/graph_store.py +162 -109
- tests/unit/agent/test_nodes.py +94 -68
- tests/unit/api/test_query_route.py +12 -12
- tests/unit/stores/test_graph_store.py +48 -0
src/civicsetu/agent/nodes.py
CHANGED
|
@@ -64,7 +64,7 @@ FAST_MODELS = [
|
|
| 64 |
FALLBACK_MODELS = THINKING_MODELS
|
| 65 |
|
| 66 |
|
| 67 |
-
def turn_reset_node(state: CivicSetuState) -> dict:
|
| 68 |
"""
|
| 69 |
Clear per-turn fields while preserving session-scoped inputs and messages.
|
| 70 |
"""
|
|
@@ -85,9 +85,9 @@ def turn_reset_node(state: CivicSetuState) -> dict:
|
|
| 85 |
}
|
| 86 |
|
| 87 |
|
| 88 |
-
def _llm_call(prompt: str, system: str, temperature: float = 0.0, tier: str = "thinking") -> str:
|
| 89 |
"""
|
| 90 |
-
Call an LLM with fallback chain.
|
| 91 |
|
| 92 |
Args:
|
| 93 |
tier: "thinking" for deep-reasoning tasks (generator),
|
|
@@ -146,7 +146,7 @@ def _llm_call(prompt: str, system: str, temperature: float = 0.0, tier: str = "t
|
|
| 146 |
if "osmapi.com" in os.environ.get("OPENAI_API_BASE", ""):
|
| 147 |
completion_kwargs["response_format"] = {"type": "json_object"}
|
| 148 |
|
| 149 |
-
response = litellm.
|
| 150 |
|
| 151 |
duration_ms = round((time.perf_counter() - start) * 1000, 2)
|
| 152 |
content = response.choices[0].message.content
|
|
@@ -323,7 +323,7 @@ def _sort_pinned_family(family: list[RetrievedChunk], hint: str) -> list[Retriev
|
|
| 323 |
return sorted(family, key=score, reverse=True)
|
| 324 |
|
| 325 |
|
| 326 |
-
def _prepend_pinned_sections(state: CivicSetuState, chunks: list[RetrievedChunk]) -> list[RetrievedChunk]:
|
| 327 |
pinned_refs = state.get("pinned_section_refs")
|
| 328 |
if not pinned_refs:
|
| 329 |
return chunks
|
|
@@ -346,12 +346,12 @@ def _prepend_pinned_sections(state: CivicSetuState, chunks: list[RetrievedChunk]
|
|
| 346 |
c.is_pinned = True
|
| 347 |
existing_matches.append(c)
|
| 348 |
|
| 349 |
-
pinned =
|
| 350 |
pinned_refs,
|
| 351 |
jurisdiction,
|
| 352 |
chunks,
|
| 353 |
hint,
|
| 354 |
-
)
|
| 355 |
promoted = existing_matches + pinned
|
| 356 |
if promoted:
|
| 357 |
promoted_ids = {str(c.chunk.chunk_id) for c in promoted}
|
|
@@ -367,7 +367,7 @@ def _prepend_pinned_sections(state: CivicSetuState, chunks: list[RetrievedChunk]
|
|
| 367 |
|
| 368 |
# ── Node 1: Classifier ─────────────────────────────────────────────────────────
|
| 369 |
|
| 370 |
-
def classifier_node(state: CivicSetuState) -> dict:
|
| 371 |
"""
|
| 372 |
Classifies query type and rewrites the query for better retrieval.
|
| 373 |
Returns: query_type, rewritten_query
|
|
@@ -392,7 +392,7 @@ def classifier_node(state: CivicSetuState) -> dict:
|
|
| 392 |
)
|
| 393 |
|
| 394 |
try:
|
| 395 |
-
raw = _llm_call(prompt, system, tier="fast")
|
| 396 |
result = _extract_json_dict(raw)
|
| 397 |
|
| 398 |
query_type_str = result.get("query_type", "fact_lookup")
|
|
@@ -423,7 +423,7 @@ async def _rrf_retrieve(
|
|
| 423 |
return await VectorRetriever.retrieve(query, query_embedding, top_k, jurisdiction)
|
| 424 |
|
| 425 |
|
| 426 |
-
def vector_retrieval_node(state: CivicSetuState) -> dict:
|
| 427 |
from civicsetu.retrieval.vector_retriever import VectorRetriever
|
| 428 |
|
| 429 |
query = state.get("rewritten_query") or state["query"]
|
|
@@ -434,12 +434,12 @@ def vector_retrieval_node(state: CivicSetuState) -> dict:
|
|
| 434 |
log.info("vector_retrieval_node", query=query[:80], top_k=top_k)
|
| 435 |
|
| 436 |
embed_start = time.perf_counter()
|
| 437 |
-
query_embedding =
|
| 438 |
log.info("stage_timing", node="vector_retrieval", stage="embedding",
|
| 439 |
duration_ms=round((time.perf_counter() - embed_start) * 1000, 2))
|
| 440 |
|
| 441 |
retrieve_start = time.perf_counter()
|
| 442 |
-
chunks =
|
| 443 |
log.info("stage_timing", node="vector_retrieval", stage="postgres_retrieval",
|
| 444 |
duration_ms=round((time.perf_counter() - retrieve_start) * 1000, 2))
|
| 445 |
# Fix 4: section-ID-aware direct lookup.
|
|
@@ -478,7 +478,7 @@ def vector_retrieval_node(state: CivicSetuState) -> dict:
|
|
| 478 |
extra.append(fc)
|
| 479 |
return extra
|
| 480 |
|
| 481 |
-
direct_chunks =
|
| 482 |
if direct_chunks:
|
| 483 |
log.info("vector_section_direct_lookup",
|
| 484 |
section_ids=list(section_ids),
|
|
@@ -490,7 +490,7 @@ def vector_retrieval_node(state: CivicSetuState) -> dict:
|
|
| 490 |
# Section 5 timeline) — adding them ensures cross-Act context is available.
|
| 491 |
if jurisdiction and jurisdiction != Jurisdiction.CENTRAL:
|
| 492 |
central_start = time.perf_counter()
|
| 493 |
-
central_chunks =
|
| 494 |
seen = {str(c.chunk.chunk_id) for c in chunks}
|
| 495 |
extra_central = [c for c in central_chunks if str(c.chunk.chunk_id) not in seen]
|
| 496 |
if extra_central:
|
|
@@ -500,13 +500,13 @@ def vector_retrieval_node(state: CivicSetuState) -> dict:
|
|
| 500 |
log.info("stage_timing", node="vector_retrieval", stage="central_supplement",
|
| 501 |
duration_ms=round((time.perf_counter() - central_start) * 1000, 2))
|
| 502 |
|
| 503 |
-
chunks = _prepend_pinned_sections(state, chunks)
|
| 504 |
log.info("node_timing", node="vector_retrieval",
|
| 505 |
duration_ms=round((time.perf_counter() - node_start) * 1000, 2), results=len(chunks))
|
| 506 |
return {"retrieved_chunks": chunks}
|
| 507 |
|
| 508 |
|
| 509 |
-
def graph_retrieval_node(state: CivicSetuState) -> dict:
|
| 510 |
"""
|
| 511 |
Graph-based retrieval for cross_reference and temporal queries.
|
| 512 |
Traverses REFERENCES edges in Neo4j then hydrates chunks from pgvector.
|
|
@@ -530,7 +530,7 @@ def graph_retrieval_node(state: CivicSetuState) -> dict:
|
|
| 530 |
)
|
| 531 |
|
| 532 |
retrieve_start = time.perf_counter()
|
| 533 |
-
chunks =
|
| 534 |
log.info("graph_retrieval_complete", results=len(chunks))
|
| 535 |
log.info("stage_timing", node="graph_retrieval", stage="neo4j_postgres_hydration", duration_ms=round((time.perf_counter() - retrieve_start) * 1000, 2))
|
| 536 |
|
|
@@ -539,11 +539,11 @@ def graph_retrieval_node(state: CivicSetuState) -> dict:
|
|
| 539 |
if not chunks:
|
| 540 |
log.info("graph_retrieval_fallback_to_rrf", query=query[:80])
|
| 541 |
embed_start = time.perf_counter()
|
| 542 |
-
query_embedding =
|
| 543 |
log.info("stage_timing", node="graph_retrieval", stage="fallback_embedding", duration_ms=round((time.perf_counter() - embed_start) * 1000, 2))
|
| 544 |
|
| 545 |
fallback_start = time.perf_counter()
|
| 546 |
-
chunks =
|
| 547 |
log.info("stage_timing", node="graph_retrieval", stage="fallback_rrf_search", duration_ms=round((time.perf_counter() - fallback_start) * 1000, 2))
|
| 548 |
log.info("graph_fallback_rrf_results", count=len(chunks))
|
| 549 |
|
|
@@ -551,7 +551,7 @@ def graph_retrieval_node(state: CivicSetuState) -> dict:
|
|
| 551 |
# retry without jurisdiction filter to pick up Central Act chunks.
|
| 552 |
if not chunks and jurisdiction:
|
| 553 |
log.info("graph_retrieval_fallback_no_jurisdiction", query=query[:80])
|
| 554 |
-
chunks =
|
| 555 |
log.info("graph_fallback_no_jurisdiction_results", count=len(chunks))
|
| 556 |
# Fix 5: section-ID-aware direct lookup in graph retrieval path.
|
| 557 |
# Use original query when it has explicit sections (XREF: "Section 18 refund")
|
|
@@ -593,7 +593,7 @@ def graph_retrieval_node(state: CivicSetuState) -> dict:
|
|
| 593 |
extra.append(fc)
|
| 594 |
return extra
|
| 595 |
|
| 596 |
-
direct_chunks =
|
| 597 |
if direct_chunks:
|
| 598 |
log.info("graph_section_direct_lookup",
|
| 599 |
section_ids=list(section_ids),
|
|
@@ -607,8 +607,8 @@ def graph_retrieval_node(state: CivicSetuState) -> dict:
|
|
| 607 |
# only fires when the state family is EMPTY, missing this mismatch case.
|
| 608 |
if jurisdiction and jurisdiction != Jurisdiction.CENTRAL:
|
| 609 |
central_supplement_start = time.perf_counter()
|
| 610 |
-
_query_embedding =
|
| 611 |
-
central_chunks =
|
| 612 |
seen = {str(c.chunk.chunk_id) for c in chunks}
|
| 613 |
extra_central = [c for c in central_chunks if str(c.chunk.chunk_id) not in seen]
|
| 614 |
if extra_central:
|
|
@@ -617,7 +617,7 @@ def graph_retrieval_node(state: CivicSetuState) -> dict:
|
|
| 617 |
log.info("stage_timing", node="graph_retrieval", stage="central_supplement",
|
| 618 |
duration_ms=round((time.perf_counter() - central_supplement_start) * 1000, 2))
|
| 619 |
|
| 620 |
-
chunks = _prepend_pinned_sections(state, chunks)
|
| 621 |
_MAX_GRAPH_CHUNKS = 25
|
| 622 |
if len(chunks) > _MAX_GRAPH_CHUNKS:
|
| 623 |
log.warning(
|
|
@@ -633,7 +633,7 @@ def graph_retrieval_node(state: CivicSetuState) -> dict:
|
|
| 633 |
|
| 634 |
# ── Node 3: Reranker ───────────────────────────────────────────────────────────
|
| 635 |
|
| 636 |
-
def reranker_node(state: CivicSetuState) -> dict:
|
| 637 |
from civicsetu.retrieval.reranker import Reranker
|
| 638 |
|
| 639 |
chunks = state.get("retrieved_chunks", [])
|
|
@@ -645,7 +645,19 @@ def reranker_node(state: CivicSetuState) -> dict:
|
|
| 645 |
duration_ms=round((time.perf_counter() - node_start) * 1000, 2), reranked=0)
|
| 646 |
return {"reranked_chunks": []}
|
| 647 |
|
| 648 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 649 |
|
| 650 |
log.info("reranker_complete", reranked=len(reranked))
|
| 651 |
log.info("node_timing", node="reranker",
|
|
@@ -655,7 +667,7 @@ def reranker_node(state: CivicSetuState) -> dict:
|
|
| 655 |
|
| 656 |
# ── Node 4: Generator ──────────────────────────────────────────────────────────
|
| 657 |
|
| 658 |
-
def generator_node(state: CivicSetuState) -> dict:
|
| 659 |
"""
|
| 660 |
Generates a cited answer from reranked chunks.
|
| 661 |
Citations are anchored to cited_chunks indices returned by the LLM —
|
|
@@ -739,7 +751,7 @@ def generator_node(state: CivicSetuState) -> dict:
|
|
| 739 |
try:
|
| 740 |
llm_start = time.perf_counter()
|
| 741 |
log.info("generator_debug_start", query=query[:50], chunks=len(chunks), model=THINKING_MODELS[0])
|
| 742 |
-
raw = _llm_call(prompt, system, temperature=0.0, tier="thinking")
|
| 743 |
log.info("generator_debug_raw", raw_type=type(raw).__name__, raw_len=len(raw), raw_preview=raw[:200])
|
| 744 |
log.info("stage_timing", node="generator", stage="llm", duration_ms=round((time.perf_counter() - llm_start) * 1000, 2))
|
| 745 |
result = _extract_json_dict(raw)
|
|
@@ -846,7 +858,7 @@ def generator_node(state: CivicSetuState) -> dict:
|
|
| 846 |
|
| 847 |
# ── Node 5: Validator ──────────────────────────────────────────────────────────
|
| 848 |
|
| 849 |
-
def validator_node(state: CivicSetuState) -> dict:
|
| 850 |
answer = state.get("raw_response", "")
|
| 851 |
chunks = state.get("reranked_chunks", [])
|
| 852 |
node_start = time.perf_counter()
|
|
@@ -865,7 +877,7 @@ def validator_node(state: CivicSetuState) -> dict:
|
|
| 865 |
}
|
| 866 |
|
| 867 |
|
| 868 |
-
def hybrid_retrieval_node(state: CivicSetuState) -> dict:
|
| 869 |
"""
|
| 870 |
Used for conflict_detection queries.
|
| 871 |
Runs vector + graph retrieval in parallel, merges results.
|
|
@@ -881,7 +893,7 @@ def hybrid_retrieval_node(state: CivicSetuState) -> dict:
|
|
| 881 |
log.info("hybrid_retrieval_node", query=query[:80])
|
| 882 |
|
| 883 |
embed_start = time.perf_counter()
|
| 884 |
-
query_embedding =
|
| 885 |
log.info("stage_timing", node="hybrid_retrieval", stage="embedding", duration_ms=round((time.perf_counter() - embed_start) * 1000, 2))
|
| 886 |
|
| 887 |
async def _retrieve():
|
|
@@ -922,11 +934,11 @@ def hybrid_retrieval_node(state: CivicSetuState) -> dict:
|
|
| 922 |
return v_chunks, g_chunks, extra_central
|
| 923 |
|
| 924 |
retrieve_start = time.perf_counter()
|
| 925 |
-
v_chunks, g_chunks, extra_central =
|
| 926 |
log.info("stage_timing", node="hybrid_retrieval", stage="vector_graph_parallel", duration_ms=round((time.perf_counter() - retrieve_start) * 1000, 2))
|
| 927 |
|
| 928 |
all_chunks = v_chunks + g_chunks + extra_central
|
| 929 |
-
all_chunks = _prepend_pinned_sections(state, all_chunks)
|
| 930 |
log.info(
|
| 931 |
"hybrid_retrieval_complete",
|
| 932 |
vector_chunks=len(v_chunks),
|
|
|
|
| 64 |
FALLBACK_MODELS = THINKING_MODELS
|
| 65 |
|
| 66 |
|
| 67 |
+
async def turn_reset_node(state: CivicSetuState) -> dict:
|
| 68 |
"""
|
| 69 |
Clear per-turn fields while preserving session-scoped inputs and messages.
|
| 70 |
"""
|
|
|
|
| 85 |
}
|
| 86 |
|
| 87 |
|
| 88 |
+
async def _llm_call(prompt: str, system: str, temperature: float = 0.0, tier: str = "thinking") -> str:
|
| 89 |
"""
|
| 90 |
+
Call an LLM with fallback chain asynchronously.
|
| 91 |
|
| 92 |
Args:
|
| 93 |
tier: "thinking" for deep-reasoning tasks (generator),
|
|
|
|
| 146 |
if "osmapi.com" in os.environ.get("OPENAI_API_BASE", ""):
|
| 147 |
completion_kwargs["response_format"] = {"type": "json_object"}
|
| 148 |
|
| 149 |
+
response = await litellm.acompletion(**completion_kwargs)
|
| 150 |
|
| 151 |
duration_ms = round((time.perf_counter() - start) * 1000, 2)
|
| 152 |
content = response.choices[0].message.content
|
|
|
|
| 323 |
return sorted(family, key=score, reverse=True)
|
| 324 |
|
| 325 |
|
| 326 |
+
async def _prepend_pinned_sections(state: CivicSetuState, chunks: list[RetrievedChunk]) -> list[RetrievedChunk]:
|
| 327 |
pinned_refs = state.get("pinned_section_refs")
|
| 328 |
if not pinned_refs:
|
| 329 |
return chunks
|
|
|
|
| 346 |
c.is_pinned = True
|
| 347 |
existing_matches.append(c)
|
| 348 |
|
| 349 |
+
pinned = await _fetch_pinned_sections(
|
| 350 |
pinned_refs,
|
| 351 |
jurisdiction,
|
| 352 |
chunks,
|
| 353 |
hint,
|
| 354 |
+
)
|
| 355 |
promoted = existing_matches + pinned
|
| 356 |
if promoted:
|
| 357 |
promoted_ids = {str(c.chunk.chunk_id) for c in promoted}
|
|
|
|
| 367 |
|
| 368 |
# ── Node 1: Classifier ─────────────────────────────────────────────────────────
|
| 369 |
|
| 370 |
+
async def classifier_node(state: CivicSetuState) -> dict:
|
| 371 |
"""
|
| 372 |
Classifies query type and rewrites the query for better retrieval.
|
| 373 |
Returns: query_type, rewritten_query
|
|
|
|
| 392 |
)
|
| 393 |
|
| 394 |
try:
|
| 395 |
+
raw = await _llm_call(prompt, system, tier="fast")
|
| 396 |
result = _extract_json_dict(raw)
|
| 397 |
|
| 398 |
query_type_str = result.get("query_type", "fact_lookup")
|
|
|
|
| 423 |
return await VectorRetriever.retrieve(query, query_embedding, top_k, jurisdiction)
|
| 424 |
|
| 425 |
|
| 426 |
+
async def vector_retrieval_node(state: CivicSetuState) -> dict:
|
| 427 |
from civicsetu.retrieval.vector_retriever import VectorRetriever
|
| 428 |
|
| 429 |
query = state.get("rewritten_query") or state["query"]
|
|
|
|
| 434 |
log.info("vector_retrieval_node", query=query[:80], top_k=top_k)
|
| 435 |
|
| 436 |
embed_start = time.perf_counter()
|
| 437 |
+
query_embedding = await asyncio.to_thread(cached_embed, query)
|
| 438 |
log.info("stage_timing", node="vector_retrieval", stage="embedding",
|
| 439 |
duration_ms=round((time.perf_counter() - embed_start) * 1000, 2))
|
| 440 |
|
| 441 |
retrieve_start = time.perf_counter()
|
| 442 |
+
chunks = await VectorRetriever.retrieve(query, query_embedding, top_k, jurisdiction)
|
| 443 |
log.info("stage_timing", node="vector_retrieval", stage="postgres_retrieval",
|
| 444 |
duration_ms=round((time.perf_counter() - retrieve_start) * 1000, 2))
|
| 445 |
# Fix 4: section-ID-aware direct lookup.
|
|
|
|
| 478 |
extra.append(fc)
|
| 479 |
return extra
|
| 480 |
|
| 481 |
+
direct_chunks = await _fetch_sections()
|
| 482 |
if direct_chunks:
|
| 483 |
log.info("vector_section_direct_lookup",
|
| 484 |
section_ids=list(section_ids),
|
|
|
|
| 490 |
# Section 5 timeline) — adding them ensures cross-Act context is available.
|
| 491 |
if jurisdiction and jurisdiction != Jurisdiction.CENTRAL:
|
| 492 |
central_start = time.perf_counter()
|
| 493 |
+
central_chunks = await VectorRetriever.retrieve(query, query_embedding, top_k, Jurisdiction.CENTRAL)
|
| 494 |
seen = {str(c.chunk.chunk_id) for c in chunks}
|
| 495 |
extra_central = [c for c in central_chunks if str(c.chunk.chunk_id) not in seen]
|
| 496 |
if extra_central:
|
|
|
|
| 500 |
log.info("stage_timing", node="vector_retrieval", stage="central_supplement",
|
| 501 |
duration_ms=round((time.perf_counter() - central_start) * 1000, 2))
|
| 502 |
|
| 503 |
+
chunks = await _prepend_pinned_sections(state, chunks)
|
| 504 |
log.info("node_timing", node="vector_retrieval",
|
| 505 |
duration_ms=round((time.perf_counter() - node_start) * 1000, 2), results=len(chunks))
|
| 506 |
return {"retrieved_chunks": chunks}
|
| 507 |
|
| 508 |
|
| 509 |
+
async def graph_retrieval_node(state: CivicSetuState) -> dict:
|
| 510 |
"""
|
| 511 |
Graph-based retrieval for cross_reference and temporal queries.
|
| 512 |
Traverses REFERENCES edges in Neo4j then hydrates chunks from pgvector.
|
|
|
|
| 530 |
)
|
| 531 |
|
| 532 |
retrieve_start = time.perf_counter()
|
| 533 |
+
chunks = await _retrieve()
|
| 534 |
log.info("graph_retrieval_complete", results=len(chunks))
|
| 535 |
log.info("stage_timing", node="graph_retrieval", stage="neo4j_postgres_hydration", duration_ms=round((time.perf_counter() - retrieve_start) * 1000, 2))
|
| 536 |
|
|
|
|
| 539 |
if not chunks:
|
| 540 |
log.info("graph_retrieval_fallback_to_rrf", query=query[:80])
|
| 541 |
embed_start = time.perf_counter()
|
| 542 |
+
query_embedding = await asyncio.to_thread(cached_embed, query)
|
| 543 |
log.info("stage_timing", node="graph_retrieval", stage="fallback_embedding", duration_ms=round((time.perf_counter() - embed_start) * 1000, 2))
|
| 544 |
|
| 545 |
fallback_start = time.perf_counter()
|
| 546 |
+
chunks = await _rrf_retrieve(query, query_embedding, top_k, jurisdiction)
|
| 547 |
log.info("stage_timing", node="graph_retrieval", stage="fallback_rrf_search", duration_ms=round((time.perf_counter() - fallback_start) * 1000, 2))
|
| 548 |
log.info("graph_fallback_rrf_results", count=len(chunks))
|
| 549 |
|
|
|
|
| 551 |
# retry without jurisdiction filter to pick up Central Act chunks.
|
| 552 |
if not chunks and jurisdiction:
|
| 553 |
log.info("graph_retrieval_fallback_no_jurisdiction", query=query[:80])
|
| 554 |
+
chunks = await _rrf_retrieve(query, query_embedding, top_k, None)
|
| 555 |
log.info("graph_fallback_no_jurisdiction_results", count=len(chunks))
|
| 556 |
# Fix 5: section-ID-aware direct lookup in graph retrieval path.
|
| 557 |
# Use original query when it has explicit sections (XREF: "Section 18 refund")
|
|
|
|
| 593 |
extra.append(fc)
|
| 594 |
return extra
|
| 595 |
|
| 596 |
+
direct_chunks = await _fetch_sections_graph()
|
| 597 |
if direct_chunks:
|
| 598 |
log.info("graph_section_direct_lookup",
|
| 599 |
section_ids=list(section_ids),
|
|
|
|
| 607 |
# only fires when the state family is EMPTY, missing this mismatch case.
|
| 608 |
if jurisdiction and jurisdiction != Jurisdiction.CENTRAL:
|
| 609 |
central_supplement_start = time.perf_counter()
|
| 610 |
+
_query_embedding = await asyncio.to_thread(cached_embed, query)
|
| 611 |
+
central_chunks = await _VR.retrieve(query, _query_embedding, top_k, Jurisdiction.CENTRAL)
|
| 612 |
seen = {str(c.chunk.chunk_id) for c in chunks}
|
| 613 |
extra_central = [c for c in central_chunks if str(c.chunk.chunk_id) not in seen]
|
| 614 |
if extra_central:
|
|
|
|
| 617 |
log.info("stage_timing", node="graph_retrieval", stage="central_supplement",
|
| 618 |
duration_ms=round((time.perf_counter() - central_supplement_start) * 1000, 2))
|
| 619 |
|
| 620 |
+
chunks = await _prepend_pinned_sections(state, chunks)
|
| 621 |
_MAX_GRAPH_CHUNKS = 25
|
| 622 |
if len(chunks) > _MAX_GRAPH_CHUNKS:
|
| 623 |
log.warning(
|
|
|
|
| 633 |
|
| 634 |
# ── Node 3: Reranker ───────────────────────────────────────────────────────────
|
| 635 |
|
| 636 |
+
async def reranker_node(state: CivicSetuState) -> dict:
|
| 637 |
from civicsetu.retrieval.reranker import Reranker
|
| 638 |
|
| 639 |
chunks = state.get("retrieved_chunks", [])
|
|
|
|
| 645 |
duration_ms=round((time.perf_counter() - node_start) * 1000, 2), reranked=0)
|
| 646 |
return {"reranked_chunks": []}
|
| 647 |
|
| 648 |
+
# Evaluation context trimming: if pinned sections are provided,
|
| 649 |
+
# ONLY keep chunks marked as pinned (matched the target sections).
|
| 650 |
+
# Prevents the LLM from picking up nearby sections that happen to share
|
| 651 |
+
# keywords but aren't relevant to the specific subclause being tested.
|
| 652 |
+
pinned_refs = state.get("pinned_section_refs")
|
| 653 |
+
if pinned_refs:
|
| 654 |
+
filtered = [rc for rc in chunks if rc.is_pinned]
|
| 655 |
+
if filtered:
|
| 656 |
+
log.info("reranker_eval_context_trimmed", before=len(chunks), after=len(filtered))
|
| 657 |
+
chunks = filtered
|
| 658 |
+
|
| 659 |
+
# Reranker is cpu-bound/blocking, but run in thread to keep event loop free
|
| 660 |
+
reranked = await asyncio.to_thread(Reranker.rerank, chunks, query)
|
| 661 |
|
| 662 |
log.info("reranker_complete", reranked=len(reranked))
|
| 663 |
log.info("node_timing", node="reranker",
|
|
|
|
| 667 |
|
| 668 |
# ── Node 4: Generator ──────────────────────────────────────────────────────────
|
| 669 |
|
| 670 |
+
async def generator_node(state: CivicSetuState) -> dict:
|
| 671 |
"""
|
| 672 |
Generates a cited answer from reranked chunks.
|
| 673 |
Citations are anchored to cited_chunks indices returned by the LLM —
|
|
|
|
| 751 |
try:
|
| 752 |
llm_start = time.perf_counter()
|
| 753 |
log.info("generator_debug_start", query=query[:50], chunks=len(chunks), model=THINKING_MODELS[0])
|
| 754 |
+
raw = await _llm_call(prompt, system, temperature=0.0, tier="thinking")
|
| 755 |
log.info("generator_debug_raw", raw_type=type(raw).__name__, raw_len=len(raw), raw_preview=raw[:200])
|
| 756 |
log.info("stage_timing", node="generator", stage="llm", duration_ms=round((time.perf_counter() - llm_start) * 1000, 2))
|
| 757 |
result = _extract_json_dict(raw)
|
|
|
|
| 858 |
|
| 859 |
# ── Node 5: Validator ──────────────────────────────────────────────────────────
|
| 860 |
|
| 861 |
+
async def validator_node(state: CivicSetuState) -> dict:
|
| 862 |
answer = state.get("raw_response", "")
|
| 863 |
chunks = state.get("reranked_chunks", [])
|
| 864 |
node_start = time.perf_counter()
|
|
|
|
| 877 |
}
|
| 878 |
|
| 879 |
|
| 880 |
+
async def hybrid_retrieval_node(state: CivicSetuState) -> dict:
|
| 881 |
"""
|
| 882 |
Used for conflict_detection queries.
|
| 883 |
Runs vector + graph retrieval in parallel, merges results.
|
|
|
|
| 893 |
log.info("hybrid_retrieval_node", query=query[:80])
|
| 894 |
|
| 895 |
embed_start = time.perf_counter()
|
| 896 |
+
query_embedding = await asyncio.to_thread(cached_embed, query)
|
| 897 |
log.info("stage_timing", node="hybrid_retrieval", stage="embedding", duration_ms=round((time.perf_counter() - embed_start) * 1000, 2))
|
| 898 |
|
| 899 |
async def _retrieve():
|
|
|
|
| 934 |
return v_chunks, g_chunks, extra_central
|
| 935 |
|
| 936 |
retrieve_start = time.perf_counter()
|
| 937 |
+
v_chunks, g_chunks, extra_central = await _retrieve()
|
| 938 |
log.info("stage_timing", node="hybrid_retrieval", stage="vector_graph_parallel", duration_ms=round((time.perf_counter() - retrieve_start) * 1000, 2))
|
| 939 |
|
| 940 |
all_chunks = v_chunks + g_chunks + extra_central
|
| 941 |
+
all_chunks = await _prepend_pinned_sections(state, all_chunks)
|
| 942 |
log.info(
|
| 943 |
"hybrid_retrieval_complete",
|
| 944 |
vector_chunks=len(v_chunks),
|
src/civicsetu/api/routes/graph.py
CHANGED
|
@@ -241,7 +241,7 @@ async def section_context_query(
|
|
| 241 |
|
| 242 |
try:
|
| 243 |
invoke_start = time.perf_counter()
|
| 244 |
-
result = await
|
| 245 |
log.info(
|
| 246 |
"graph_invoke_complete",
|
| 247 |
route="section_context",
|
|
@@ -255,8 +255,7 @@ async def section_context_query(
|
|
| 255 |
if raw_response:
|
| 256 |
try:
|
| 257 |
update_start = time.perf_counter()
|
| 258 |
-
await
|
| 259 |
-
graph.update_state,
|
| 260 |
config,
|
| 261 |
{"messages": [ChatMessage(role="assistant", content=raw_response)]},
|
| 262 |
)
|
|
|
|
| 241 |
|
| 242 |
try:
|
| 243 |
invoke_start = time.perf_counter()
|
| 244 |
+
result = await graph.ainvoke(initial_state, config)
|
| 245 |
log.info(
|
| 246 |
"graph_invoke_complete",
|
| 247 |
route="section_context",
|
|
|
|
| 255 |
if raw_response:
|
| 256 |
try:
|
| 257 |
update_start = time.perf_counter()
|
| 258 |
+
await graph.aupdate_state(
|
|
|
|
| 259 |
config,
|
| 260 |
{"messages": [ChatMessage(role="assistant", content=raw_response)]},
|
| 261 |
)
|
src/civicsetu/api/routes/query.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
-
import asyncio
|
| 4 |
import time
|
| 5 |
import uuid
|
| 6 |
|
|
@@ -9,7 +8,7 @@ from fastapi import APIRouter, HTTPException, Request
|
|
| 9 |
|
| 10 |
from civicsetu.guardrails.input_guard import InputGuard
|
| 11 |
from civicsetu.guardrails.output_guard import OutputGuard
|
| 12 |
-
from civicsetu.models.schemas import
|
| 13 |
|
| 14 |
log = structlog.get_logger(__name__)
|
| 15 |
router = APIRouter()
|
|
@@ -49,7 +48,7 @@ async def query_endpoint(request: Request, body: QueryRequest):
|
|
| 49 |
|
| 50 |
try:
|
| 51 |
invoke_start = time.perf_counter()
|
| 52 |
-
result = await
|
| 53 |
log.info(
|
| 54 |
"graph_invoke_complete",
|
| 55 |
route="query",
|
|
@@ -63,8 +62,7 @@ async def query_endpoint(request: Request, body: QueryRequest):
|
|
| 63 |
if raw_response:
|
| 64 |
try:
|
| 65 |
update_start = time.perf_counter()
|
| 66 |
-
await
|
| 67 |
-
graph.update_state,
|
| 68 |
config,
|
| 69 |
{"messages": [ChatMessage(role="assistant", content=raw_response)]},
|
| 70 |
)
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
|
|
|
| 3 |
import time
|
| 4 |
import uuid
|
| 5 |
|
|
|
|
| 8 |
|
| 9 |
from civicsetu.guardrails.input_guard import InputGuard
|
| 10 |
from civicsetu.guardrails.output_guard import OutputGuard
|
| 11 |
+
from civicsetu.models.schemas import ChatMessage, CivicSetuResponse, InsufficientInfoResponse, QueryRequest
|
| 12 |
|
| 13 |
log = structlog.get_logger(__name__)
|
| 14 |
router = APIRouter()
|
|
|
|
| 48 |
|
| 49 |
try:
|
| 50 |
invoke_start = time.perf_counter()
|
| 51 |
+
result = await graph.ainvoke(initial_state, config)
|
| 52 |
log.info(
|
| 53 |
"graph_invoke_complete",
|
| 54 |
route="query",
|
|
|
|
| 62 |
if raw_response:
|
| 63 |
try:
|
| 64 |
update_start = time.perf_counter()
|
| 65 |
+
await graph.aupdate_state(
|
|
|
|
| 66 |
config,
|
| 67 |
{"messages": [ChatMessage(role="assistant", content=raw_response)]},
|
| 68 |
)
|
src/civicsetu/stores/graph_store.py
CHANGED
|
@@ -14,6 +14,14 @@ log = structlog.get_logger(__name__)
|
|
| 14 |
_driver: AsyncDriver | None = None
|
| 15 |
_driver_lock: asyncio.Lock = asyncio.Lock()
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
async def _get_driver() -> AsyncDriver:
|
| 19 |
"""Cached driver with lazy init — avoids 25+ driver creations per graph retrieval."""
|
|
@@ -41,6 +49,39 @@ async def close_driver() -> None:
|
|
| 41 |
_driver = None
|
| 42 |
log.info("neo4j_driver_closed")
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
async def get_driver() -> AsyncDriver:
|
| 45 |
"""Public alias for main.py lifespan — warms the singleton at startup."""
|
| 46 |
return await _get_driver()
|
|
@@ -283,90 +324,98 @@ class GraphStore:
|
|
| 283 |
jurisdiction: str,
|
| 284 |
depth: int = 1,
|
| 285 |
) -> list[dict]:
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
|
|
|
|
|
|
| 302 |
|
| 303 |
@staticmethod
|
| 304 |
async def get_sections_referencing(
|
| 305 |
section_id: str,
|
| 306 |
jurisdiction: str,
|
| 307 |
) -> list[dict]:
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
|
|
|
|
|
|
| 324 |
|
| 325 |
@staticmethod
|
| 326 |
async def get_derived_act_sections(
|
| 327 |
rule_section_id: str,
|
| 328 |
rule_jurisdiction: str,
|
| 329 |
) -> list[dict]:
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
|
|
|
|
|
|
| 347 |
|
| 348 |
@staticmethod
|
| 349 |
async def get_deriving_rule_sections(
|
| 350 |
act_section_id: str,
|
| 351 |
act_jurisdiction: str = "CENTRAL",
|
| 352 |
) -> list[dict]:
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
|
|
|
|
|
|
| 370 |
|
| 371 |
@staticmethod
|
| 372 |
async def get_sections_for_document(doc_id: str) -> list[dict]:
|
|
@@ -386,50 +435,54 @@ class GraphStore:
|
|
| 386 |
|
| 387 |
@staticmethod
|
| 388 |
async def get_topology() -> tuple[list[dict], list[dict]]:
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
|
|
|
|
|
|
| 418 |
|
| 419 |
@staticmethod
|
| 420 |
async def graph_stats() -> dict:
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
|
|
|
|
|
|
| 435 |
|
|
|
|
| 14 |
_driver: AsyncDriver | None = None
|
| 15 |
_driver_lock: asyncio.Lock = asyncio.Lock()
|
| 16 |
|
| 17 |
+
_TRANSIENT_CONNECTION_MARKERS = (
|
| 18 |
+
"defunct connection",
|
| 19 |
+
"the connection is closed",
|
| 20 |
+
"connection reset by peer",
|
| 21 |
+
"unable to retrieve routing information",
|
| 22 |
+
"service unavailable",
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
|
| 26 |
async def _get_driver() -> AsyncDriver:
|
| 27 |
"""Cached driver with lazy init — avoids 25+ driver creations per graph retrieval."""
|
|
|
|
| 49 |
_driver = None
|
| 50 |
log.info("neo4j_driver_closed")
|
| 51 |
|
| 52 |
+
|
| 53 |
+
async def _reset_driver() -> None:
|
| 54 |
+
global _driver
|
| 55 |
+
async with _driver_lock:
|
| 56 |
+
if _driver is not None:
|
| 57 |
+
try:
|
| 58 |
+
await _driver.close()
|
| 59 |
+
finally:
|
| 60 |
+
_driver = None
|
| 61 |
+
log.info("neo4j_driver_reset")
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _is_transient_connection_error(error: Exception) -> bool:
|
| 65 |
+
message = str(error).lower()
|
| 66 |
+
return any(marker in message for marker in _TRANSIENT_CONNECTION_MARKERS)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
async def _with_reconnected_driver(operation):
|
| 70 |
+
try:
|
| 71 |
+
driver = await _get_driver()
|
| 72 |
+
return await operation(driver)
|
| 73 |
+
except Exception as first_error:
|
| 74 |
+
if not _is_transient_connection_error(first_error):
|
| 75 |
+
raise
|
| 76 |
+
|
| 77 |
+
log.warning("neo4j_driver_reconnecting", error=str(first_error))
|
| 78 |
+
await _reset_driver()
|
| 79 |
+
driver = await _get_driver()
|
| 80 |
+
try:
|
| 81 |
+
return await operation(driver)
|
| 82 |
+
except Exception:
|
| 83 |
+
raise first_error
|
| 84 |
+
|
| 85 |
async def get_driver() -> AsyncDriver:
|
| 86 |
"""Public alias for main.py lifespan — warms the singleton at startup."""
|
| 87 |
return await _get_driver()
|
|
|
|
| 324 |
jurisdiction: str,
|
| 325 |
depth: int = 1,
|
| 326 |
) -> list[dict]:
|
| 327 |
+
async def _op(driver: AsyncDriver) -> list[dict]:
|
| 328 |
+
async with driver.session() as session:
|
| 329 |
+
result = await session.run(
|
| 330 |
+
f"""
|
| 331 |
+
MATCH (src:Section {{section_id: $section_id, jurisdiction: $jurisdiction}})
|
| 332 |
+
-[:REFERENCES*1..{depth}]->(tgt:Section)
|
| 333 |
+
WHERE tgt.is_active = true
|
| 334 |
+
RETURN DISTINCT tgt.section_id AS section_id,
|
| 335 |
+
tgt.title AS title,
|
| 336 |
+
tgt.chunk_id AS chunk_id,
|
| 337 |
+
tgt.jurisdiction AS jurisdiction
|
| 338 |
+
""",
|
| 339 |
+
section_id=section_id,
|
| 340 |
+
jurisdiction=jurisdiction,
|
| 341 |
+
)
|
| 342 |
+
return await result.data()
|
| 343 |
+
|
| 344 |
+
return await _with_reconnected_driver(_op)
|
| 345 |
|
| 346 |
@staticmethod
|
| 347 |
async def get_sections_referencing(
|
| 348 |
section_id: str,
|
| 349 |
jurisdiction: str,
|
| 350 |
) -> list[dict]:
|
| 351 |
+
async def _op(driver: AsyncDriver) -> list[dict]:
|
| 352 |
+
async with driver.session() as session:
|
| 353 |
+
result = await session.run(
|
| 354 |
+
"""
|
| 355 |
+
MATCH (src:Section)-[:REFERENCES]->
|
| 356 |
+
(tgt:Section {section_id: $section_id, jurisdiction: $jurisdiction})
|
| 357 |
+
WHERE src.is_active = true
|
| 358 |
+
RETURN DISTINCT src.section_id AS section_id,
|
| 359 |
+
src.title AS title,
|
| 360 |
+
src.chunk_id AS chunk_id,
|
| 361 |
+
src.jurisdiction AS jurisdiction
|
| 362 |
+
""",
|
| 363 |
+
section_id=section_id,
|
| 364 |
+
jurisdiction=jurisdiction,
|
| 365 |
+
)
|
| 366 |
+
return await result.data()
|
| 367 |
+
|
| 368 |
+
return await _with_reconnected_driver(_op)
|
| 369 |
|
| 370 |
@staticmethod
|
| 371 |
async def get_derived_act_sections(
|
| 372 |
rule_section_id: str,
|
| 373 |
rule_jurisdiction: str,
|
| 374 |
) -> list[dict]:
|
| 375 |
+
async def _op(driver: AsyncDriver) -> list[dict]:
|
| 376 |
+
async with driver.session() as session:
|
| 377 |
+
result = await session.run(
|
| 378 |
+
"""
|
| 379 |
+
MATCH (rule_sec:Section {section_id: $section_id, jurisdiction: $jurisdiction})
|
| 380 |
+
-[:DERIVED_FROM]->(act_sec:Section)
|
| 381 |
+
WHERE act_sec.is_active = true
|
| 382 |
+
RETURN DISTINCT act_sec.section_id AS section_id,
|
| 383 |
+
act_sec.title AS title,
|
| 384 |
+
act_sec.chunk_id AS chunk_id,
|
| 385 |
+
act_sec.jurisdiction AS jurisdiction,
|
| 386 |
+
act_sec.doc_name AS doc_name
|
| 387 |
+
""",
|
| 388 |
+
section_id=rule_section_id,
|
| 389 |
+
jurisdiction=rule_jurisdiction,
|
| 390 |
+
)
|
| 391 |
+
return await result.data()
|
| 392 |
+
|
| 393 |
+
return await _with_reconnected_driver(_op)
|
| 394 |
|
| 395 |
@staticmethod
|
| 396 |
async def get_deriving_rule_sections(
|
| 397 |
act_section_id: str,
|
| 398 |
act_jurisdiction: str = "CENTRAL",
|
| 399 |
) -> list[dict]:
|
| 400 |
+
async def _op(driver: AsyncDriver) -> list[dict]:
|
| 401 |
+
async with driver.session() as session:
|
| 402 |
+
result = await session.run(
|
| 403 |
+
"""
|
| 404 |
+
MATCH (rule_sec:Section)-[:DERIVED_FROM]->
|
| 405 |
+
(act_sec:Section {section_id: $section_id, jurisdiction: $jurisdiction})
|
| 406 |
+
WHERE rule_sec.is_active = true
|
| 407 |
+
RETURN DISTINCT rule_sec.section_id AS section_id,
|
| 408 |
+
rule_sec.title AS title,
|
| 409 |
+
rule_sec.chunk_id AS chunk_id,
|
| 410 |
+
rule_sec.jurisdiction AS jurisdiction,
|
| 411 |
+
rule_sec.doc_name AS doc_name
|
| 412 |
+
""",
|
| 413 |
+
section_id=act_section_id,
|
| 414 |
+
jurisdiction=act_jurisdiction,
|
| 415 |
+
)
|
| 416 |
+
return await result.data()
|
| 417 |
+
|
| 418 |
+
return await _with_reconnected_driver(_op)
|
| 419 |
|
| 420 |
@staticmethod
|
| 421 |
async def get_sections_for_document(doc_id: str) -> list[dict]:
|
|
|
|
| 435 |
|
| 436 |
@staticmethod
|
| 437 |
async def get_topology() -> tuple[list[dict], list[dict]]:
|
| 438 |
+
async def _op(driver: AsyncDriver) -> tuple[list[dict], list[dict]]:
|
| 439 |
+
async with driver.session() as session:
|
| 440 |
+
edges_result = await session.run(
|
| 441 |
+
"""
|
| 442 |
+
MATCH (s:Section)-[r]->(t:Section)
|
| 443 |
+
WHERE type(r) IN ['REFERENCES', 'DERIVED_FROM']
|
| 444 |
+
AND s.is_active = true AND t.is_active = true
|
| 445 |
+
RETURN s.chunk_id AS source, t.chunk_id AS target, type(r) AS edge_type
|
| 446 |
+
"""
|
| 447 |
+
)
|
| 448 |
+
nodes_result = await session.run(
|
| 449 |
+
"""
|
| 450 |
+
MATCH (s:Section)-[r]-()
|
| 451 |
+
WHERE type(r) IN ['REFERENCES', 'DERIVED_FROM']
|
| 452 |
+
AND s.is_active = true
|
| 453 |
+
WITH s, count(r) AS conn_count
|
| 454 |
+
RETURN DISTINCT
|
| 455 |
+
s.chunk_id AS chunk_id,
|
| 456 |
+
s.section_id AS section_id,
|
| 457 |
+
s.title AS title,
|
| 458 |
+
s.jurisdiction AS jurisdiction,
|
| 459 |
+
s.doc_name AS doc_name,
|
| 460 |
+
s.is_active AS is_active,
|
| 461 |
+
conn_count AS connection_count
|
| 462 |
+
"""
|
| 463 |
+
)
|
| 464 |
+
edges = await edges_result.data()
|
| 465 |
+
nodes = await nodes_result.data()
|
| 466 |
+
return nodes, edges
|
| 467 |
+
|
| 468 |
+
return await _with_reconnected_driver(_op)
|
| 469 |
|
| 470 |
@staticmethod
|
| 471 |
async def graph_stats() -> dict:
|
| 472 |
+
async def _op(driver: AsyncDriver) -> dict:
|
| 473 |
+
async with driver.session() as session:
|
| 474 |
+
result = await session.run(
|
| 475 |
+
"""
|
| 476 |
+
RETURN
|
| 477 |
+
count { MATCH (d:Document) RETURN d } AS docs,
|
| 478 |
+
count { MATCH (s:Section) RETURN s } AS sections,
|
| 479 |
+
count { MATCH ()-[:REFERENCES]->() RETURN 1 } AS refs,
|
| 480 |
+
count { MATCH ()-[:HAS_SECTION]->() RETURN 1 } AS has_sec,
|
| 481 |
+
count { MATCH ()-[:DERIVED_FROM]->() RETURN 1 } AS derived_from
|
| 482 |
+
"""
|
| 483 |
+
)
|
| 484 |
+
record = await result.single()
|
| 485 |
+
return dict(record) if record else {}
|
| 486 |
+
|
| 487 |
+
return await _with_reconnected_driver(_op)
|
| 488 |
|
tests/unit/agent/test_nodes.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
|
|
|
| 3 |
import json
|
| 4 |
-
import uuid
|
| 5 |
from unittest.mock import MagicMock, patch
|
| 6 |
|
| 7 |
import pytest
|
|
@@ -14,9 +14,9 @@ def test_reranker_settings_defaults():
|
|
| 14 |
"""Settings ship with safe, sensible defaults — no .env needed."""
|
| 15 |
from civicsetu.config.settings import Settings
|
| 16 |
s = Settings()
|
| 17 |
-
assert s.reranker_model == "
|
| 18 |
assert s.reranker_score_threshold == 0.05
|
| 19 |
-
assert s.reranker_score_gap == 0.
|
| 20 |
|
| 21 |
|
| 22 |
def test_reranker_settings_env_override(monkeypatch):
|
|
@@ -51,13 +51,15 @@ def test_neo4j_username_env_alias_override(monkeypatch):
|
|
| 51 |
get_settings.cache_clear()
|
| 52 |
|
| 53 |
|
| 54 |
-
|
|
|
|
| 55 |
from civicsetu.agent.nodes import reranker_node
|
| 56 |
-
result = reranker_node(_base_state(retrieved_chunks=[], reranked_chunks=[]))
|
| 57 |
assert result["reranked_chunks"] == []
|
| 58 |
|
| 59 |
|
| 60 |
-
|
|
|
|
| 61 |
from civicsetu.agent.nodes import reranker_node
|
| 62 |
|
| 63 |
pinned = _make_rc(section_id="18", is_pinned=True)
|
|
@@ -74,14 +76,15 @@ def test_reranker_pinned_chunks_always_first():
|
|
| 74 |
reranked_chunks=[],
|
| 75 |
query="test query",
|
| 76 |
)
|
| 77 |
-
result = reranker_node(state)
|
| 78 |
|
| 79 |
reranked = result["reranked_chunks"]
|
| 80 |
assert reranked[0].is_pinned is True
|
| 81 |
assert reranked[0].chunk.section_id == "18"
|
| 82 |
|
| 83 |
|
| 84 |
-
|
|
|
|
| 85 |
from civicsetu.agent.nodes import reranker_node
|
| 86 |
|
| 87 |
pinned_chunks = [_make_rc(section_id=str(i), is_pinned=True) for i in range(4)]
|
|
@@ -95,13 +98,14 @@ def test_reranker_keeps_pinned_chunks_up_to_context_limit():
|
|
| 95 |
reranked_chunks=[],
|
| 96 |
query="test query",
|
| 97 |
)
|
| 98 |
-
result = reranker_node(state)
|
| 99 |
|
| 100 |
pinned_in_result = [c for c in result["reranked_chunks"] if c.is_pinned]
|
| 101 |
assert len(pinned_in_result) == 4
|
| 102 |
|
| 103 |
|
| 104 |
-
|
|
|
|
| 105 |
from civicsetu.agent.nodes import reranker_node
|
| 106 |
|
| 107 |
chunk = _make_rc(section_id="18")
|
|
@@ -118,7 +122,7 @@ def test_reranker_deduplicates_by_chunk_id():
|
|
| 118 |
reranked_chunks=[],
|
| 119 |
query="test query",
|
| 120 |
)
|
| 121 |
-
result = reranker_node(state)
|
| 122 |
|
| 123 |
all_ids = [str(c.chunk.chunk_id) for c in result["reranked_chunks"]]
|
| 124 |
assert len(all_ids) == len(set(all_ids))
|
|
@@ -178,7 +182,8 @@ def test_pin_relevance_prefers_specific_subclauses_over_base_header_on_tie():
|
|
| 178 |
assert [c.chunk.section_id for c in ranked] == ["7(6)", "7"]
|
| 179 |
|
| 180 |
|
| 181 |
-
|
|
|
|
| 182 |
from civicsetu.agent.nodes import _prepend_pinned_sections
|
| 183 |
from civicsetu.models.enums import Jurisdiction
|
| 184 |
|
|
@@ -186,18 +191,17 @@ def test_prepend_pinned_sections_promotes_existing_matches(monkeypatch):
|
|
| 186 |
rule_5 = _make_rc(section_id="5", jurisdiction=Jurisdiction.KARNATAKA)
|
| 187 |
section_4 = _make_rc(section_id="4(14)", jurisdiction=Jurisdiction.CENTRAL)
|
| 188 |
|
| 189 |
-
def
|
| 190 |
-
coro.close()
|
| 191 |
return []
|
| 192 |
|
| 193 |
-
monkeypatch.setattr("civicsetu.agent.nodes.
|
| 194 |
|
| 195 |
-
result = _prepend_pinned_sections(
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
[noisy, rule_5, section_4],
|
| 202 |
)
|
| 203 |
|
|
@@ -206,7 +210,8 @@ def test_prepend_pinned_sections_promotes_existing_matches(monkeypatch):
|
|
| 206 |
assert result[1].is_pinned is True
|
| 207 |
|
| 208 |
|
| 209 |
-
|
|
|
|
| 210 |
from civicsetu.agent.nodes import reranker_node
|
| 211 |
from civicsetu.models.enums import Jurisdiction
|
| 212 |
|
|
@@ -215,11 +220,12 @@ def test_reranker_node_trims_eval_context_to_pinned_families():
|
|
| 215 |
noisy_central = _make_rc(section_id="4(10)", jurisdiction=Jurisdiction.CENTRAL)
|
| 216 |
noisy_state = _make_rc(section_id="3", jurisdiction=Jurisdiction.UTTAR_PRADESH)
|
| 217 |
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
|
|
|
| 223 |
_base_state(
|
| 224 |
query="Which Karnataka rule implements project account maintenance?",
|
| 225 |
rewritten_query="Which Karnataka rule implements project account maintenance?",
|
|
@@ -244,7 +250,8 @@ def _llm_response(**kwargs) -> str:
|
|
| 244 |
return json.dumps(payload)
|
| 245 |
|
| 246 |
|
| 247 |
-
|
|
|
|
| 248 |
from civicsetu.agent.nodes import generator_node
|
| 249 |
|
| 250 |
chunks = [_make_rc(section_id=str(i)) for i in range(1, 6)]
|
|
@@ -256,14 +263,15 @@ def test_generator_cites_only_referenced_chunks():
|
|
| 256 |
query="What does Section 1 say?",
|
| 257 |
reranked_chunks=chunks,
|
| 258 |
)
|
| 259 |
-
result = generator_node(state)
|
| 260 |
|
| 261 |
assert len(result["citations"]) == 2
|
| 262 |
cited_ids = {c.section_id for c in result["citations"]}
|
| 263 |
assert cited_ids == {"1", "3"}
|
| 264 |
|
| 265 |
|
| 266 |
-
|
|
|
|
| 267 |
from civicsetu.agent.nodes import generator_node
|
| 268 |
|
| 269 |
chunks = [_make_rc(section_id=str(i)) for i in range(1, 4)]
|
|
@@ -271,13 +279,14 @@ def test_generator_fallback_when_cited_chunks_empty():
|
|
| 271 |
|
| 272 |
with patch("civicsetu.agent.nodes._llm_call", return_value=llm_out):
|
| 273 |
state = _base_state(reranked_chunks=chunks)
|
| 274 |
-
result = generator_node(state)
|
| 275 |
|
| 276 |
# Fallback: all 3 chunks cited
|
| 277 |
assert len(result["citations"]) == 3
|
| 278 |
|
| 279 |
|
| 280 |
-
|
|
|
|
| 281 |
from civicsetu.agent.nodes import generator_node
|
| 282 |
|
| 283 |
chunks = [_make_rc(section_id=str(i)) for i in range(1, 4)]
|
|
@@ -286,48 +295,52 @@ def test_generator_filters_invalid_indices():
|
|
| 286 |
|
| 287 |
with patch("civicsetu.agent.nodes._llm_call", return_value=llm_out):
|
| 288 |
state = _base_state(reranked_chunks=chunks)
|
| 289 |
-
result = generator_node(state)
|
| 290 |
|
| 291 |
# Only index 2 is valid → 1 citation
|
| 292 |
assert len(result["citations"]) == 1
|
| 293 |
assert result["citations"][0].section_id == "2"
|
| 294 |
|
| 295 |
|
| 296 |
-
|
|
|
|
| 297 |
from civicsetu.agent.nodes import generator_node
|
| 298 |
|
| 299 |
-
result = generator_node(_base_state(reranked_chunks=[]))
|
| 300 |
assert result["citations"] == []
|
| 301 |
assert result["confidence_score"] == 0.0
|
| 302 |
|
| 303 |
|
| 304 |
-
|
|
|
|
| 305 |
from civicsetu.agent.nodes import generator_node
|
| 306 |
|
| 307 |
chunks = [_make_rc(section_id="18")]
|
| 308 |
with patch("civicsetu.agent.nodes._llm_call", return_value="not json {{{{"):
|
| 309 |
state = _base_state(reranked_chunks=chunks)
|
| 310 |
-
result = generator_node(state)
|
| 311 |
# Salvage path: non-empty malformed text is returned as answer with 0.3 confidence
|
| 312 |
assert result["confidence_score"] == 0.3
|
| 313 |
assert len(result["citations"]) == 1 # salvage cites all provided chunks
|
| 314 |
|
| 315 |
|
| 316 |
-
|
|
|
|
| 317 |
from civicsetu.agent.nodes import generator_node
|
| 318 |
|
| 319 |
chunks = [_make_rc(section_id="18"), _make_rc(section_id="31")]
|
| 320 |
raw = "Promoter must register project, disclose details, and honor timelines."
|
| 321 |
|
| 322 |
with patch("civicsetu.agent.nodes._llm_call", return_value=raw):
|
| 323 |
-
result = generator_node(_base_state(reranked_chunks=chunks))
|
| 324 |
|
| 325 |
assert result["raw_response"] == raw
|
| 326 |
assert result["confidence_score"] == 0.3
|
| 327 |
assert len(result["citations"]) == 2
|
| 328 |
|
| 329 |
|
| 330 |
-
|
|
|
|
| 331 |
from civicsetu.agent.nodes import classifier_node
|
| 332 |
from civicsetu.models.enums import QueryType
|
| 333 |
|
|
@@ -341,13 +354,14 @@ Thinking:
|
|
| 341 |
"""
|
| 342 |
|
| 343 |
with patch("civicsetu.agent.nodes._llm_call", return_value=wrapped):
|
| 344 |
-
result = classifier_node(_base_state(query="What does Section 18 say?"))
|
| 345 |
|
| 346 |
assert result["query_type"] == QueryType.CROSS_REFERENCE
|
| 347 |
assert result["rewritten_query"] == "Section 18 refund obligations under RERA Act"
|
| 348 |
|
| 349 |
|
| 350 |
-
|
|
|
|
| 351 |
from civicsetu.agent.nodes import generator_node
|
| 352 |
|
| 353 |
chunks = [_make_rc(section_id="18")]
|
|
@@ -360,14 +374,15 @@ Here is structured answer.
|
|
| 360 |
"""
|
| 361 |
|
| 362 |
with patch("civicsetu.agent.nodes._llm_call", return_value=wrapped):
|
| 363 |
-
result = generator_node(_base_state(reranked_chunks=chunks))
|
| 364 |
|
| 365 |
assert result["raw_response"] == "Refund due under Section 18."
|
| 366 |
assert result["confidence_score"] == 0.8
|
| 367 |
assert len(result["citations"]) == 1
|
| 368 |
|
| 369 |
|
| 370 |
-
|
|
|
|
| 371 |
import civicsetu.agent.nodes as nodes_mod
|
| 372 |
|
| 373 |
monkeypatch.setenv("OPENAI_API_BASE", "https://api.osmapi.com/v1")
|
|
@@ -376,14 +391,16 @@ def test_llm_call_uses_json_mode_for_osmapi(monkeypatch):
|
|
| 376 |
fake_response.choices = [MagicMock(message=MagicMock(content='{"ok": true}'))]
|
| 377 |
fake_response.usage = None
|
| 378 |
|
| 379 |
-
with patch("civicsetu.agent.nodes.litellm.
|
| 380 |
-
|
|
|
|
| 381 |
|
| 382 |
assert result == '{"ok": true}'
|
| 383 |
assert completion.call_args.kwargs["response_format"] == {"type": "json_object"}
|
| 384 |
|
| 385 |
|
| 386 |
-
|
|
|
|
| 387 |
import civicsetu.agent.nodes as nodes_mod
|
| 388 |
|
| 389 |
monkeypatch.setenv("OPENAI_API_BASE", "https://api.osmapi.com/v1")
|
|
@@ -395,16 +412,18 @@ def test_llm_call_does_not_force_no_reasoning_for_osmapi(monkeypatch):
|
|
| 395 |
original_models = nodes_mod.THINKING_MODELS[:]
|
| 396 |
nodes_mod.THINKING_MODELS[:] = ["openai/gpt-4o-mini"]
|
| 397 |
|
| 398 |
-
with patch("civicsetu.agent.nodes.litellm.
|
|
|
|
| 399 |
try:
|
| 400 |
-
nodes_mod._llm_call("prompt", "system")
|
| 401 |
finally:
|
| 402 |
nodes_mod.THINKING_MODELS[:] = original_models
|
| 403 |
|
| 404 |
assert "extra_body" not in completion.call_args.kwargs
|
| 405 |
|
| 406 |
|
| 407 |
-
|
|
|
|
| 408 |
import civicsetu.agent.nodes as nodes_mod
|
| 409 |
|
| 410 |
fake_response = MagicMock()
|
|
@@ -414,9 +433,10 @@ def test_llm_call_does_not_attach_deepseek_kwargs_to_non_deepseek_models():
|
|
| 414 |
original_models = nodes_mod.THINKING_MODELS[:]
|
| 415 |
nodes_mod.THINKING_MODELS[:] = ["groq/llama-3.3-70b-versatile"]
|
| 416 |
|
| 417 |
-
with patch("civicsetu.agent.nodes.litellm.
|
|
|
|
| 418 |
try:
|
| 419 |
-
nodes_mod._llm_call("prompt", "system", tier="thinking")
|
| 420 |
finally:
|
| 421 |
nodes_mod.THINKING_MODELS[:] = original_models
|
| 422 |
|
|
@@ -424,7 +444,8 @@ def test_llm_call_does_not_attach_deepseek_kwargs_to_non_deepseek_models():
|
|
| 424 |
assert "extra_body" not in completion.call_args.kwargs
|
| 425 |
|
| 426 |
|
| 427 |
-
|
|
|
|
| 428 |
from civicsetu.agent.nodes import generator_node
|
| 429 |
|
| 430 |
# Two chunks with same (section_id, doc_name) — different chunk_ids
|
|
@@ -434,18 +455,19 @@ def test_generator_deduplicates_citations():
|
|
| 434 |
|
| 435 |
with patch("civicsetu.agent.nodes._llm_call", return_value=llm_out):
|
| 436 |
state = _base_state(reranked_chunks=[chunk_a, chunk_b])
|
| 437 |
-
result = generator_node(state)
|
| 438 |
|
| 439 |
assert len(result["citations"]) == 1
|
| 440 |
|
| 441 |
|
| 442 |
-
|
|
|
|
| 443 |
from civicsetu.agent.nodes import generator_node
|
| 444 |
from civicsetu.models.enums import QueryType
|
| 445 |
|
| 446 |
captured = {}
|
| 447 |
|
| 448 |
-
def fake_llm_call(prompt: str, system: str, temperature: float = 0.0) -> str:
|
| 449 |
captured["prompt"] = prompt
|
| 450 |
captured["system"] = system
|
| 451 |
return _llm_response(answer="If a builder misses the deadline, the buyer gets a remedy.")
|
|
@@ -455,7 +477,7 @@ def test_generator_system_prompt_uses_plain_language_persona():
|
|
| 455 |
query_type=QueryType.FACT_LOOKUP,
|
| 456 |
reranked_chunks=[_make_rc(section_id="18")],
|
| 457 |
)
|
| 458 |
-
generator_node(state)
|
| 459 |
|
| 460 |
assert "plain-language guide to Indian RERA laws" in captured["system"]
|
| 461 |
assert "explain what the law means in practice" in captured["system"]
|
|
@@ -474,12 +496,13 @@ def test_generator_system_prompt_uses_plain_language_persona():
|
|
| 474 |
("temporal", "Lead with the specific time period or deadline"),
|
| 475 |
],
|
| 476 |
)
|
| 477 |
-
|
|
|
|
| 478 |
from civicsetu.agent.nodes import generator_node
|
| 479 |
|
| 480 |
captured = {}
|
| 481 |
|
| 482 |
-
def fake_llm_call(prompt: str, system: str, temperature: float = 0.0) -> str:
|
| 483 |
captured["system"] = system
|
| 484 |
return _llm_response()
|
| 485 |
|
|
@@ -488,7 +511,7 @@ def test_generator_system_prompt_includes_query_type_tone_hint(query_type, expec
|
|
| 488 |
query_type=query_type,
|
| 489 |
reranked_chunks=[_make_rc(section_id="18")],
|
| 490 |
)
|
| 491 |
-
generator_node(state)
|
| 492 |
|
| 493 |
assert expected_hint in captured["system"]
|
| 494 |
|
|
@@ -502,12 +525,12 @@ def test_get_ranker_uses_settings_model():
|
|
| 502 |
reranker_mod._ranker = None # clear module-level cache
|
| 503 |
|
| 504 |
with patch("civicsetu.retrieval.reranker.settings") as mock_settings:
|
| 505 |
-
mock_settings.reranker_model = "
|
| 506 |
with patch("flashrank.Ranker") as MockRanker:
|
| 507 |
MockRanker.return_value = MagicMock()
|
| 508 |
reranker_mod._get_ranker()
|
| 509 |
MockRanker.assert_called_once_with(
|
| 510 |
-
model_name="
|
| 511 |
)
|
| 512 |
reranker_mod._ranker = None # clean up
|
| 513 |
|
|
@@ -602,7 +625,8 @@ def test_score_gap_new_threshold_still_cuts_on_cliff():
|
|
| 602 |
|
| 603 |
# ── reranker_node threshold + gap filtering ───────────────────────────────────
|
| 604 |
|
| 605 |
-
|
|
|
|
| 606 |
"""Chunks scoring below reranker_score_threshold must not appear in output."""
|
| 607 |
from civicsetu.agent.nodes import reranker_node
|
| 608 |
from unittest.mock import patch, MagicMock
|
|
@@ -626,14 +650,15 @@ def test_reranker_drops_below_threshold():
|
|
| 626 |
reranked_chunks=[],
|
| 627 |
query="test query",
|
| 628 |
)
|
| 629 |
-
result = reranker_node(state)
|
| 630 |
|
| 631 |
section_ids = [c.chunk.section_id for c in result["reranked_chunks"]]
|
| 632 |
assert "1" in section_ids
|
| 633 |
assert "2" not in section_ids
|
| 634 |
|
| 635 |
|
| 636 |
-
|
|
|
|
| 637 |
"""A large score cliff stops inclusion even if remaining chunks exceed threshold."""
|
| 638 |
from civicsetu.agent.nodes import reranker_node
|
| 639 |
from unittest.mock import patch, MagicMock
|
|
@@ -660,7 +685,7 @@ def test_reranker_applies_score_gap():
|
|
| 660 |
reranked_chunks=[],
|
| 661 |
query="test query",
|
| 662 |
)
|
| 663 |
-
result = reranker_node(state)
|
| 664 |
|
| 665 |
section_ids = [c.chunk.section_id for c in result["reranked_chunks"]]
|
| 666 |
assert "1" in section_ids
|
|
@@ -668,7 +693,8 @@ def test_reranker_applies_score_gap():
|
|
| 668 |
assert "3" not in section_ids
|
| 669 |
|
| 670 |
|
| 671 |
-
|
|
|
|
| 672 |
"""reranker_filtered log event must report correct before/after/dropped counts."""
|
| 673 |
from civicsetu.agent.nodes import reranker_node
|
| 674 |
from unittest.mock import patch, MagicMock
|
|
@@ -693,7 +719,7 @@ def test_reranker_filtered_count_logged():
|
|
| 693 |
reranked_chunks=[],
|
| 694 |
query="test query",
|
| 695 |
)
|
| 696 |
-
result = reranker_node(state)
|
| 697 |
|
| 698 |
# Verify reranked_chunks output
|
| 699 |
assert len(result["reranked_chunks"]) == 1
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
+
import asyncio
|
| 4 |
import json
|
|
|
|
| 5 |
from unittest.mock import MagicMock, patch
|
| 6 |
|
| 7 |
import pytest
|
|
|
|
| 14 |
"""Settings ship with safe, sensible defaults — no .env needed."""
|
| 15 |
from civicsetu.config.settings import Settings
|
| 16 |
s = Settings()
|
| 17 |
+
assert s.reranker_model == "ms-marco-MiniLM-L-12-v2"
|
| 18 |
assert s.reranker_score_threshold == 0.05
|
| 19 |
+
assert s.reranker_score_gap == 0.95
|
| 20 |
|
| 21 |
|
| 22 |
def test_reranker_settings_env_override(monkeypatch):
|
|
|
|
| 51 |
get_settings.cache_clear()
|
| 52 |
|
| 53 |
|
| 54 |
+
@pytest.mark.asyncio
|
| 55 |
+
async def test_reranker_empty_chunks():
|
| 56 |
from civicsetu.agent.nodes import reranker_node
|
| 57 |
+
result = await reranker_node(_base_state(retrieved_chunks=[], reranked_chunks=[]))
|
| 58 |
assert result["reranked_chunks"] == []
|
| 59 |
|
| 60 |
|
| 61 |
+
@pytest.mark.asyncio
|
| 62 |
+
async def test_reranker_pinned_chunks_always_first():
|
| 63 |
from civicsetu.agent.nodes import reranker_node
|
| 64 |
|
| 65 |
pinned = _make_rc(section_id="18", is_pinned=True)
|
|
|
|
| 76 |
reranked_chunks=[],
|
| 77 |
query="test query",
|
| 78 |
)
|
| 79 |
+
result = await reranker_node(state)
|
| 80 |
|
| 81 |
reranked = result["reranked_chunks"]
|
| 82 |
assert reranked[0].is_pinned is True
|
| 83 |
assert reranked[0].chunk.section_id == "18"
|
| 84 |
|
| 85 |
|
| 86 |
+
@pytest.mark.asyncio
|
| 87 |
+
async def test_reranker_keeps_pinned_chunks_up_to_context_limit():
|
| 88 |
from civicsetu.agent.nodes import reranker_node
|
| 89 |
|
| 90 |
pinned_chunks = [_make_rc(section_id=str(i), is_pinned=True) for i in range(4)]
|
|
|
|
| 98 |
reranked_chunks=[],
|
| 99 |
query="test query",
|
| 100 |
)
|
| 101 |
+
result = await reranker_node(state)
|
| 102 |
|
| 103 |
pinned_in_result = [c for c in result["reranked_chunks"] if c.is_pinned]
|
| 104 |
assert len(pinned_in_result) == 4
|
| 105 |
|
| 106 |
|
| 107 |
+
@pytest.mark.asyncio
|
| 108 |
+
async def test_reranker_deduplicates_by_chunk_id():
|
| 109 |
from civicsetu.agent.nodes import reranker_node
|
| 110 |
|
| 111 |
chunk = _make_rc(section_id="18")
|
|
|
|
| 122 |
reranked_chunks=[],
|
| 123 |
query="test query",
|
| 124 |
)
|
| 125 |
+
result = await reranker_node(state)
|
| 126 |
|
| 127 |
all_ids = [str(c.chunk.chunk_id) for c in result["reranked_chunks"]]
|
| 128 |
assert len(all_ids) == len(set(all_ids))
|
|
|
|
| 182 |
assert [c.chunk.section_id for c in ranked] == ["7(6)", "7"]
|
| 183 |
|
| 184 |
|
| 185 |
+
@pytest.mark.asyncio
|
| 186 |
+
async def test_prepend_pinned_sections_promotes_existing_matches(monkeypatch):
|
| 187 |
from civicsetu.agent.nodes import _prepend_pinned_sections
|
| 188 |
from civicsetu.models.enums import Jurisdiction
|
| 189 |
|
|
|
|
| 191 |
rule_5 = _make_rc(section_id="5", jurisdiction=Jurisdiction.KARNATAKA)
|
| 192 |
section_4 = _make_rc(section_id="4(14)", jurisdiction=Jurisdiction.CENTRAL)
|
| 193 |
|
| 194 |
+
async def fake_fetch(refs, jur, existing, hint=""):
|
|
|
|
| 195 |
return []
|
| 196 |
|
| 197 |
+
monkeypatch.setattr("civicsetu.agent.nodes._fetch_pinned_sections", fake_fetch)
|
| 198 |
|
| 199 |
+
result = await _prepend_pinned_sections(
|
| 200 |
+
_base_state(
|
| 201 |
+
pinned_section_refs=["Rule 5", "Section 4"],
|
| 202 |
+
pinned_section_jurisdiction=Jurisdiction.KARNATAKA,
|
| 203 |
+
pinned_section_hint="separate bank account seventy per cent",
|
| 204 |
+
),
|
| 205 |
[noisy, rule_5, section_4],
|
| 206 |
)
|
| 207 |
|
|
|
|
| 210 |
assert result[1].is_pinned is True
|
| 211 |
|
| 212 |
|
| 213 |
+
@pytest.mark.asyncio
|
| 214 |
+
async def test_reranker_node_trims_eval_context_to_pinned_families():
|
| 215 |
from civicsetu.agent.nodes import reranker_node
|
| 216 |
from civicsetu.models.enums import Jurisdiction
|
| 217 |
|
|
|
|
| 220 |
noisy_central = _make_rc(section_id="4(10)", jurisdiction=Jurisdiction.CENTRAL)
|
| 221 |
noisy_state = _make_rc(section_id="3", jurisdiction=Jurisdiction.UTTAR_PRADESH)
|
| 222 |
|
| 223 |
+
# Use a side effect to return only the input chunks that were not filtered out
|
| 224 |
+
def mock_rerank(chunks, query):
|
| 225 |
+
return chunks
|
| 226 |
+
|
| 227 |
+
with patch("civicsetu.retrieval.reranker.Reranker.rerank", side_effect=mock_rerank):
|
| 228 |
+
result = await reranker_node(
|
| 229 |
_base_state(
|
| 230 |
query="Which Karnataka rule implements project account maintenance?",
|
| 231 |
rewritten_query="Which Karnataka rule implements project account maintenance?",
|
|
|
|
| 250 |
return json.dumps(payload)
|
| 251 |
|
| 252 |
|
| 253 |
+
@pytest.mark.asyncio
|
| 254 |
+
async def test_generator_cites_only_referenced_chunks():
|
| 255 |
from civicsetu.agent.nodes import generator_node
|
| 256 |
|
| 257 |
chunks = [_make_rc(section_id=str(i)) for i in range(1, 6)]
|
|
|
|
| 263 |
query="What does Section 1 say?",
|
| 264 |
reranked_chunks=chunks,
|
| 265 |
)
|
| 266 |
+
result = await generator_node(state)
|
| 267 |
|
| 268 |
assert len(result["citations"]) == 2
|
| 269 |
cited_ids = {c.section_id for c in result["citations"]}
|
| 270 |
assert cited_ids == {"1", "3"}
|
| 271 |
|
| 272 |
|
| 273 |
+
@pytest.mark.asyncio
|
| 274 |
+
async def test_generator_fallback_when_cited_chunks_empty():
|
| 275 |
from civicsetu.agent.nodes import generator_node
|
| 276 |
|
| 277 |
chunks = [_make_rc(section_id=str(i)) for i in range(1, 4)]
|
|
|
|
| 279 |
|
| 280 |
with patch("civicsetu.agent.nodes._llm_call", return_value=llm_out):
|
| 281 |
state = _base_state(reranked_chunks=chunks)
|
| 282 |
+
result = await generator_node(state)
|
| 283 |
|
| 284 |
# Fallback: all 3 chunks cited
|
| 285 |
assert len(result["citations"]) == 3
|
| 286 |
|
| 287 |
|
| 288 |
+
@pytest.mark.asyncio
|
| 289 |
+
async def test_generator_filters_invalid_indices():
|
| 290 |
from civicsetu.agent.nodes import generator_node
|
| 291 |
|
| 292 |
chunks = [_make_rc(section_id=str(i)) for i in range(1, 4)]
|
|
|
|
| 295 |
|
| 296 |
with patch("civicsetu.agent.nodes._llm_call", return_value=llm_out):
|
| 297 |
state = _base_state(reranked_chunks=chunks)
|
| 298 |
+
result = await generator_node(state)
|
| 299 |
|
| 300 |
# Only index 2 is valid → 1 citation
|
| 301 |
assert len(result["citations"]) == 1
|
| 302 |
assert result["citations"][0].section_id == "2"
|
| 303 |
|
| 304 |
|
| 305 |
+
@pytest.mark.asyncio
|
| 306 |
+
async def test_generator_returns_empty_on_no_chunks():
|
| 307 |
from civicsetu.agent.nodes import generator_node
|
| 308 |
|
| 309 |
+
result = await generator_node(_base_state(reranked_chunks=[]))
|
| 310 |
assert result["citations"] == []
|
| 311 |
assert result["confidence_score"] == 0.0
|
| 312 |
|
| 313 |
|
| 314 |
+
@pytest.mark.asyncio
|
| 315 |
+
async def test_generator_handles_malformed_llm_json():
|
| 316 |
from civicsetu.agent.nodes import generator_node
|
| 317 |
|
| 318 |
chunks = [_make_rc(section_id="18")]
|
| 319 |
with patch("civicsetu.agent.nodes._llm_call", return_value="not json {{{{"):
|
| 320 |
state = _base_state(reranked_chunks=chunks)
|
| 321 |
+
result = await generator_node(state)
|
| 322 |
# Salvage path: non-empty malformed text is returned as answer with 0.3 confidence
|
| 323 |
assert result["confidence_score"] == 0.3
|
| 324 |
assert len(result["citations"]) == 1 # salvage cites all provided chunks
|
| 325 |
|
| 326 |
|
| 327 |
+
@pytest.mark.asyncio
|
| 328 |
+
async def test_generator_salvages_plain_text_when_json_missing():
|
| 329 |
from civicsetu.agent.nodes import generator_node
|
| 330 |
|
| 331 |
chunks = [_make_rc(section_id="18"), _make_rc(section_id="31")]
|
| 332 |
raw = "Promoter must register project, disclose details, and honor timelines."
|
| 333 |
|
| 334 |
with patch("civicsetu.agent.nodes._llm_call", return_value=raw):
|
| 335 |
+
result = await generator_node(_base_state(reranked_chunks=chunks))
|
| 336 |
|
| 337 |
assert result["raw_response"] == raw
|
| 338 |
assert result["confidence_score"] == 0.3
|
| 339 |
assert len(result["citations"]) == 2
|
| 340 |
|
| 341 |
|
| 342 |
+
@pytest.mark.asyncio
|
| 343 |
+
async def test_classifier_extracts_json_from_reasoning_wrapper():
|
| 344 |
from civicsetu.agent.nodes import classifier_node
|
| 345 |
from civicsetu.models.enums import QueryType
|
| 346 |
|
|
|
|
| 354 |
"""
|
| 355 |
|
| 356 |
with patch("civicsetu.agent.nodes._llm_call", return_value=wrapped):
|
| 357 |
+
result = await classifier_node(_base_state(query="What does Section 18 say?"))
|
| 358 |
|
| 359 |
assert result["query_type"] == QueryType.CROSS_REFERENCE
|
| 360 |
assert result["rewritten_query"] == "Section 18 refund obligations under RERA Act"
|
| 361 |
|
| 362 |
|
| 363 |
+
@pytest.mark.asyncio
|
| 364 |
+
async def test_generator_extracts_json_from_reasoning_wrapper():
|
| 365 |
from civicsetu.agent.nodes import generator_node
|
| 366 |
|
| 367 |
chunks = [_make_rc(section_id="18")]
|
|
|
|
| 374 |
"""
|
| 375 |
|
| 376 |
with patch("civicsetu.agent.nodes._llm_call", return_value=wrapped):
|
| 377 |
+
result = await generator_node(_base_state(reranked_chunks=chunks))
|
| 378 |
|
| 379 |
assert result["raw_response"] == "Refund due under Section 18."
|
| 380 |
assert result["confidence_score"] == 0.8
|
| 381 |
assert len(result["citations"]) == 1
|
| 382 |
|
| 383 |
|
| 384 |
+
@pytest.mark.asyncio
|
| 385 |
+
async def test_llm_call_uses_json_mode_for_osmapi(monkeypatch):
|
| 386 |
import civicsetu.agent.nodes as nodes_mod
|
| 387 |
|
| 388 |
monkeypatch.setenv("OPENAI_API_BASE", "https://api.osmapi.com/v1")
|
|
|
|
| 391 |
fake_response.choices = [MagicMock(message=MagicMock(content='{"ok": true}'))]
|
| 392 |
fake_response.usage = None
|
| 393 |
|
| 394 |
+
with patch("civicsetu.agent.nodes.litellm.acompletion", new=MagicMock(return_value=asyncio.Future())) as completion:
|
| 395 |
+
completion.return_value.set_result(fake_response)
|
| 396 |
+
result = await nodes_mod._llm_call("prompt", "system")
|
| 397 |
|
| 398 |
assert result == '{"ok": true}'
|
| 399 |
assert completion.call_args.kwargs["response_format"] == {"type": "json_object"}
|
| 400 |
|
| 401 |
|
| 402 |
+
@pytest.mark.asyncio
|
| 403 |
+
async def test_llm_call_does_not_force_no_reasoning_for_osmapi(monkeypatch):
|
| 404 |
import civicsetu.agent.nodes as nodes_mod
|
| 405 |
|
| 406 |
monkeypatch.setenv("OPENAI_API_BASE", "https://api.osmapi.com/v1")
|
|
|
|
| 412 |
original_models = nodes_mod.THINKING_MODELS[:]
|
| 413 |
nodes_mod.THINKING_MODELS[:] = ["openai/gpt-4o-mini"]
|
| 414 |
|
| 415 |
+
with patch("civicsetu.agent.nodes.litellm.acompletion", new=MagicMock(return_value=asyncio.Future())) as completion:
|
| 416 |
+
completion.return_value.set_result(fake_response)
|
| 417 |
try:
|
| 418 |
+
await nodes_mod._llm_call("prompt", "system")
|
| 419 |
finally:
|
| 420 |
nodes_mod.THINKING_MODELS[:] = original_models
|
| 421 |
|
| 422 |
assert "extra_body" not in completion.call_args.kwargs
|
| 423 |
|
| 424 |
|
| 425 |
+
@pytest.mark.asyncio
|
| 426 |
+
async def test_llm_call_does_not_attach_deepseek_kwargs_to_non_deepseek_models():
|
| 427 |
import civicsetu.agent.nodes as nodes_mod
|
| 428 |
|
| 429 |
fake_response = MagicMock()
|
|
|
|
| 433 |
original_models = nodes_mod.THINKING_MODELS[:]
|
| 434 |
nodes_mod.THINKING_MODELS[:] = ["groq/llama-3.3-70b-versatile"]
|
| 435 |
|
| 436 |
+
with patch("civicsetu.agent.nodes.litellm.acompletion", new=MagicMock(return_value=asyncio.Future())) as completion:
|
| 437 |
+
completion.return_value.set_result(fake_response)
|
| 438 |
try:
|
| 439 |
+
await nodes_mod._llm_call("prompt", "system", tier="thinking")
|
| 440 |
finally:
|
| 441 |
nodes_mod.THINKING_MODELS[:] = original_models
|
| 442 |
|
|
|
|
| 444 |
assert "extra_body" not in completion.call_args.kwargs
|
| 445 |
|
| 446 |
|
| 447 |
+
@pytest.mark.asyncio
|
| 448 |
+
async def test_generator_deduplicates_citations():
|
| 449 |
from civicsetu.agent.nodes import generator_node
|
| 450 |
|
| 451 |
# Two chunks with same (section_id, doc_name) — different chunk_ids
|
|
|
|
| 455 |
|
| 456 |
with patch("civicsetu.agent.nodes._llm_call", return_value=llm_out):
|
| 457 |
state = _base_state(reranked_chunks=[chunk_a, chunk_b])
|
| 458 |
+
result = await generator_node(state)
|
| 459 |
|
| 460 |
assert len(result["citations"]) == 1
|
| 461 |
|
| 462 |
|
| 463 |
+
@pytest.mark.asyncio
|
| 464 |
+
async def test_generator_system_prompt_uses_plain_language_persona():
|
| 465 |
from civicsetu.agent.nodes import generator_node
|
| 466 |
from civicsetu.models.enums import QueryType
|
| 467 |
|
| 468 |
captured = {}
|
| 469 |
|
| 470 |
+
async def fake_llm_call(prompt: str, system: str, temperature: float = 0.0, tier="thinking") -> str:
|
| 471 |
captured["prompt"] = prompt
|
| 472 |
captured["system"] = system
|
| 473 |
return _llm_response(answer="If a builder misses the deadline, the buyer gets a remedy.")
|
|
|
|
| 477 |
query_type=QueryType.FACT_LOOKUP,
|
| 478 |
reranked_chunks=[_make_rc(section_id="18")],
|
| 479 |
)
|
| 480 |
+
await generator_node(state)
|
| 481 |
|
| 482 |
assert "plain-language guide to Indian RERA laws" in captured["system"]
|
| 483 |
assert "explain what the law means in practice" in captured["system"]
|
|
|
|
| 496 |
("temporal", "Lead with the specific time period or deadline"),
|
| 497 |
],
|
| 498 |
)
|
| 499 |
+
@pytest.mark.asyncio
|
| 500 |
+
async def test_generator_system_prompt_includes_query_type_tone_hint(query_type, expected_hint):
|
| 501 |
from civicsetu.agent.nodes import generator_node
|
| 502 |
|
| 503 |
captured = {}
|
| 504 |
|
| 505 |
+
async def fake_llm_call(prompt: str, system: str, temperature: float = 0.0, tier="thinking") -> str:
|
| 506 |
captured["system"] = system
|
| 507 |
return _llm_response()
|
| 508 |
|
|
|
|
| 511 |
query_type=query_type,
|
| 512 |
reranked_chunks=[_make_rc(section_id="18")],
|
| 513 |
)
|
| 514 |
+
await generator_node(state)
|
| 515 |
|
| 516 |
assert expected_hint in captured["system"]
|
| 517 |
|
|
|
|
| 525 |
reranker_mod._ranker = None # clear module-level cache
|
| 526 |
|
| 527 |
with patch("civicsetu.retrieval.reranker.settings") as mock_settings:
|
| 528 |
+
mock_settings.reranker_model = "ms-marco-MiniLM-L-12-v2"
|
| 529 |
with patch("flashrank.Ranker") as MockRanker:
|
| 530 |
MockRanker.return_value = MagicMock()
|
| 531 |
reranker_mod._get_ranker()
|
| 532 |
MockRanker.assert_called_once_with(
|
| 533 |
+
model_name="ms-marco-MiniLM-L-12-v2", cache_dir=".cache/flashrank"
|
| 534 |
)
|
| 535 |
reranker_mod._ranker = None # clean up
|
| 536 |
|
|
|
|
| 625 |
|
| 626 |
# ── reranker_node threshold + gap filtering ───────────────────────────────────
|
| 627 |
|
| 628 |
+
@pytest.mark.asyncio
|
| 629 |
+
async def test_reranker_drops_below_threshold():
|
| 630 |
"""Chunks scoring below reranker_score_threshold must not appear in output."""
|
| 631 |
from civicsetu.agent.nodes import reranker_node
|
| 632 |
from unittest.mock import patch, MagicMock
|
|
|
|
| 650 |
reranked_chunks=[],
|
| 651 |
query="test query",
|
| 652 |
)
|
| 653 |
+
result = await reranker_node(state)
|
| 654 |
|
| 655 |
section_ids = [c.chunk.section_id for c in result["reranked_chunks"]]
|
| 656 |
assert "1" in section_ids
|
| 657 |
assert "2" not in section_ids
|
| 658 |
|
| 659 |
|
| 660 |
+
@pytest.mark.asyncio
|
| 661 |
+
async def test_reranker_applies_score_gap():
|
| 662 |
"""A large score cliff stops inclusion even if remaining chunks exceed threshold."""
|
| 663 |
from civicsetu.agent.nodes import reranker_node
|
| 664 |
from unittest.mock import patch, MagicMock
|
|
|
|
| 685 |
reranked_chunks=[],
|
| 686 |
query="test query",
|
| 687 |
)
|
| 688 |
+
result = await reranker_node(state)
|
| 689 |
|
| 690 |
section_ids = [c.chunk.section_id for c in result["reranked_chunks"]]
|
| 691 |
assert "1" in section_ids
|
|
|
|
| 693 |
assert "3" not in section_ids
|
| 694 |
|
| 695 |
|
| 696 |
+
@pytest.mark.asyncio
|
| 697 |
+
async def test_reranker_filtered_count_logged():
|
| 698 |
"""reranker_filtered log event must report correct before/after/dropped counts."""
|
| 699 |
from civicsetu.agent.nodes import reranker_node
|
| 700 |
from unittest.mock import patch, MagicMock
|
|
|
|
| 719 |
reranked_chunks=[],
|
| 720 |
query="test query",
|
| 721 |
)
|
| 722 |
+
result = await reranker_node(state)
|
| 723 |
|
| 724 |
# Verify reranked_chunks output
|
| 725 |
assert len(result["reranked_chunks"]) == 1
|
tests/unit/api/test_query_route.py
CHANGED
|
@@ -142,14 +142,14 @@ def test_graph_topology_returns_empty_payload_when_neo4j_auth_fails(client):
|
|
| 142 |
|
| 143 |
def test_query_returns_200_with_citations(client):
|
| 144 |
test_client, mock_graph = client
|
| 145 |
-
mock_graph.
|
| 146 |
"raw_response": "Under Section 18, the promoter must...",
|
| 147 |
"citations": [_make_citation("18")],
|
| 148 |
"confidence_score": 0.9,
|
| 149 |
"query_type": QueryType.CROSS_REFERENCE,
|
| 150 |
"conflict_warnings": [],
|
| 151 |
"amendment_notice": None,
|
| 152 |
-
}
|
| 153 |
|
| 154 |
response = test_client.post("/api/v1/query", json={"query": "What does Section 18 say?"})
|
| 155 |
assert response.status_code == 200
|
|
@@ -162,14 +162,14 @@ def test_query_returns_200_with_citations(client):
|
|
| 162 |
|
| 163 |
def test_query_returns_insufficient_when_no_citations(client):
|
| 164 |
test_client, mock_graph = client
|
| 165 |
-
mock_graph.
|
| 166 |
"raw_response": "I don't know",
|
| 167 |
"citations": [],
|
| 168 |
"confidence_score": 0.2,
|
| 169 |
"query_type": QueryType.FACT_LOOKUP,
|
| 170 |
"conflict_warnings": [],
|
| 171 |
"amendment_notice": None,
|
| 172 |
-
}
|
| 173 |
|
| 174 |
response = test_client.post("/api/v1/query", json={"query": "Some obscure question here"})
|
| 175 |
assert response.status_code == 200
|
|
@@ -194,14 +194,14 @@ def test_query_rejects_top_k_out_of_range(client):
|
|
| 194 |
|
| 195 |
def test_query_response_always_has_disclaimer(client):
|
| 196 |
test_client, mock_graph = client
|
| 197 |
-
mock_graph.
|
| 198 |
"raw_response": "Under Section 18...",
|
| 199 |
"citations": [_make_citation()],
|
| 200 |
"confidence_score": 0.8,
|
| 201 |
"query_type": QueryType.FACT_LOOKUP,
|
| 202 |
"conflict_warnings": [],
|
| 203 |
"amendment_notice": None,
|
| 204 |
-
}
|
| 205 |
|
| 206 |
response = test_client.post("/api/v1/query", json={"query": "What does Section 18 say?"})
|
| 207 |
body = response.json()
|
|
@@ -211,14 +211,14 @@ def test_query_response_always_has_disclaimer(client):
|
|
| 211 |
|
| 212 |
def test_query_reuses_provided_session_id(client):
|
| 213 |
test_client, mock_graph = client
|
| 214 |
-
mock_graph.
|
| 215 |
"raw_response": "Under Section 18...",
|
| 216 |
"citations": [_make_citation()],
|
| 217 |
"confidence_score": 0.8,
|
| 218 |
"query_type": QueryType.FACT_LOOKUP,
|
| 219 |
"conflict_warnings": [],
|
| 220 |
"amendment_notice": None,
|
| 221 |
-
}
|
| 222 |
|
| 223 |
response = test_client.post(
|
| 224 |
"/api/v1/query",
|
|
@@ -228,7 +228,7 @@ def test_query_reuses_provided_session_id(client):
|
|
| 228 |
body = response.json()
|
| 229 |
assert response.status_code == 200
|
| 230 |
assert body["session_id"] == "my-session-abc"
|
| 231 |
-
_, config = mock_graph.
|
| 232 |
assert config["configurable"]["thread_id"] == "my-session-abc"
|
| 233 |
|
| 234 |
|
|
@@ -393,14 +393,14 @@ def test_graph_section_content_resolves_graph_chunk_id_case_insensitive_status(c
|
|
| 393 |
|
| 394 |
def test_section_context_query_sets_skip_classifier_and_source_section(client):
|
| 395 |
test_client, mock_graph = client
|
| 396 |
-
mock_graph.
|
| 397 |
"raw_response": "Section 18 requires refund with interest.",
|
| 398 |
"citations": [_make_citation("18")],
|
| 399 |
"confidence_score": 0.88,
|
| 400 |
"query_type": QueryType.CROSS_REFERENCE,
|
| 401 |
"conflict_warnings": [],
|
| 402 |
"amendment_notice": None,
|
| 403 |
-
}
|
| 404 |
|
| 405 |
response = test_client.post(
|
| 406 |
"/api/v1/query/section-context",
|
|
@@ -415,7 +415,7 @@ def test_section_context_query_sets_skip_classifier_and_source_section(client):
|
|
| 415 |
assert response.status_code == 200
|
| 416 |
body = response.json()
|
| 417 |
assert body["session_id"] == "section-thread-1"
|
| 418 |
-
initial_state, config = mock_graph.
|
| 419 |
assert initial_state["skip_classifier"] is True
|
| 420 |
assert initial_state["source_section_id"] == "18"
|
| 421 |
assert initial_state["query_type"] == QueryType.CROSS_REFERENCE
|
|
|
|
| 142 |
|
| 143 |
def test_query_returns_200_with_citations(client):
|
| 144 |
test_client, mock_graph = client
|
| 145 |
+
mock_graph.ainvoke = AsyncMock(return_value={
|
| 146 |
"raw_response": "Under Section 18, the promoter must...",
|
| 147 |
"citations": [_make_citation("18")],
|
| 148 |
"confidence_score": 0.9,
|
| 149 |
"query_type": QueryType.CROSS_REFERENCE,
|
| 150 |
"conflict_warnings": [],
|
| 151 |
"amendment_notice": None,
|
| 152 |
+
})
|
| 153 |
|
| 154 |
response = test_client.post("/api/v1/query", json={"query": "What does Section 18 say?"})
|
| 155 |
assert response.status_code == 200
|
|
|
|
| 162 |
|
| 163 |
def test_query_returns_insufficient_when_no_citations(client):
|
| 164 |
test_client, mock_graph = client
|
| 165 |
+
mock_graph.ainvoke = AsyncMock(return_value={
|
| 166 |
"raw_response": "I don't know",
|
| 167 |
"citations": [],
|
| 168 |
"confidence_score": 0.2,
|
| 169 |
"query_type": QueryType.FACT_LOOKUP,
|
| 170 |
"conflict_warnings": [],
|
| 171 |
"amendment_notice": None,
|
| 172 |
+
})
|
| 173 |
|
| 174 |
response = test_client.post("/api/v1/query", json={"query": "Some obscure question here"})
|
| 175 |
assert response.status_code == 200
|
|
|
|
| 194 |
|
| 195 |
def test_query_response_always_has_disclaimer(client):
|
| 196 |
test_client, mock_graph = client
|
| 197 |
+
mock_graph.ainvoke = AsyncMock(return_value={
|
| 198 |
"raw_response": "Under Section 18...",
|
| 199 |
"citations": [_make_citation()],
|
| 200 |
"confidence_score": 0.8,
|
| 201 |
"query_type": QueryType.FACT_LOOKUP,
|
| 202 |
"conflict_warnings": [],
|
| 203 |
"amendment_notice": None,
|
| 204 |
+
})
|
| 205 |
|
| 206 |
response = test_client.post("/api/v1/query", json={"query": "What does Section 18 say?"})
|
| 207 |
body = response.json()
|
|
|
|
| 211 |
|
| 212 |
def test_query_reuses_provided_session_id(client):
|
| 213 |
test_client, mock_graph = client
|
| 214 |
+
mock_graph.ainvoke = AsyncMock(return_value={
|
| 215 |
"raw_response": "Under Section 18...",
|
| 216 |
"citations": [_make_citation()],
|
| 217 |
"confidence_score": 0.8,
|
| 218 |
"query_type": QueryType.FACT_LOOKUP,
|
| 219 |
"conflict_warnings": [],
|
| 220 |
"amendment_notice": None,
|
| 221 |
+
})
|
| 222 |
|
| 223 |
response = test_client.post(
|
| 224 |
"/api/v1/query",
|
|
|
|
| 228 |
body = response.json()
|
| 229 |
assert response.status_code == 200
|
| 230 |
assert body["session_id"] == "my-session-abc"
|
| 231 |
+
_, config = mock_graph.ainvoke.call_args.args
|
| 232 |
assert config["configurable"]["thread_id"] == "my-session-abc"
|
| 233 |
|
| 234 |
|
|
|
|
| 393 |
|
| 394 |
def test_section_context_query_sets_skip_classifier_and_source_section(client):
|
| 395 |
test_client, mock_graph = client
|
| 396 |
+
mock_graph.ainvoke = AsyncMock(return_value={
|
| 397 |
"raw_response": "Section 18 requires refund with interest.",
|
| 398 |
"citations": [_make_citation("18")],
|
| 399 |
"confidence_score": 0.88,
|
| 400 |
"query_type": QueryType.CROSS_REFERENCE,
|
| 401 |
"conflict_warnings": [],
|
| 402 |
"amendment_notice": None,
|
| 403 |
+
})
|
| 404 |
|
| 405 |
response = test_client.post(
|
| 406 |
"/api/v1/query/section-context",
|
|
|
|
| 415 |
assert response.status_code == 200
|
| 416 |
body = response.json()
|
| 417 |
assert body["session_id"] == "section-thread-1"
|
| 418 |
+
initial_state, config = mock_graph.ainvoke.call_args.args
|
| 419 |
assert initial_state["skip_classifier"] is True
|
| 420 |
assert initial_state["source_section_id"] == "18"
|
| 421 |
assert initial_state["query_type"] == QueryType.CROSS_REFERENCE
|
tests/unit/stores/test_graph_store.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from contextlib import asynccontextmanager
|
| 4 |
+
from unittest.mock import AsyncMock
|
| 5 |
+
|
| 6 |
+
import pytest
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def _make_driver(session):
|
| 10 |
+
driver = AsyncMock()
|
| 11 |
+
|
| 12 |
+
@asynccontextmanager
|
| 13 |
+
async def _session_cm():
|
| 14 |
+
yield session
|
| 15 |
+
|
| 16 |
+
driver.session = _session_cm
|
| 17 |
+
return driver
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@pytest.mark.asyncio
|
| 21 |
+
async def test_graph_topology_retries_after_transient_driver_error(monkeypatch):
|
| 22 |
+
from civicsetu.stores import graph_store
|
| 23 |
+
|
| 24 |
+
first_session = AsyncMock()
|
| 25 |
+
first_session.run.side_effect = RuntimeError("the connection is closed")
|
| 26 |
+
|
| 27 |
+
edges_result = AsyncMock()
|
| 28 |
+
edges_result.data.return_value = [{"source": "a", "target": "b", "edge_type": "REFERENCES"}]
|
| 29 |
+
nodes_result = AsyncMock()
|
| 30 |
+
nodes_result.data.return_value = [{"chunk_id": "a", "section_id": "18"}]
|
| 31 |
+
|
| 32 |
+
second_session = AsyncMock()
|
| 33 |
+
second_session.run.side_effect = [edges_result, nodes_result]
|
| 34 |
+
|
| 35 |
+
first_driver = _make_driver(first_session)
|
| 36 |
+
second_driver = _make_driver(second_session)
|
| 37 |
+
get_driver = AsyncMock(side_effect=[first_driver, second_driver])
|
| 38 |
+
reset_driver = AsyncMock()
|
| 39 |
+
|
| 40 |
+
monkeypatch.setattr(graph_store, "_get_driver", get_driver)
|
| 41 |
+
monkeypatch.setattr(graph_store, "_reset_driver", reset_driver)
|
| 42 |
+
|
| 43 |
+
nodes, edges = await graph_store.GraphStore.get_topology()
|
| 44 |
+
|
| 45 |
+
assert nodes == [{"chunk_id": "a", "section_id": "18"}]
|
| 46 |
+
assert edges == [{"source": "a", "target": "b", "edge_type": "REFERENCES"}]
|
| 47 |
+
assert get_driver.await_count == 2
|
| 48 |
+
reset_driver.assert_awaited_once()
|