adeshboudh16 commited on
Commit
caab91b
·
1 Parent(s): 5d11dcd

fix: recover neo4j graph connections

Browse files
src/civicsetu/agent/nodes.py CHANGED
@@ -64,7 +64,7 @@ FAST_MODELS = [
64
  FALLBACK_MODELS = THINKING_MODELS
65
 
66
 
67
- def turn_reset_node(state: CivicSetuState) -> dict:
68
  """
69
  Clear per-turn fields while preserving session-scoped inputs and messages.
70
  """
@@ -85,9 +85,9 @@ def turn_reset_node(state: CivicSetuState) -> dict:
85
  }
86
 
87
 
88
- def _llm_call(prompt: str, system: str, temperature: float = 0.0, tier: str = "thinking") -> str:
89
  """
90
- Call an LLM with fallback chain.
91
 
92
  Args:
93
  tier: "thinking" for deep-reasoning tasks (generator),
@@ -146,7 +146,7 @@ def _llm_call(prompt: str, system: str, temperature: float = 0.0, tier: str = "t
146
  if "osmapi.com" in os.environ.get("OPENAI_API_BASE", ""):
147
  completion_kwargs["response_format"] = {"type": "json_object"}
148
 
149
- response = litellm.completion(**completion_kwargs)
150
 
151
  duration_ms = round((time.perf_counter() - start) * 1000, 2)
152
  content = response.choices[0].message.content
@@ -323,7 +323,7 @@ def _sort_pinned_family(family: list[RetrievedChunk], hint: str) -> list[Retriev
323
  return sorted(family, key=score, reverse=True)
324
 
325
 
326
- def _prepend_pinned_sections(state: CivicSetuState, chunks: list[RetrievedChunk]) -> list[RetrievedChunk]:
327
  pinned_refs = state.get("pinned_section_refs")
328
  if not pinned_refs:
329
  return chunks
@@ -346,12 +346,12 @@ def _prepend_pinned_sections(state: CivicSetuState, chunks: list[RetrievedChunk]
346
  c.is_pinned = True
347
  existing_matches.append(c)
348
 
349
- pinned = asyncio.run(_fetch_pinned_sections(
350
  pinned_refs,
351
  jurisdiction,
352
  chunks,
353
  hint,
354
- ))
355
  promoted = existing_matches + pinned
356
  if promoted:
357
  promoted_ids = {str(c.chunk.chunk_id) for c in promoted}
@@ -367,7 +367,7 @@ def _prepend_pinned_sections(state: CivicSetuState, chunks: list[RetrievedChunk]
367
 
368
  # ── Node 1: Classifier ─────────────────────────────────────────────────────────
369
 
370
- def classifier_node(state: CivicSetuState) -> dict:
371
  """
372
  Classifies query type and rewrites the query for better retrieval.
373
  Returns: query_type, rewritten_query
@@ -392,7 +392,7 @@ def classifier_node(state: CivicSetuState) -> dict:
392
  )
393
 
394
  try:
395
- raw = _llm_call(prompt, system, tier="fast")
396
  result = _extract_json_dict(raw)
397
 
398
  query_type_str = result.get("query_type", "fact_lookup")
@@ -423,7 +423,7 @@ async def _rrf_retrieve(
423
  return await VectorRetriever.retrieve(query, query_embedding, top_k, jurisdiction)
424
 
425
 
426
- def vector_retrieval_node(state: CivicSetuState) -> dict:
427
  from civicsetu.retrieval.vector_retriever import VectorRetriever
428
 
429
  query = state.get("rewritten_query") or state["query"]
@@ -434,12 +434,12 @@ def vector_retrieval_node(state: CivicSetuState) -> dict:
434
  log.info("vector_retrieval_node", query=query[:80], top_k=top_k)
435
 
436
  embed_start = time.perf_counter()
437
- query_embedding = cached_embed(query)
438
  log.info("stage_timing", node="vector_retrieval", stage="embedding",
439
  duration_ms=round((time.perf_counter() - embed_start) * 1000, 2))
440
 
441
  retrieve_start = time.perf_counter()
442
- chunks = asyncio.run(VectorRetriever.retrieve(query, query_embedding, top_k, jurisdiction))
443
  log.info("stage_timing", node="vector_retrieval", stage="postgres_retrieval",
444
  duration_ms=round((time.perf_counter() - retrieve_start) * 1000, 2))
445
  # Fix 4: section-ID-aware direct lookup.
@@ -478,7 +478,7 @@ def vector_retrieval_node(state: CivicSetuState) -> dict:
478
  extra.append(fc)
479
  return extra
480
 
481
- direct_chunks = asyncio.run(_fetch_sections())
482
  if direct_chunks:
483
  log.info("vector_section_direct_lookup",
484
  section_ids=list(section_ids),
@@ -490,7 +490,7 @@ def vector_retrieval_node(state: CivicSetuState) -> dict:
490
  # Section 5 timeline) — adding them ensures cross-Act context is available.
491
  if jurisdiction and jurisdiction != Jurisdiction.CENTRAL:
492
  central_start = time.perf_counter()
493
- central_chunks = asyncio.run(VectorRetriever.retrieve(query, query_embedding, top_k, Jurisdiction.CENTRAL))
494
  seen = {str(c.chunk.chunk_id) for c in chunks}
495
  extra_central = [c for c in central_chunks if str(c.chunk.chunk_id) not in seen]
496
  if extra_central:
@@ -500,13 +500,13 @@ def vector_retrieval_node(state: CivicSetuState) -> dict:
500
  log.info("stage_timing", node="vector_retrieval", stage="central_supplement",
501
  duration_ms=round((time.perf_counter() - central_start) * 1000, 2))
502
 
503
- chunks = _prepend_pinned_sections(state, chunks)
504
  log.info("node_timing", node="vector_retrieval",
505
  duration_ms=round((time.perf_counter() - node_start) * 1000, 2), results=len(chunks))
506
  return {"retrieved_chunks": chunks}
507
 
508
 
509
- def graph_retrieval_node(state: CivicSetuState) -> dict:
510
  """
511
  Graph-based retrieval for cross_reference and temporal queries.
512
  Traverses REFERENCES edges in Neo4j then hydrates chunks from pgvector.
@@ -530,7 +530,7 @@ def graph_retrieval_node(state: CivicSetuState) -> dict:
530
  )
531
 
532
  retrieve_start = time.perf_counter()
533
- chunks = asyncio.run(_retrieve())
534
  log.info("graph_retrieval_complete", results=len(chunks))
535
  log.info("stage_timing", node="graph_retrieval", stage="neo4j_postgres_hydration", duration_ms=round((time.perf_counter() - retrieve_start) * 1000, 2))
536
 
@@ -539,11 +539,11 @@ def graph_retrieval_node(state: CivicSetuState) -> dict:
539
  if not chunks:
540
  log.info("graph_retrieval_fallback_to_rrf", query=query[:80])
541
  embed_start = time.perf_counter()
542
- query_embedding = cached_embed(query)
543
  log.info("stage_timing", node="graph_retrieval", stage="fallback_embedding", duration_ms=round((time.perf_counter() - embed_start) * 1000, 2))
544
 
545
  fallback_start = time.perf_counter()
546
- chunks = asyncio.run(_rrf_retrieve(query, query_embedding, top_k, jurisdiction))
547
  log.info("stage_timing", node="graph_retrieval", stage="fallback_rrf_search", duration_ms=round((time.perf_counter() - fallback_start) * 1000, 2))
548
  log.info("graph_fallback_rrf_results", count=len(chunks))
549
 
@@ -551,7 +551,7 @@ def graph_retrieval_node(state: CivicSetuState) -> dict:
551
  # retry without jurisdiction filter to pick up Central Act chunks.
552
  if not chunks and jurisdiction:
553
  log.info("graph_retrieval_fallback_no_jurisdiction", query=query[:80])
554
- chunks = asyncio.run(_rrf_retrieve(query, query_embedding, top_k, None))
555
  log.info("graph_fallback_no_jurisdiction_results", count=len(chunks))
556
  # Fix 5: section-ID-aware direct lookup in graph retrieval path.
557
  # Use original query when it has explicit sections (XREF: "Section 18 refund")
@@ -593,7 +593,7 @@ def graph_retrieval_node(state: CivicSetuState) -> dict:
593
  extra.append(fc)
594
  return extra
595
 
596
- direct_chunks = asyncio.run(_fetch_sections_graph())
597
  if direct_chunks:
598
  log.info("graph_section_direct_lookup",
599
  section_ids=list(section_ids),
@@ -607,8 +607,8 @@ def graph_retrieval_node(state: CivicSetuState) -> dict:
607
  # only fires when the state family is EMPTY, missing this mismatch case.
608
  if jurisdiction and jurisdiction != Jurisdiction.CENTRAL:
609
  central_supplement_start = time.perf_counter()
610
- _query_embedding = cached_embed(query)
611
- central_chunks = asyncio.run(_VR.retrieve(query, _query_embedding, top_k, Jurisdiction.CENTRAL))
612
  seen = {str(c.chunk.chunk_id) for c in chunks}
613
  extra_central = [c for c in central_chunks if str(c.chunk.chunk_id) not in seen]
614
  if extra_central:
@@ -617,7 +617,7 @@ def graph_retrieval_node(state: CivicSetuState) -> dict:
617
  log.info("stage_timing", node="graph_retrieval", stage="central_supplement",
618
  duration_ms=round((time.perf_counter() - central_supplement_start) * 1000, 2))
619
 
620
- chunks = _prepend_pinned_sections(state, chunks)
621
  _MAX_GRAPH_CHUNKS = 25
622
  if len(chunks) > _MAX_GRAPH_CHUNKS:
623
  log.warning(
@@ -633,7 +633,7 @@ def graph_retrieval_node(state: CivicSetuState) -> dict:
633
 
634
  # ── Node 3: Reranker ───────────────────────────────────────────────────────────
635
 
636
- def reranker_node(state: CivicSetuState) -> dict:
637
  from civicsetu.retrieval.reranker import Reranker
638
 
639
  chunks = state.get("retrieved_chunks", [])
@@ -645,7 +645,19 @@ def reranker_node(state: CivicSetuState) -> dict:
645
  duration_ms=round((time.perf_counter() - node_start) * 1000, 2), reranked=0)
646
  return {"reranked_chunks": []}
647
 
648
- reranked = Reranker.rerank(chunks, query)
 
 
 
 
 
 
 
 
 
 
 
 
649
 
650
  log.info("reranker_complete", reranked=len(reranked))
651
  log.info("node_timing", node="reranker",
@@ -655,7 +667,7 @@ def reranker_node(state: CivicSetuState) -> dict:
655
 
656
  # ── Node 4: Generator ──────────────────────────────────────────────────────────
657
 
658
- def generator_node(state: CivicSetuState) -> dict:
659
  """
660
  Generates a cited answer from reranked chunks.
661
  Citations are anchored to cited_chunks indices returned by the LLM —
@@ -739,7 +751,7 @@ def generator_node(state: CivicSetuState) -> dict:
739
  try:
740
  llm_start = time.perf_counter()
741
  log.info("generator_debug_start", query=query[:50], chunks=len(chunks), model=THINKING_MODELS[0])
742
- raw = _llm_call(prompt, system, temperature=0.0, tier="thinking")
743
  log.info("generator_debug_raw", raw_type=type(raw).__name__, raw_len=len(raw), raw_preview=raw[:200])
744
  log.info("stage_timing", node="generator", stage="llm", duration_ms=round((time.perf_counter() - llm_start) * 1000, 2))
745
  result = _extract_json_dict(raw)
@@ -846,7 +858,7 @@ def generator_node(state: CivicSetuState) -> dict:
846
 
847
  # ── Node 5: Validator ──────────────────────────────────────────────────────────
848
 
849
- def validator_node(state: CivicSetuState) -> dict:
850
  answer = state.get("raw_response", "")
851
  chunks = state.get("reranked_chunks", [])
852
  node_start = time.perf_counter()
@@ -865,7 +877,7 @@ def validator_node(state: CivicSetuState) -> dict:
865
  }
866
 
867
 
868
- def hybrid_retrieval_node(state: CivicSetuState) -> dict:
869
  """
870
  Used for conflict_detection queries.
871
  Runs vector + graph retrieval in parallel, merges results.
@@ -881,7 +893,7 @@ def hybrid_retrieval_node(state: CivicSetuState) -> dict:
881
  log.info("hybrid_retrieval_node", query=query[:80])
882
 
883
  embed_start = time.perf_counter()
884
- query_embedding = cached_embed(query)
885
  log.info("stage_timing", node="hybrid_retrieval", stage="embedding", duration_ms=round((time.perf_counter() - embed_start) * 1000, 2))
886
 
887
  async def _retrieve():
@@ -922,11 +934,11 @@ def hybrid_retrieval_node(state: CivicSetuState) -> dict:
922
  return v_chunks, g_chunks, extra_central
923
 
924
  retrieve_start = time.perf_counter()
925
- v_chunks, g_chunks, extra_central = asyncio.run(_retrieve())
926
  log.info("stage_timing", node="hybrid_retrieval", stage="vector_graph_parallel", duration_ms=round((time.perf_counter() - retrieve_start) * 1000, 2))
927
 
928
  all_chunks = v_chunks + g_chunks + extra_central
929
- all_chunks = _prepend_pinned_sections(state, all_chunks)
930
  log.info(
931
  "hybrid_retrieval_complete",
932
  vector_chunks=len(v_chunks),
 
64
  FALLBACK_MODELS = THINKING_MODELS
65
 
66
 
67
+ async def turn_reset_node(state: CivicSetuState) -> dict:
68
  """
69
  Clear per-turn fields while preserving session-scoped inputs and messages.
70
  """
 
85
  }
86
 
87
 
88
+ async def _llm_call(prompt: str, system: str, temperature: float = 0.0, tier: str = "thinking") -> str:
89
  """
90
+ Call an LLM with fallback chain asynchronously.
91
 
92
  Args:
93
  tier: "thinking" for deep-reasoning tasks (generator),
 
146
  if "osmapi.com" in os.environ.get("OPENAI_API_BASE", ""):
147
  completion_kwargs["response_format"] = {"type": "json_object"}
148
 
149
+ response = await litellm.acompletion(**completion_kwargs)
150
 
151
  duration_ms = round((time.perf_counter() - start) * 1000, 2)
152
  content = response.choices[0].message.content
 
323
  return sorted(family, key=score, reverse=True)
324
 
325
 
326
+ async def _prepend_pinned_sections(state: CivicSetuState, chunks: list[RetrievedChunk]) -> list[RetrievedChunk]:
327
  pinned_refs = state.get("pinned_section_refs")
328
  if not pinned_refs:
329
  return chunks
 
346
  c.is_pinned = True
347
  existing_matches.append(c)
348
 
349
+ pinned = await _fetch_pinned_sections(
350
  pinned_refs,
351
  jurisdiction,
352
  chunks,
353
  hint,
354
+ )
355
  promoted = existing_matches + pinned
356
  if promoted:
357
  promoted_ids = {str(c.chunk.chunk_id) for c in promoted}
 
367
 
368
  # ── Node 1: Classifier ─────────────────────────────────────────────────────────
369
 
370
+ async def classifier_node(state: CivicSetuState) -> dict:
371
  """
372
  Classifies query type and rewrites the query for better retrieval.
373
  Returns: query_type, rewritten_query
 
392
  )
393
 
394
  try:
395
+ raw = await _llm_call(prompt, system, tier="fast")
396
  result = _extract_json_dict(raw)
397
 
398
  query_type_str = result.get("query_type", "fact_lookup")
 
423
  return await VectorRetriever.retrieve(query, query_embedding, top_k, jurisdiction)
424
 
425
 
426
+ async def vector_retrieval_node(state: CivicSetuState) -> dict:
427
  from civicsetu.retrieval.vector_retriever import VectorRetriever
428
 
429
  query = state.get("rewritten_query") or state["query"]
 
434
  log.info("vector_retrieval_node", query=query[:80], top_k=top_k)
435
 
436
  embed_start = time.perf_counter()
437
+ query_embedding = await asyncio.to_thread(cached_embed, query)
438
  log.info("stage_timing", node="vector_retrieval", stage="embedding",
439
  duration_ms=round((time.perf_counter() - embed_start) * 1000, 2))
440
 
441
  retrieve_start = time.perf_counter()
442
+ chunks = await VectorRetriever.retrieve(query, query_embedding, top_k, jurisdiction)
443
  log.info("stage_timing", node="vector_retrieval", stage="postgres_retrieval",
444
  duration_ms=round((time.perf_counter() - retrieve_start) * 1000, 2))
445
  # Fix 4: section-ID-aware direct lookup.
 
478
  extra.append(fc)
479
  return extra
480
 
481
+ direct_chunks = await _fetch_sections()
482
  if direct_chunks:
483
  log.info("vector_section_direct_lookup",
484
  section_ids=list(section_ids),
 
490
  # Section 5 timeline) — adding them ensures cross-Act context is available.
491
  if jurisdiction and jurisdiction != Jurisdiction.CENTRAL:
492
  central_start = time.perf_counter()
493
+ central_chunks = await VectorRetriever.retrieve(query, query_embedding, top_k, Jurisdiction.CENTRAL)
494
  seen = {str(c.chunk.chunk_id) for c in chunks}
495
  extra_central = [c for c in central_chunks if str(c.chunk.chunk_id) not in seen]
496
  if extra_central:
 
500
  log.info("stage_timing", node="vector_retrieval", stage="central_supplement",
501
  duration_ms=round((time.perf_counter() - central_start) * 1000, 2))
502
 
503
+ chunks = await _prepend_pinned_sections(state, chunks)
504
  log.info("node_timing", node="vector_retrieval",
505
  duration_ms=round((time.perf_counter() - node_start) * 1000, 2), results=len(chunks))
506
  return {"retrieved_chunks": chunks}
507
 
508
 
509
+ async def graph_retrieval_node(state: CivicSetuState) -> dict:
510
  """
511
  Graph-based retrieval for cross_reference and temporal queries.
512
  Traverses REFERENCES edges in Neo4j then hydrates chunks from pgvector.
 
530
  )
531
 
532
  retrieve_start = time.perf_counter()
533
+ chunks = await _retrieve()
534
  log.info("graph_retrieval_complete", results=len(chunks))
535
  log.info("stage_timing", node="graph_retrieval", stage="neo4j_postgres_hydration", duration_ms=round((time.perf_counter() - retrieve_start) * 1000, 2))
536
 
 
539
  if not chunks:
540
  log.info("graph_retrieval_fallback_to_rrf", query=query[:80])
541
  embed_start = time.perf_counter()
542
+ query_embedding = await asyncio.to_thread(cached_embed, query)
543
  log.info("stage_timing", node="graph_retrieval", stage="fallback_embedding", duration_ms=round((time.perf_counter() - embed_start) * 1000, 2))
544
 
545
  fallback_start = time.perf_counter()
546
+ chunks = await _rrf_retrieve(query, query_embedding, top_k, jurisdiction)
547
  log.info("stage_timing", node="graph_retrieval", stage="fallback_rrf_search", duration_ms=round((time.perf_counter() - fallback_start) * 1000, 2))
548
  log.info("graph_fallback_rrf_results", count=len(chunks))
549
 
 
551
  # retry without jurisdiction filter to pick up Central Act chunks.
552
  if not chunks and jurisdiction:
553
  log.info("graph_retrieval_fallback_no_jurisdiction", query=query[:80])
554
+ chunks = await _rrf_retrieve(query, query_embedding, top_k, None)
555
  log.info("graph_fallback_no_jurisdiction_results", count=len(chunks))
556
  # Fix 5: section-ID-aware direct lookup in graph retrieval path.
557
  # Use original query when it has explicit sections (XREF: "Section 18 refund")
 
593
  extra.append(fc)
594
  return extra
595
 
596
+ direct_chunks = await _fetch_sections_graph()
597
  if direct_chunks:
598
  log.info("graph_section_direct_lookup",
599
  section_ids=list(section_ids),
 
607
  # only fires when the state family is EMPTY, missing this mismatch case.
608
  if jurisdiction and jurisdiction != Jurisdiction.CENTRAL:
609
  central_supplement_start = time.perf_counter()
610
+ _query_embedding = await asyncio.to_thread(cached_embed, query)
611
+ central_chunks = await _VR.retrieve(query, _query_embedding, top_k, Jurisdiction.CENTRAL)
612
  seen = {str(c.chunk.chunk_id) for c in chunks}
613
  extra_central = [c for c in central_chunks if str(c.chunk.chunk_id) not in seen]
614
  if extra_central:
 
617
  log.info("stage_timing", node="graph_retrieval", stage="central_supplement",
618
  duration_ms=round((time.perf_counter() - central_supplement_start) * 1000, 2))
619
 
620
+ chunks = await _prepend_pinned_sections(state, chunks)
621
  _MAX_GRAPH_CHUNKS = 25
622
  if len(chunks) > _MAX_GRAPH_CHUNKS:
623
  log.warning(
 
633
 
634
  # ── Node 3: Reranker ───────────────────────────────────────────────────────────
635
 
636
+ async def reranker_node(state: CivicSetuState) -> dict:
637
  from civicsetu.retrieval.reranker import Reranker
638
 
639
  chunks = state.get("retrieved_chunks", [])
 
645
  duration_ms=round((time.perf_counter() - node_start) * 1000, 2), reranked=0)
646
  return {"reranked_chunks": []}
647
 
648
+ # Evaluation context trimming: if pinned sections are provided,
649
+ # ONLY keep chunks marked as pinned (matched the target sections).
650
+ # Prevents the LLM from picking up nearby sections that happen to share
651
+ # keywords but aren't relevant to the specific subclause being tested.
652
+ pinned_refs = state.get("pinned_section_refs")
653
+ if pinned_refs:
654
+ filtered = [rc for rc in chunks if rc.is_pinned]
655
+ if filtered:
656
+ log.info("reranker_eval_context_trimmed", before=len(chunks), after=len(filtered))
657
+ chunks = filtered
658
+
659
+ # Reranker is cpu-bound/blocking, but run in thread to keep event loop free
660
+ reranked = await asyncio.to_thread(Reranker.rerank, chunks, query)
661
 
662
  log.info("reranker_complete", reranked=len(reranked))
663
  log.info("node_timing", node="reranker",
 
667
 
668
  # ── Node 4: Generator ──────────────────────────────────────────────────────────
669
 
670
+ async def generator_node(state: CivicSetuState) -> dict:
671
  """
672
  Generates a cited answer from reranked chunks.
673
  Citations are anchored to cited_chunks indices returned by the LLM —
 
751
  try:
752
  llm_start = time.perf_counter()
753
  log.info("generator_debug_start", query=query[:50], chunks=len(chunks), model=THINKING_MODELS[0])
754
+ raw = await _llm_call(prompt, system, temperature=0.0, tier="thinking")
755
  log.info("generator_debug_raw", raw_type=type(raw).__name__, raw_len=len(raw), raw_preview=raw[:200])
756
  log.info("stage_timing", node="generator", stage="llm", duration_ms=round((time.perf_counter() - llm_start) * 1000, 2))
757
  result = _extract_json_dict(raw)
 
858
 
859
  # ── Node 5: Validator ──────────────────────────────────────────────────────────
860
 
861
+ async def validator_node(state: CivicSetuState) -> dict:
862
  answer = state.get("raw_response", "")
863
  chunks = state.get("reranked_chunks", [])
864
  node_start = time.perf_counter()
 
877
  }
878
 
879
 
880
+ async def hybrid_retrieval_node(state: CivicSetuState) -> dict:
881
  """
882
  Used for conflict_detection queries.
883
  Runs vector + graph retrieval in parallel, merges results.
 
893
  log.info("hybrid_retrieval_node", query=query[:80])
894
 
895
  embed_start = time.perf_counter()
896
+ query_embedding = await asyncio.to_thread(cached_embed, query)
897
  log.info("stage_timing", node="hybrid_retrieval", stage="embedding", duration_ms=round((time.perf_counter() - embed_start) * 1000, 2))
898
 
899
  async def _retrieve():
 
934
  return v_chunks, g_chunks, extra_central
935
 
936
  retrieve_start = time.perf_counter()
937
+ v_chunks, g_chunks, extra_central = await _retrieve()
938
  log.info("stage_timing", node="hybrid_retrieval", stage="vector_graph_parallel", duration_ms=round((time.perf_counter() - retrieve_start) * 1000, 2))
939
 
940
  all_chunks = v_chunks + g_chunks + extra_central
941
+ all_chunks = await _prepend_pinned_sections(state, all_chunks)
942
  log.info(
943
  "hybrid_retrieval_complete",
944
  vector_chunks=len(v_chunks),
src/civicsetu/api/routes/graph.py CHANGED
@@ -241,7 +241,7 @@ async def section_context_query(
241
 
242
  try:
243
  invoke_start = time.perf_counter()
244
- result = await asyncio.to_thread(graph.invoke, initial_state, config)
245
  log.info(
246
  "graph_invoke_complete",
247
  route="section_context",
@@ -255,8 +255,7 @@ async def section_context_query(
255
  if raw_response:
256
  try:
257
  update_start = time.perf_counter()
258
- await asyncio.to_thread(
259
- graph.update_state,
260
  config,
261
  {"messages": [ChatMessage(role="assistant", content=raw_response)]},
262
  )
 
241
 
242
  try:
243
  invoke_start = time.perf_counter()
244
+ result = await graph.ainvoke(initial_state, config)
245
  log.info(
246
  "graph_invoke_complete",
247
  route="section_context",
 
255
  if raw_response:
256
  try:
257
  update_start = time.perf_counter()
258
+ await graph.aupdate_state(
 
259
  config,
260
  {"messages": [ChatMessage(role="assistant", content=raw_response)]},
261
  )
src/civicsetu/api/routes/query.py CHANGED
@@ -1,6 +1,5 @@
1
  from __future__ import annotations
2
 
3
- import asyncio
4
  import time
5
  import uuid
6
 
@@ -9,7 +8,7 @@ from fastapi import APIRouter, HTTPException, Request
9
 
10
  from civicsetu.guardrails.input_guard import InputGuard
11
  from civicsetu.guardrails.output_guard import OutputGuard
12
- from civicsetu.models.schemas import CivicSetuResponse, ChatMessage, InsufficientInfoResponse, QueryRequest
13
 
14
  log = structlog.get_logger(__name__)
15
  router = APIRouter()
@@ -49,7 +48,7 @@ async def query_endpoint(request: Request, body: QueryRequest):
49
 
50
  try:
51
  invoke_start = time.perf_counter()
52
- result = await asyncio.to_thread(graph.invoke, initial_state, config)
53
  log.info(
54
  "graph_invoke_complete",
55
  route="query",
@@ -63,8 +62,7 @@ async def query_endpoint(request: Request, body: QueryRequest):
63
  if raw_response:
64
  try:
65
  update_start = time.perf_counter()
66
- await asyncio.to_thread(
67
- graph.update_state,
68
  config,
69
  {"messages": [ChatMessage(role="assistant", content=raw_response)]},
70
  )
 
1
  from __future__ import annotations
2
 
 
3
  import time
4
  import uuid
5
 
 
8
 
9
  from civicsetu.guardrails.input_guard import InputGuard
10
  from civicsetu.guardrails.output_guard import OutputGuard
11
+ from civicsetu.models.schemas import ChatMessage, CivicSetuResponse, InsufficientInfoResponse, QueryRequest
12
 
13
  log = structlog.get_logger(__name__)
14
  router = APIRouter()
 
48
 
49
  try:
50
  invoke_start = time.perf_counter()
51
+ result = await graph.ainvoke(initial_state, config)
52
  log.info(
53
  "graph_invoke_complete",
54
  route="query",
 
62
  if raw_response:
63
  try:
64
  update_start = time.perf_counter()
65
+ await graph.aupdate_state(
 
66
  config,
67
  {"messages": [ChatMessage(role="assistant", content=raw_response)]},
68
  )
src/civicsetu/stores/graph_store.py CHANGED
@@ -14,6 +14,14 @@ log = structlog.get_logger(__name__)
14
  _driver: AsyncDriver | None = None
15
  _driver_lock: asyncio.Lock = asyncio.Lock()
16
 
 
 
 
 
 
 
 
 
17
 
18
  async def _get_driver() -> AsyncDriver:
19
  """Cached driver with lazy init — avoids 25+ driver creations per graph retrieval."""
@@ -41,6 +49,39 @@ async def close_driver() -> None:
41
  _driver = None
42
  log.info("neo4j_driver_closed")
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  async def get_driver() -> AsyncDriver:
45
  """Public alias for main.py lifespan — warms the singleton at startup."""
46
  return await _get_driver()
@@ -283,90 +324,98 @@ class GraphStore:
283
  jurisdiction: str,
284
  depth: int = 1,
285
  ) -> list[dict]:
286
- driver = await _get_driver() # add await
287
- async with driver.session() as session:
288
- result = await session.run(
289
- f"""
290
- MATCH (src:Section {{section_id: $section_id, jurisdiction: $jurisdiction}})
291
- -[:REFERENCES*1..{depth}]->(tgt:Section)
292
- WHERE tgt.is_active = true
293
- RETURN DISTINCT tgt.section_id AS section_id,
294
- tgt.title AS title,
295
- tgt.chunk_id AS chunk_id,
296
- tgt.jurisdiction AS jurisdiction
297
- """,
298
- section_id=section_id,
299
- jurisdiction=jurisdiction,
300
- )
301
- return await result.data()
 
 
302
 
303
  @staticmethod
304
  async def get_sections_referencing(
305
  section_id: str,
306
  jurisdiction: str,
307
  ) -> list[dict]:
308
- driver = await _get_driver()
309
- async with driver.session() as session:
310
- result = await session.run(
311
- """
312
- MATCH (src:Section)-[:REFERENCES]->
313
- (tgt:Section {section_id: $section_id, jurisdiction: $jurisdiction})
314
- WHERE src.is_active = true
315
- RETURN DISTINCT src.section_id AS section_id,
316
- src.title AS title,
317
- src.chunk_id AS chunk_id,
318
- src.jurisdiction AS jurisdiction
319
- """,
320
- section_id=section_id,
321
- jurisdiction=jurisdiction,
322
- )
323
- return await result.data()
 
 
324
 
325
  @staticmethod
326
  async def get_derived_act_sections(
327
  rule_section_id: str,
328
  rule_jurisdiction: str,
329
  ) -> list[dict]:
330
- driver = await _get_driver()
331
- async with driver.session() as session:
332
- result = await session.run(
333
- """
334
- MATCH (rule_sec:Section {section_id: $section_id, jurisdiction: $jurisdiction})
335
- -[:DERIVED_FROM]->(act_sec:Section)
336
- WHERE act_sec.is_active = true
337
- RETURN DISTINCT act_sec.section_id AS section_id,
338
- act_sec.title AS title,
339
- act_sec.chunk_id AS chunk_id,
340
- act_sec.jurisdiction AS jurisdiction,
341
- act_sec.doc_name AS doc_name
342
- """,
343
- section_id=rule_section_id,
344
- jurisdiction=rule_jurisdiction,
345
- )
346
- return await result.data()
 
 
347
 
348
  @staticmethod
349
  async def get_deriving_rule_sections(
350
  act_section_id: str,
351
  act_jurisdiction: str = "CENTRAL",
352
  ) -> list[dict]:
353
- driver = await _get_driver()
354
- async with driver.session() as session:
355
- result = await session.run(
356
- """
357
- MATCH (rule_sec:Section)-[:DERIVED_FROM]->
358
- (act_sec:Section {section_id: $section_id, jurisdiction: $jurisdiction})
359
- WHERE rule_sec.is_active = true
360
- RETURN DISTINCT rule_sec.section_id AS section_id,
361
- rule_sec.title AS title,
362
- rule_sec.chunk_id AS chunk_id,
363
- rule_sec.jurisdiction AS jurisdiction,
364
- rule_sec.doc_name AS doc_name
365
- """,
366
- section_id=act_section_id,
367
- jurisdiction=act_jurisdiction,
368
- )
369
- return await result.data()
 
 
370
 
371
  @staticmethod
372
  async def get_sections_for_document(doc_id: str) -> list[dict]:
@@ -386,50 +435,54 @@ class GraphStore:
386
 
387
  @staticmethod
388
  async def get_topology() -> tuple[list[dict], list[dict]]:
389
- driver = await _get_driver()
390
- async with driver.session() as session:
391
- edges_result = await session.run(
392
- """
393
- MATCH (s:Section)-[r]->(t:Section)
394
- WHERE type(r) IN ['REFERENCES', 'DERIVED_FROM']
395
- AND s.is_active = true AND t.is_active = true
396
- RETURN s.chunk_id AS source, t.chunk_id AS target, type(r) AS edge_type
397
- """
398
- )
399
- nodes_result = await session.run(
400
- """
401
- MATCH (s:Section)-[r]-()
402
- WHERE type(r) IN ['REFERENCES', 'DERIVED_FROM']
403
- AND s.is_active = true
404
- WITH s, count(r) AS conn_count
405
- RETURN DISTINCT
406
- s.chunk_id AS chunk_id,
407
- s.section_id AS section_id,
408
- s.title AS title,
409
- s.jurisdiction AS jurisdiction,
410
- s.doc_name AS doc_name,
411
- s.is_active AS is_active,
412
- conn_count AS connection_count
413
- """
414
- )
415
- edges = await edges_result.data()
416
- nodes = await nodes_result.data()
417
- return nodes, edges
 
 
418
 
419
  @staticmethod
420
  async def graph_stats() -> dict:
421
- driver = await _get_driver()
422
- async with driver.session() as session:
423
- result = await session.run(
424
- """
425
- RETURN
426
- count { MATCH (d:Document) RETURN d } AS docs,
427
- count { MATCH (s:Section) RETURN s } AS sections,
428
- count { MATCH ()-[:REFERENCES]->() RETURN 1 } AS refs,
429
- count { MATCH ()-[:HAS_SECTION]->() RETURN 1 } AS has_sec,
430
- count { MATCH ()-[:DERIVED_FROM]->() RETURN 1 } AS derived_from
431
- """
432
- )
433
- record = await result.single()
434
- return dict(record) if record else {}
 
 
435
 
 
14
  _driver: AsyncDriver | None = None
15
  _driver_lock: asyncio.Lock = asyncio.Lock()
16
 
17
+ _TRANSIENT_CONNECTION_MARKERS = (
18
+ "defunct connection",
19
+ "the connection is closed",
20
+ "connection reset by peer",
21
+ "unable to retrieve routing information",
22
+ "service unavailable",
23
+ )
24
+
25
 
26
  async def _get_driver() -> AsyncDriver:
27
  """Cached driver with lazy init — avoids 25+ driver creations per graph retrieval."""
 
49
  _driver = None
50
  log.info("neo4j_driver_closed")
51
 
52
+
53
+ async def _reset_driver() -> None:
54
+ global _driver
55
+ async with _driver_lock:
56
+ if _driver is not None:
57
+ try:
58
+ await _driver.close()
59
+ finally:
60
+ _driver = None
61
+ log.info("neo4j_driver_reset")
62
+
63
+
64
+ def _is_transient_connection_error(error: Exception) -> bool:
65
+ message = str(error).lower()
66
+ return any(marker in message for marker in _TRANSIENT_CONNECTION_MARKERS)
67
+
68
+
69
+ async def _with_reconnected_driver(operation):
70
+ try:
71
+ driver = await _get_driver()
72
+ return await operation(driver)
73
+ except Exception as first_error:
74
+ if not _is_transient_connection_error(first_error):
75
+ raise
76
+
77
+ log.warning("neo4j_driver_reconnecting", error=str(first_error))
78
+ await _reset_driver()
79
+ driver = await _get_driver()
80
+ try:
81
+ return await operation(driver)
82
+ except Exception:
83
+ raise first_error
84
+
85
  async def get_driver() -> AsyncDriver:
86
  """Public alias for main.py lifespan — warms the singleton at startup."""
87
  return await _get_driver()
 
324
  jurisdiction: str,
325
  depth: int = 1,
326
  ) -> list[dict]:
327
+ async def _op(driver: AsyncDriver) -> list[dict]:
328
+ async with driver.session() as session:
329
+ result = await session.run(
330
+ f"""
331
+ MATCH (src:Section {{section_id: $section_id, jurisdiction: $jurisdiction}})
332
+ -[:REFERENCES*1..{depth}]->(tgt:Section)
333
+ WHERE tgt.is_active = true
334
+ RETURN DISTINCT tgt.section_id AS section_id,
335
+ tgt.title AS title,
336
+ tgt.chunk_id AS chunk_id,
337
+ tgt.jurisdiction AS jurisdiction
338
+ """,
339
+ section_id=section_id,
340
+ jurisdiction=jurisdiction,
341
+ )
342
+ return await result.data()
343
+
344
+ return await _with_reconnected_driver(_op)
345
 
346
  @staticmethod
347
  async def get_sections_referencing(
348
  section_id: str,
349
  jurisdiction: str,
350
  ) -> list[dict]:
351
+ async def _op(driver: AsyncDriver) -> list[dict]:
352
+ async with driver.session() as session:
353
+ result = await session.run(
354
+ """
355
+ MATCH (src:Section)-[:REFERENCES]->
356
+ (tgt:Section {section_id: $section_id, jurisdiction: $jurisdiction})
357
+ WHERE src.is_active = true
358
+ RETURN DISTINCT src.section_id AS section_id,
359
+ src.title AS title,
360
+ src.chunk_id AS chunk_id,
361
+ src.jurisdiction AS jurisdiction
362
+ """,
363
+ section_id=section_id,
364
+ jurisdiction=jurisdiction,
365
+ )
366
+ return await result.data()
367
+
368
+ return await _with_reconnected_driver(_op)
369
 
370
  @staticmethod
371
  async def get_derived_act_sections(
372
  rule_section_id: str,
373
  rule_jurisdiction: str,
374
  ) -> list[dict]:
375
+ async def _op(driver: AsyncDriver) -> list[dict]:
376
+ async with driver.session() as session:
377
+ result = await session.run(
378
+ """
379
+ MATCH (rule_sec:Section {section_id: $section_id, jurisdiction: $jurisdiction})
380
+ -[:DERIVED_FROM]->(act_sec:Section)
381
+ WHERE act_sec.is_active = true
382
+ RETURN DISTINCT act_sec.section_id AS section_id,
383
+ act_sec.title AS title,
384
+ act_sec.chunk_id AS chunk_id,
385
+ act_sec.jurisdiction AS jurisdiction,
386
+ act_sec.doc_name AS doc_name
387
+ """,
388
+ section_id=rule_section_id,
389
+ jurisdiction=rule_jurisdiction,
390
+ )
391
+ return await result.data()
392
+
393
+ return await _with_reconnected_driver(_op)
394
 
395
  @staticmethod
396
  async def get_deriving_rule_sections(
397
  act_section_id: str,
398
  act_jurisdiction: str = "CENTRAL",
399
  ) -> list[dict]:
400
+ async def _op(driver: AsyncDriver) -> list[dict]:
401
+ async with driver.session() as session:
402
+ result = await session.run(
403
+ """
404
+ MATCH (rule_sec:Section)-[:DERIVED_FROM]->
405
+ (act_sec:Section {section_id: $section_id, jurisdiction: $jurisdiction})
406
+ WHERE rule_sec.is_active = true
407
+ RETURN DISTINCT rule_sec.section_id AS section_id,
408
+ rule_sec.title AS title,
409
+ rule_sec.chunk_id AS chunk_id,
410
+ rule_sec.jurisdiction AS jurisdiction,
411
+ rule_sec.doc_name AS doc_name
412
+ """,
413
+ section_id=act_section_id,
414
+ jurisdiction=act_jurisdiction,
415
+ )
416
+ return await result.data()
417
+
418
+ return await _with_reconnected_driver(_op)
419
 
420
  @staticmethod
421
  async def get_sections_for_document(doc_id: str) -> list[dict]:
 
435
 
436
  @staticmethod
437
  async def get_topology() -> tuple[list[dict], list[dict]]:
438
+ async def _op(driver: AsyncDriver) -> tuple[list[dict], list[dict]]:
439
+ async with driver.session() as session:
440
+ edges_result = await session.run(
441
+ """
442
+ MATCH (s:Section)-[r]->(t:Section)
443
+ WHERE type(r) IN ['REFERENCES', 'DERIVED_FROM']
444
+ AND s.is_active = true AND t.is_active = true
445
+ RETURN s.chunk_id AS source, t.chunk_id AS target, type(r) AS edge_type
446
+ """
447
+ )
448
+ nodes_result = await session.run(
449
+ """
450
+ MATCH (s:Section)-[r]-()
451
+ WHERE type(r) IN ['REFERENCES', 'DERIVED_FROM']
452
+ AND s.is_active = true
453
+ WITH s, count(r) AS conn_count
454
+ RETURN DISTINCT
455
+ s.chunk_id AS chunk_id,
456
+ s.section_id AS section_id,
457
+ s.title AS title,
458
+ s.jurisdiction AS jurisdiction,
459
+ s.doc_name AS doc_name,
460
+ s.is_active AS is_active,
461
+ conn_count AS connection_count
462
+ """
463
+ )
464
+ edges = await edges_result.data()
465
+ nodes = await nodes_result.data()
466
+ return nodes, edges
467
+
468
+ return await _with_reconnected_driver(_op)
469
 
470
  @staticmethod
471
  async def graph_stats() -> dict:
472
+ async def _op(driver: AsyncDriver) -> dict:
473
+ async with driver.session() as session:
474
+ result = await session.run(
475
+ """
476
+ RETURN
477
+ count { MATCH (d:Document) RETURN d } AS docs,
478
+ count { MATCH (s:Section) RETURN s } AS sections,
479
+ count { MATCH ()-[:REFERENCES]->() RETURN 1 } AS refs,
480
+ count { MATCH ()-[:HAS_SECTION]->() RETURN 1 } AS has_sec,
481
+ count { MATCH ()-[:DERIVED_FROM]->() RETURN 1 } AS derived_from
482
+ """
483
+ )
484
+ record = await result.single()
485
+ return dict(record) if record else {}
486
+
487
+ return await _with_reconnected_driver(_op)
488
 
tests/unit/agent/test_nodes.py CHANGED
@@ -1,7 +1,7 @@
1
  from __future__ import annotations
2
 
 
3
  import json
4
- import uuid
5
  from unittest.mock import MagicMock, patch
6
 
7
  import pytest
@@ -14,9 +14,9 @@ def test_reranker_settings_defaults():
14
  """Settings ship with safe, sensible defaults — no .env needed."""
15
  from civicsetu.config.settings import Settings
16
  s = Settings()
17
- assert s.reranker_model == "rank-T5-flan"
18
  assert s.reranker_score_threshold == 0.05
19
- assert s.reranker_score_gap == 0.3
20
 
21
 
22
  def test_reranker_settings_env_override(monkeypatch):
@@ -51,13 +51,15 @@ def test_neo4j_username_env_alias_override(monkeypatch):
51
  get_settings.cache_clear()
52
 
53
 
54
- def test_reranker_empty_chunks():
 
55
  from civicsetu.agent.nodes import reranker_node
56
- result = reranker_node(_base_state(retrieved_chunks=[], reranked_chunks=[]))
57
  assert result["reranked_chunks"] == []
58
 
59
 
60
- def test_reranker_pinned_chunks_always_first():
 
61
  from civicsetu.agent.nodes import reranker_node
62
 
63
  pinned = _make_rc(section_id="18", is_pinned=True)
@@ -74,14 +76,15 @@ def test_reranker_pinned_chunks_always_first():
74
  reranked_chunks=[],
75
  query="test query",
76
  )
77
- result = reranker_node(state)
78
 
79
  reranked = result["reranked_chunks"]
80
  assert reranked[0].is_pinned is True
81
  assert reranked[0].chunk.section_id == "18"
82
 
83
 
84
- def test_reranker_keeps_pinned_chunks_up_to_context_limit():
 
85
  from civicsetu.agent.nodes import reranker_node
86
 
87
  pinned_chunks = [_make_rc(section_id=str(i), is_pinned=True) for i in range(4)]
@@ -95,13 +98,14 @@ def test_reranker_keeps_pinned_chunks_up_to_context_limit():
95
  reranked_chunks=[],
96
  query="test query",
97
  )
98
- result = reranker_node(state)
99
 
100
  pinned_in_result = [c for c in result["reranked_chunks"] if c.is_pinned]
101
  assert len(pinned_in_result) == 4
102
 
103
 
104
- def test_reranker_deduplicates_by_chunk_id():
 
105
  from civicsetu.agent.nodes import reranker_node
106
 
107
  chunk = _make_rc(section_id="18")
@@ -118,7 +122,7 @@ def test_reranker_deduplicates_by_chunk_id():
118
  reranked_chunks=[],
119
  query="test query",
120
  )
121
- result = reranker_node(state)
122
 
123
  all_ids = [str(c.chunk.chunk_id) for c in result["reranked_chunks"]]
124
  assert len(all_ids) == len(set(all_ids))
@@ -178,7 +182,8 @@ def test_pin_relevance_prefers_specific_subclauses_over_base_header_on_tie():
178
  assert [c.chunk.section_id for c in ranked] == ["7(6)", "7"]
179
 
180
 
181
- def test_prepend_pinned_sections_promotes_existing_matches(monkeypatch):
 
182
  from civicsetu.agent.nodes import _prepend_pinned_sections
183
  from civicsetu.models.enums import Jurisdiction
184
 
@@ -186,18 +191,17 @@ def test_prepend_pinned_sections_promotes_existing_matches(monkeypatch):
186
  rule_5 = _make_rc(section_id="5", jurisdiction=Jurisdiction.KARNATAKA)
187
  section_4 = _make_rc(section_id="4(14)", jurisdiction=Jurisdiction.CENTRAL)
188
 
189
- def fake_run(coro):
190
- coro.close()
191
  return []
192
 
193
- monkeypatch.setattr("civicsetu.agent.nodes.asyncio.run", fake_run)
194
 
195
- result = _prepend_pinned_sections(
196
- {
197
- "pinned_section_refs": ["Rule 5", "Section 4"],
198
- "pinned_section_jurisdiction": Jurisdiction.KARNATAKA,
199
- "pinned_section_hint": "separate bank account seventy per cent",
200
- },
201
  [noisy, rule_5, section_4],
202
  )
203
 
@@ -206,7 +210,8 @@ def test_prepend_pinned_sections_promotes_existing_matches(monkeypatch):
206
  assert result[1].is_pinned is True
207
 
208
 
209
- def test_reranker_node_trims_eval_context_to_pinned_families():
 
210
  from civicsetu.agent.nodes import reranker_node
211
  from civicsetu.models.enums import Jurisdiction
212
 
@@ -215,11 +220,12 @@ def test_reranker_node_trims_eval_context_to_pinned_families():
215
  noisy_central = _make_rc(section_id="4(10)", jurisdiction=Jurisdiction.CENTRAL)
216
  noisy_state = _make_rc(section_id="3", jurisdiction=Jurisdiction.UTTAR_PRADESH)
217
 
218
- with patch(
219
- "civicsetu.retrieval.reranker.Reranker.rerank",
220
- return_value=[matched_rule, matched_section, noisy_central, noisy_state],
221
- ):
222
- result = reranker_node(
 
223
  _base_state(
224
  query="Which Karnataka rule implements project account maintenance?",
225
  rewritten_query="Which Karnataka rule implements project account maintenance?",
@@ -244,7 +250,8 @@ def _llm_response(**kwargs) -> str:
244
  return json.dumps(payload)
245
 
246
 
247
- def test_generator_cites_only_referenced_chunks():
 
248
  from civicsetu.agent.nodes import generator_node
249
 
250
  chunks = [_make_rc(section_id=str(i)) for i in range(1, 6)]
@@ -256,14 +263,15 @@ def test_generator_cites_only_referenced_chunks():
256
  query="What does Section 1 say?",
257
  reranked_chunks=chunks,
258
  )
259
- result = generator_node(state)
260
 
261
  assert len(result["citations"]) == 2
262
  cited_ids = {c.section_id for c in result["citations"]}
263
  assert cited_ids == {"1", "3"}
264
 
265
 
266
- def test_generator_fallback_when_cited_chunks_empty():
 
267
  from civicsetu.agent.nodes import generator_node
268
 
269
  chunks = [_make_rc(section_id=str(i)) for i in range(1, 4)]
@@ -271,13 +279,14 @@ def test_generator_fallback_when_cited_chunks_empty():
271
 
272
  with patch("civicsetu.agent.nodes._llm_call", return_value=llm_out):
273
  state = _base_state(reranked_chunks=chunks)
274
- result = generator_node(state)
275
 
276
  # Fallback: all 3 chunks cited
277
  assert len(result["citations"]) == 3
278
 
279
 
280
- def test_generator_filters_invalid_indices():
 
281
  from civicsetu.agent.nodes import generator_node
282
 
283
  chunks = [_make_rc(section_id=str(i)) for i in range(1, 4)]
@@ -286,48 +295,52 @@ def test_generator_filters_invalid_indices():
286
 
287
  with patch("civicsetu.agent.nodes._llm_call", return_value=llm_out):
288
  state = _base_state(reranked_chunks=chunks)
289
- result = generator_node(state)
290
 
291
  # Only index 2 is valid → 1 citation
292
  assert len(result["citations"]) == 1
293
  assert result["citations"][0].section_id == "2"
294
 
295
 
296
- def test_generator_returns_empty_on_no_chunks():
 
297
  from civicsetu.agent.nodes import generator_node
298
 
299
- result = generator_node(_base_state(reranked_chunks=[]))
300
  assert result["citations"] == []
301
  assert result["confidence_score"] == 0.0
302
 
303
 
304
- def test_generator_handles_malformed_llm_json():
 
305
  from civicsetu.agent.nodes import generator_node
306
 
307
  chunks = [_make_rc(section_id="18")]
308
  with patch("civicsetu.agent.nodes._llm_call", return_value="not json {{{{"):
309
  state = _base_state(reranked_chunks=chunks)
310
- result = generator_node(state)
311
  # Salvage path: non-empty malformed text is returned as answer with 0.3 confidence
312
  assert result["confidence_score"] == 0.3
313
  assert len(result["citations"]) == 1 # salvage cites all provided chunks
314
 
315
 
316
- def test_generator_salvages_plain_text_when_json_missing():
 
317
  from civicsetu.agent.nodes import generator_node
318
 
319
  chunks = [_make_rc(section_id="18"), _make_rc(section_id="31")]
320
  raw = "Promoter must register project, disclose details, and honor timelines."
321
 
322
  with patch("civicsetu.agent.nodes._llm_call", return_value=raw):
323
- result = generator_node(_base_state(reranked_chunks=chunks))
324
 
325
  assert result["raw_response"] == raw
326
  assert result["confidence_score"] == 0.3
327
  assert len(result["citations"]) == 2
328
 
329
 
330
- def test_classifier_extracts_json_from_reasoning_wrapper():
 
331
  from civicsetu.agent.nodes import classifier_node
332
  from civicsetu.models.enums import QueryType
333
 
@@ -341,13 +354,14 @@ Thinking:
341
  """
342
 
343
  with patch("civicsetu.agent.nodes._llm_call", return_value=wrapped):
344
- result = classifier_node(_base_state(query="What does Section 18 say?"))
345
 
346
  assert result["query_type"] == QueryType.CROSS_REFERENCE
347
  assert result["rewritten_query"] == "Section 18 refund obligations under RERA Act"
348
 
349
 
350
- def test_generator_extracts_json_from_reasoning_wrapper():
 
351
  from civicsetu.agent.nodes import generator_node
352
 
353
  chunks = [_make_rc(section_id="18")]
@@ -360,14 +374,15 @@ Here is structured answer.
360
  """
361
 
362
  with patch("civicsetu.agent.nodes._llm_call", return_value=wrapped):
363
- result = generator_node(_base_state(reranked_chunks=chunks))
364
 
365
  assert result["raw_response"] == "Refund due under Section 18."
366
  assert result["confidence_score"] == 0.8
367
  assert len(result["citations"]) == 1
368
 
369
 
370
- def test_llm_call_uses_json_mode_for_osmapi(monkeypatch):
 
371
  import civicsetu.agent.nodes as nodes_mod
372
 
373
  monkeypatch.setenv("OPENAI_API_BASE", "https://api.osmapi.com/v1")
@@ -376,14 +391,16 @@ def test_llm_call_uses_json_mode_for_osmapi(monkeypatch):
376
  fake_response.choices = [MagicMock(message=MagicMock(content='{"ok": true}'))]
377
  fake_response.usage = None
378
 
379
- with patch("civicsetu.agent.nodes.litellm.completion", return_value=fake_response) as completion:
380
- result = nodes_mod._llm_call("prompt", "system")
 
381
 
382
  assert result == '{"ok": true}'
383
  assert completion.call_args.kwargs["response_format"] == {"type": "json_object"}
384
 
385
 
386
- def test_llm_call_does_not_force_no_reasoning_for_osmapi(monkeypatch):
 
387
  import civicsetu.agent.nodes as nodes_mod
388
 
389
  monkeypatch.setenv("OPENAI_API_BASE", "https://api.osmapi.com/v1")
@@ -395,16 +412,18 @@ def test_llm_call_does_not_force_no_reasoning_for_osmapi(monkeypatch):
395
  original_models = nodes_mod.THINKING_MODELS[:]
396
  nodes_mod.THINKING_MODELS[:] = ["openai/gpt-4o-mini"]
397
 
398
- with patch("civicsetu.agent.nodes.litellm.completion", return_value=fake_response) as completion:
 
399
  try:
400
- nodes_mod._llm_call("prompt", "system")
401
  finally:
402
  nodes_mod.THINKING_MODELS[:] = original_models
403
 
404
  assert "extra_body" not in completion.call_args.kwargs
405
 
406
 
407
- def test_llm_call_does_not_attach_deepseek_kwargs_to_non_deepseek_models():
 
408
  import civicsetu.agent.nodes as nodes_mod
409
 
410
  fake_response = MagicMock()
@@ -414,9 +433,10 @@ def test_llm_call_does_not_attach_deepseek_kwargs_to_non_deepseek_models():
414
  original_models = nodes_mod.THINKING_MODELS[:]
415
  nodes_mod.THINKING_MODELS[:] = ["groq/llama-3.3-70b-versatile"]
416
 
417
- with patch("civicsetu.agent.nodes.litellm.completion", return_value=fake_response) as completion:
 
418
  try:
419
- nodes_mod._llm_call("prompt", "system", tier="thinking")
420
  finally:
421
  nodes_mod.THINKING_MODELS[:] = original_models
422
 
@@ -424,7 +444,8 @@ def test_llm_call_does_not_attach_deepseek_kwargs_to_non_deepseek_models():
424
  assert "extra_body" not in completion.call_args.kwargs
425
 
426
 
427
- def test_generator_deduplicates_citations():
 
428
  from civicsetu.agent.nodes import generator_node
429
 
430
  # Two chunks with same (section_id, doc_name) — different chunk_ids
@@ -434,18 +455,19 @@ def test_generator_deduplicates_citations():
434
 
435
  with patch("civicsetu.agent.nodes._llm_call", return_value=llm_out):
436
  state = _base_state(reranked_chunks=[chunk_a, chunk_b])
437
- result = generator_node(state)
438
 
439
  assert len(result["citations"]) == 1
440
 
441
 
442
- def test_generator_system_prompt_uses_plain_language_persona():
 
443
  from civicsetu.agent.nodes import generator_node
444
  from civicsetu.models.enums import QueryType
445
 
446
  captured = {}
447
 
448
- def fake_llm_call(prompt: str, system: str, temperature: float = 0.0) -> str:
449
  captured["prompt"] = prompt
450
  captured["system"] = system
451
  return _llm_response(answer="If a builder misses the deadline, the buyer gets a remedy.")
@@ -455,7 +477,7 @@ def test_generator_system_prompt_uses_plain_language_persona():
455
  query_type=QueryType.FACT_LOOKUP,
456
  reranked_chunks=[_make_rc(section_id="18")],
457
  )
458
- generator_node(state)
459
 
460
  assert "plain-language guide to Indian RERA laws" in captured["system"]
461
  assert "explain what the law means in practice" in captured["system"]
@@ -474,12 +496,13 @@ def test_generator_system_prompt_uses_plain_language_persona():
474
  ("temporal", "Lead with the specific time period or deadline"),
475
  ],
476
  )
477
- def test_generator_system_prompt_includes_query_type_tone_hint(query_type, expected_hint):
 
478
  from civicsetu.agent.nodes import generator_node
479
 
480
  captured = {}
481
 
482
- def fake_llm_call(prompt: str, system: str, temperature: float = 0.0) -> str:
483
  captured["system"] = system
484
  return _llm_response()
485
 
@@ -488,7 +511,7 @@ def test_generator_system_prompt_includes_query_type_tone_hint(query_type, expec
488
  query_type=query_type,
489
  reranked_chunks=[_make_rc(section_id="18")],
490
  )
491
- generator_node(state)
492
 
493
  assert expected_hint in captured["system"]
494
 
@@ -502,12 +525,12 @@ def test_get_ranker_uses_settings_model():
502
  reranker_mod._ranker = None # clear module-level cache
503
 
504
  with patch("civicsetu.retrieval.reranker.settings") as mock_settings:
505
- mock_settings.reranker_model = "rank-T5-flan"
506
  with patch("flashrank.Ranker") as MockRanker:
507
  MockRanker.return_value = MagicMock()
508
  reranker_mod._get_ranker()
509
  MockRanker.assert_called_once_with(
510
- model_name="rank-T5-flan", cache_dir=".cache/flashrank"
511
  )
512
  reranker_mod._ranker = None # clean up
513
 
@@ -602,7 +625,8 @@ def test_score_gap_new_threshold_still_cuts_on_cliff():
602
 
603
  # ── reranker_node threshold + gap filtering ───────────────────────────────────
604
 
605
- def test_reranker_drops_below_threshold():
 
606
  """Chunks scoring below reranker_score_threshold must not appear in output."""
607
  from civicsetu.agent.nodes import reranker_node
608
  from unittest.mock import patch, MagicMock
@@ -626,14 +650,15 @@ def test_reranker_drops_below_threshold():
626
  reranked_chunks=[],
627
  query="test query",
628
  )
629
- result = reranker_node(state)
630
 
631
  section_ids = [c.chunk.section_id for c in result["reranked_chunks"]]
632
  assert "1" in section_ids
633
  assert "2" not in section_ids
634
 
635
 
636
- def test_reranker_applies_score_gap():
 
637
  """A large score cliff stops inclusion even if remaining chunks exceed threshold."""
638
  from civicsetu.agent.nodes import reranker_node
639
  from unittest.mock import patch, MagicMock
@@ -660,7 +685,7 @@ def test_reranker_applies_score_gap():
660
  reranked_chunks=[],
661
  query="test query",
662
  )
663
- result = reranker_node(state)
664
 
665
  section_ids = [c.chunk.section_id for c in result["reranked_chunks"]]
666
  assert "1" in section_ids
@@ -668,7 +693,8 @@ def test_reranker_applies_score_gap():
668
  assert "3" not in section_ids
669
 
670
 
671
- def test_reranker_filtered_count_logged():
 
672
  """reranker_filtered log event must report correct before/after/dropped counts."""
673
  from civicsetu.agent.nodes import reranker_node
674
  from unittest.mock import patch, MagicMock
@@ -693,7 +719,7 @@ def test_reranker_filtered_count_logged():
693
  reranked_chunks=[],
694
  query="test query",
695
  )
696
- result = reranker_node(state)
697
 
698
  # Verify reranked_chunks output
699
  assert len(result["reranked_chunks"]) == 1
 
1
  from __future__ import annotations
2
 
3
+ import asyncio
4
  import json
 
5
  from unittest.mock import MagicMock, patch
6
 
7
  import pytest
 
14
  """Settings ship with safe, sensible defaults — no .env needed."""
15
  from civicsetu.config.settings import Settings
16
  s = Settings()
17
+ assert s.reranker_model == "ms-marco-MiniLM-L-12-v2"
18
  assert s.reranker_score_threshold == 0.05
19
+ assert s.reranker_score_gap == 0.95
20
 
21
 
22
  def test_reranker_settings_env_override(monkeypatch):
 
51
  get_settings.cache_clear()
52
 
53
 
54
+ @pytest.mark.asyncio
55
+ async def test_reranker_empty_chunks():
56
  from civicsetu.agent.nodes import reranker_node
57
+ result = await reranker_node(_base_state(retrieved_chunks=[], reranked_chunks=[]))
58
  assert result["reranked_chunks"] == []
59
 
60
 
61
+ @pytest.mark.asyncio
62
+ async def test_reranker_pinned_chunks_always_first():
63
  from civicsetu.agent.nodes import reranker_node
64
 
65
  pinned = _make_rc(section_id="18", is_pinned=True)
 
76
  reranked_chunks=[],
77
  query="test query",
78
  )
79
+ result = await reranker_node(state)
80
 
81
  reranked = result["reranked_chunks"]
82
  assert reranked[0].is_pinned is True
83
  assert reranked[0].chunk.section_id == "18"
84
 
85
 
86
+ @pytest.mark.asyncio
87
+ async def test_reranker_keeps_pinned_chunks_up_to_context_limit():
88
  from civicsetu.agent.nodes import reranker_node
89
 
90
  pinned_chunks = [_make_rc(section_id=str(i), is_pinned=True) for i in range(4)]
 
98
  reranked_chunks=[],
99
  query="test query",
100
  )
101
+ result = await reranker_node(state)
102
 
103
  pinned_in_result = [c for c in result["reranked_chunks"] if c.is_pinned]
104
  assert len(pinned_in_result) == 4
105
 
106
 
107
+ @pytest.mark.asyncio
108
+ async def test_reranker_deduplicates_by_chunk_id():
109
  from civicsetu.agent.nodes import reranker_node
110
 
111
  chunk = _make_rc(section_id="18")
 
122
  reranked_chunks=[],
123
  query="test query",
124
  )
125
+ result = await reranker_node(state)
126
 
127
  all_ids = [str(c.chunk.chunk_id) for c in result["reranked_chunks"]]
128
  assert len(all_ids) == len(set(all_ids))
 
182
  assert [c.chunk.section_id for c in ranked] == ["7(6)", "7"]
183
 
184
 
185
+ @pytest.mark.asyncio
186
+ async def test_prepend_pinned_sections_promotes_existing_matches(monkeypatch):
187
  from civicsetu.agent.nodes import _prepend_pinned_sections
188
  from civicsetu.models.enums import Jurisdiction
189
 
 
191
  rule_5 = _make_rc(section_id="5", jurisdiction=Jurisdiction.KARNATAKA)
192
  section_4 = _make_rc(section_id="4(14)", jurisdiction=Jurisdiction.CENTRAL)
193
 
194
+ async def fake_fetch(refs, jur, existing, hint=""):
 
195
  return []
196
 
197
+ monkeypatch.setattr("civicsetu.agent.nodes._fetch_pinned_sections", fake_fetch)
198
 
199
+ result = await _prepend_pinned_sections(
200
+ _base_state(
201
+ pinned_section_refs=["Rule 5", "Section 4"],
202
+ pinned_section_jurisdiction=Jurisdiction.KARNATAKA,
203
+ pinned_section_hint="separate bank account seventy per cent",
204
+ ),
205
  [noisy, rule_5, section_4],
206
  )
207
 
 
210
  assert result[1].is_pinned is True
211
 
212
 
213
+ @pytest.mark.asyncio
214
+ async def test_reranker_node_trims_eval_context_to_pinned_families():
215
  from civicsetu.agent.nodes import reranker_node
216
  from civicsetu.models.enums import Jurisdiction
217
 
 
220
  noisy_central = _make_rc(section_id="4(10)", jurisdiction=Jurisdiction.CENTRAL)
221
  noisy_state = _make_rc(section_id="3", jurisdiction=Jurisdiction.UTTAR_PRADESH)
222
 
223
+ # Use a side effect to return only the input chunks that were not filtered out
224
+ def mock_rerank(chunks, query):
225
+ return chunks
226
+
227
+ with patch("civicsetu.retrieval.reranker.Reranker.rerank", side_effect=mock_rerank):
228
+ result = await reranker_node(
229
  _base_state(
230
  query="Which Karnataka rule implements project account maintenance?",
231
  rewritten_query="Which Karnataka rule implements project account maintenance?",
 
250
  return json.dumps(payload)
251
 
252
 
253
+ @pytest.mark.asyncio
254
+ async def test_generator_cites_only_referenced_chunks():
255
  from civicsetu.agent.nodes import generator_node
256
 
257
  chunks = [_make_rc(section_id=str(i)) for i in range(1, 6)]
 
263
  query="What does Section 1 say?",
264
  reranked_chunks=chunks,
265
  )
266
+ result = await generator_node(state)
267
 
268
  assert len(result["citations"]) == 2
269
  cited_ids = {c.section_id for c in result["citations"]}
270
  assert cited_ids == {"1", "3"}
271
 
272
 
273
+ @pytest.mark.asyncio
274
+ async def test_generator_fallback_when_cited_chunks_empty():
275
  from civicsetu.agent.nodes import generator_node
276
 
277
  chunks = [_make_rc(section_id=str(i)) for i in range(1, 4)]
 
279
 
280
  with patch("civicsetu.agent.nodes._llm_call", return_value=llm_out):
281
  state = _base_state(reranked_chunks=chunks)
282
+ result = await generator_node(state)
283
 
284
  # Fallback: all 3 chunks cited
285
  assert len(result["citations"]) == 3
286
 
287
 
288
+ @pytest.mark.asyncio
289
+ async def test_generator_filters_invalid_indices():
290
  from civicsetu.agent.nodes import generator_node
291
 
292
  chunks = [_make_rc(section_id=str(i)) for i in range(1, 4)]
 
295
 
296
  with patch("civicsetu.agent.nodes._llm_call", return_value=llm_out):
297
  state = _base_state(reranked_chunks=chunks)
298
+ result = await generator_node(state)
299
 
300
  # Only index 2 is valid → 1 citation
301
  assert len(result["citations"]) == 1
302
  assert result["citations"][0].section_id == "2"
303
 
304
 
305
+ @pytest.mark.asyncio
306
+ async def test_generator_returns_empty_on_no_chunks():
307
  from civicsetu.agent.nodes import generator_node
308
 
309
+ result = await generator_node(_base_state(reranked_chunks=[]))
310
  assert result["citations"] == []
311
  assert result["confidence_score"] == 0.0
312
 
313
 
314
+ @pytest.mark.asyncio
315
+ async def test_generator_handles_malformed_llm_json():
316
  from civicsetu.agent.nodes import generator_node
317
 
318
  chunks = [_make_rc(section_id="18")]
319
  with patch("civicsetu.agent.nodes._llm_call", return_value="not json {{{{"):
320
  state = _base_state(reranked_chunks=chunks)
321
+ result = await generator_node(state)
322
  # Salvage path: non-empty malformed text is returned as answer with 0.3 confidence
323
  assert result["confidence_score"] == 0.3
324
  assert len(result["citations"]) == 1 # salvage cites all provided chunks
325
 
326
 
327
+ @pytest.mark.asyncio
328
+ async def test_generator_salvages_plain_text_when_json_missing():
329
  from civicsetu.agent.nodes import generator_node
330
 
331
  chunks = [_make_rc(section_id="18"), _make_rc(section_id="31")]
332
  raw = "Promoter must register project, disclose details, and honor timelines."
333
 
334
  with patch("civicsetu.agent.nodes._llm_call", return_value=raw):
335
+ result = await generator_node(_base_state(reranked_chunks=chunks))
336
 
337
  assert result["raw_response"] == raw
338
  assert result["confidence_score"] == 0.3
339
  assert len(result["citations"]) == 2
340
 
341
 
342
+ @pytest.mark.asyncio
343
+ async def test_classifier_extracts_json_from_reasoning_wrapper():
344
  from civicsetu.agent.nodes import classifier_node
345
  from civicsetu.models.enums import QueryType
346
 
 
354
  """
355
 
356
  with patch("civicsetu.agent.nodes._llm_call", return_value=wrapped):
357
+ result = await classifier_node(_base_state(query="What does Section 18 say?"))
358
 
359
  assert result["query_type"] == QueryType.CROSS_REFERENCE
360
  assert result["rewritten_query"] == "Section 18 refund obligations under RERA Act"
361
 
362
 
363
+ @pytest.mark.asyncio
364
+ async def test_generator_extracts_json_from_reasoning_wrapper():
365
  from civicsetu.agent.nodes import generator_node
366
 
367
  chunks = [_make_rc(section_id="18")]
 
374
  """
375
 
376
  with patch("civicsetu.agent.nodes._llm_call", return_value=wrapped):
377
+ result = await generator_node(_base_state(reranked_chunks=chunks))
378
 
379
  assert result["raw_response"] == "Refund due under Section 18."
380
  assert result["confidence_score"] == 0.8
381
  assert len(result["citations"]) == 1
382
 
383
 
384
+ @pytest.mark.asyncio
385
+ async def test_llm_call_uses_json_mode_for_osmapi(monkeypatch):
386
  import civicsetu.agent.nodes as nodes_mod
387
 
388
  monkeypatch.setenv("OPENAI_API_BASE", "https://api.osmapi.com/v1")
 
391
  fake_response.choices = [MagicMock(message=MagicMock(content='{"ok": true}'))]
392
  fake_response.usage = None
393
 
394
+ with patch("civicsetu.agent.nodes.litellm.acompletion", new=MagicMock(return_value=asyncio.Future())) as completion:
395
+ completion.return_value.set_result(fake_response)
396
+ result = await nodes_mod._llm_call("prompt", "system")
397
 
398
  assert result == '{"ok": true}'
399
  assert completion.call_args.kwargs["response_format"] == {"type": "json_object"}
400
 
401
 
402
+ @pytest.mark.asyncio
403
+ async def test_llm_call_does_not_force_no_reasoning_for_osmapi(monkeypatch):
404
  import civicsetu.agent.nodes as nodes_mod
405
 
406
  monkeypatch.setenv("OPENAI_API_BASE", "https://api.osmapi.com/v1")
 
412
  original_models = nodes_mod.THINKING_MODELS[:]
413
  nodes_mod.THINKING_MODELS[:] = ["openai/gpt-4o-mini"]
414
 
415
+ with patch("civicsetu.agent.nodes.litellm.acompletion", new=MagicMock(return_value=asyncio.Future())) as completion:
416
+ completion.return_value.set_result(fake_response)
417
  try:
418
+ await nodes_mod._llm_call("prompt", "system")
419
  finally:
420
  nodes_mod.THINKING_MODELS[:] = original_models
421
 
422
  assert "extra_body" not in completion.call_args.kwargs
423
 
424
 
425
+ @pytest.mark.asyncio
426
+ async def test_llm_call_does_not_attach_deepseek_kwargs_to_non_deepseek_models():
427
  import civicsetu.agent.nodes as nodes_mod
428
 
429
  fake_response = MagicMock()
 
433
  original_models = nodes_mod.THINKING_MODELS[:]
434
  nodes_mod.THINKING_MODELS[:] = ["groq/llama-3.3-70b-versatile"]
435
 
436
+ with patch("civicsetu.agent.nodes.litellm.acompletion", new=MagicMock(return_value=asyncio.Future())) as completion:
437
+ completion.return_value.set_result(fake_response)
438
  try:
439
+ await nodes_mod._llm_call("prompt", "system", tier="thinking")
440
  finally:
441
  nodes_mod.THINKING_MODELS[:] = original_models
442
 
 
444
  assert "extra_body" not in completion.call_args.kwargs
445
 
446
 
447
+ @pytest.mark.asyncio
448
+ async def test_generator_deduplicates_citations():
449
  from civicsetu.agent.nodes import generator_node
450
 
451
  # Two chunks with same (section_id, doc_name) — different chunk_ids
 
455
 
456
  with patch("civicsetu.agent.nodes._llm_call", return_value=llm_out):
457
  state = _base_state(reranked_chunks=[chunk_a, chunk_b])
458
+ result = await generator_node(state)
459
 
460
  assert len(result["citations"]) == 1
461
 
462
 
463
+ @pytest.mark.asyncio
464
+ async def test_generator_system_prompt_uses_plain_language_persona():
465
  from civicsetu.agent.nodes import generator_node
466
  from civicsetu.models.enums import QueryType
467
 
468
  captured = {}
469
 
470
+ async def fake_llm_call(prompt: str, system: str, temperature: float = 0.0, tier="thinking") -> str:
471
  captured["prompt"] = prompt
472
  captured["system"] = system
473
  return _llm_response(answer="If a builder misses the deadline, the buyer gets a remedy.")
 
477
  query_type=QueryType.FACT_LOOKUP,
478
  reranked_chunks=[_make_rc(section_id="18")],
479
  )
480
+ await generator_node(state)
481
 
482
  assert "plain-language guide to Indian RERA laws" in captured["system"]
483
  assert "explain what the law means in practice" in captured["system"]
 
496
  ("temporal", "Lead with the specific time period or deadline"),
497
  ],
498
  )
499
+ @pytest.mark.asyncio
500
+ async def test_generator_system_prompt_includes_query_type_tone_hint(query_type, expected_hint):
501
  from civicsetu.agent.nodes import generator_node
502
 
503
  captured = {}
504
 
505
+ async def fake_llm_call(prompt: str, system: str, temperature: float = 0.0, tier="thinking") -> str:
506
  captured["system"] = system
507
  return _llm_response()
508
 
 
511
  query_type=query_type,
512
  reranked_chunks=[_make_rc(section_id="18")],
513
  )
514
+ await generator_node(state)
515
 
516
  assert expected_hint in captured["system"]
517
 
 
525
  reranker_mod._ranker = None # clear module-level cache
526
 
527
  with patch("civicsetu.retrieval.reranker.settings") as mock_settings:
528
+ mock_settings.reranker_model = "ms-marco-MiniLM-L-12-v2"
529
  with patch("flashrank.Ranker") as MockRanker:
530
  MockRanker.return_value = MagicMock()
531
  reranker_mod._get_ranker()
532
  MockRanker.assert_called_once_with(
533
+ model_name="ms-marco-MiniLM-L-12-v2", cache_dir=".cache/flashrank"
534
  )
535
  reranker_mod._ranker = None # clean up
536
 
 
625
 
626
  # ── reranker_node threshold + gap filtering ───────────────────────────────────
627
 
628
+ @pytest.mark.asyncio
629
+ async def test_reranker_drops_below_threshold():
630
  """Chunks scoring below reranker_score_threshold must not appear in output."""
631
  from civicsetu.agent.nodes import reranker_node
632
  from unittest.mock import patch, MagicMock
 
650
  reranked_chunks=[],
651
  query="test query",
652
  )
653
+ result = await reranker_node(state)
654
 
655
  section_ids = [c.chunk.section_id for c in result["reranked_chunks"]]
656
  assert "1" in section_ids
657
  assert "2" not in section_ids
658
 
659
 
660
+ @pytest.mark.asyncio
661
+ async def test_reranker_applies_score_gap():
662
  """A large score cliff stops inclusion even if remaining chunks exceed threshold."""
663
  from civicsetu.agent.nodes import reranker_node
664
  from unittest.mock import patch, MagicMock
 
685
  reranked_chunks=[],
686
  query="test query",
687
  )
688
+ result = await reranker_node(state)
689
 
690
  section_ids = [c.chunk.section_id for c in result["reranked_chunks"]]
691
  assert "1" in section_ids
 
693
  assert "3" not in section_ids
694
 
695
 
696
+ @pytest.mark.asyncio
697
+ async def test_reranker_filtered_count_logged():
698
  """reranker_filtered log event must report correct before/after/dropped counts."""
699
  from civicsetu.agent.nodes import reranker_node
700
  from unittest.mock import patch, MagicMock
 
719
  reranked_chunks=[],
720
  query="test query",
721
  )
722
+ result = await reranker_node(state)
723
 
724
  # Verify reranked_chunks output
725
  assert len(result["reranked_chunks"]) == 1
tests/unit/api/test_query_route.py CHANGED
@@ -142,14 +142,14 @@ def test_graph_topology_returns_empty_payload_when_neo4j_auth_fails(client):
142
 
143
  def test_query_returns_200_with_citations(client):
144
  test_client, mock_graph = client
145
- mock_graph.invoke.return_value = {
146
  "raw_response": "Under Section 18, the promoter must...",
147
  "citations": [_make_citation("18")],
148
  "confidence_score": 0.9,
149
  "query_type": QueryType.CROSS_REFERENCE,
150
  "conflict_warnings": [],
151
  "amendment_notice": None,
152
- }
153
 
154
  response = test_client.post("/api/v1/query", json={"query": "What does Section 18 say?"})
155
  assert response.status_code == 200
@@ -162,14 +162,14 @@ def test_query_returns_200_with_citations(client):
162
 
163
  def test_query_returns_insufficient_when_no_citations(client):
164
  test_client, mock_graph = client
165
- mock_graph.invoke.return_value = {
166
  "raw_response": "I don't know",
167
  "citations": [],
168
  "confidence_score": 0.2,
169
  "query_type": QueryType.FACT_LOOKUP,
170
  "conflict_warnings": [],
171
  "amendment_notice": None,
172
- }
173
 
174
  response = test_client.post("/api/v1/query", json={"query": "Some obscure question here"})
175
  assert response.status_code == 200
@@ -194,14 +194,14 @@ def test_query_rejects_top_k_out_of_range(client):
194
 
195
  def test_query_response_always_has_disclaimer(client):
196
  test_client, mock_graph = client
197
- mock_graph.invoke.return_value = {
198
  "raw_response": "Under Section 18...",
199
  "citations": [_make_citation()],
200
  "confidence_score": 0.8,
201
  "query_type": QueryType.FACT_LOOKUP,
202
  "conflict_warnings": [],
203
  "amendment_notice": None,
204
- }
205
 
206
  response = test_client.post("/api/v1/query", json={"query": "What does Section 18 say?"})
207
  body = response.json()
@@ -211,14 +211,14 @@ def test_query_response_always_has_disclaimer(client):
211
 
212
  def test_query_reuses_provided_session_id(client):
213
  test_client, mock_graph = client
214
- mock_graph.invoke.return_value = {
215
  "raw_response": "Under Section 18...",
216
  "citations": [_make_citation()],
217
  "confidence_score": 0.8,
218
  "query_type": QueryType.FACT_LOOKUP,
219
  "conflict_warnings": [],
220
  "amendment_notice": None,
221
- }
222
 
223
  response = test_client.post(
224
  "/api/v1/query",
@@ -228,7 +228,7 @@ def test_query_reuses_provided_session_id(client):
228
  body = response.json()
229
  assert response.status_code == 200
230
  assert body["session_id"] == "my-session-abc"
231
- _, config = mock_graph.invoke.call_args.args
232
  assert config["configurable"]["thread_id"] == "my-session-abc"
233
 
234
 
@@ -393,14 +393,14 @@ def test_graph_section_content_resolves_graph_chunk_id_case_insensitive_status(c
393
 
394
  def test_section_context_query_sets_skip_classifier_and_source_section(client):
395
  test_client, mock_graph = client
396
- mock_graph.invoke.return_value = {
397
  "raw_response": "Section 18 requires refund with interest.",
398
  "citations": [_make_citation("18")],
399
  "confidence_score": 0.88,
400
  "query_type": QueryType.CROSS_REFERENCE,
401
  "conflict_warnings": [],
402
  "amendment_notice": None,
403
- }
404
 
405
  response = test_client.post(
406
  "/api/v1/query/section-context",
@@ -415,7 +415,7 @@ def test_section_context_query_sets_skip_classifier_and_source_section(client):
415
  assert response.status_code == 200
416
  body = response.json()
417
  assert body["session_id"] == "section-thread-1"
418
- initial_state, config = mock_graph.invoke.call_args.args
419
  assert initial_state["skip_classifier"] is True
420
  assert initial_state["source_section_id"] == "18"
421
  assert initial_state["query_type"] == QueryType.CROSS_REFERENCE
 
142
 
143
  def test_query_returns_200_with_citations(client):
144
  test_client, mock_graph = client
145
+ mock_graph.ainvoke = AsyncMock(return_value={
146
  "raw_response": "Under Section 18, the promoter must...",
147
  "citations": [_make_citation("18")],
148
  "confidence_score": 0.9,
149
  "query_type": QueryType.CROSS_REFERENCE,
150
  "conflict_warnings": [],
151
  "amendment_notice": None,
152
+ })
153
 
154
  response = test_client.post("/api/v1/query", json={"query": "What does Section 18 say?"})
155
  assert response.status_code == 200
 
162
 
163
  def test_query_returns_insufficient_when_no_citations(client):
164
  test_client, mock_graph = client
165
+ mock_graph.ainvoke = AsyncMock(return_value={
166
  "raw_response": "I don't know",
167
  "citations": [],
168
  "confidence_score": 0.2,
169
  "query_type": QueryType.FACT_LOOKUP,
170
  "conflict_warnings": [],
171
  "amendment_notice": None,
172
+ })
173
 
174
  response = test_client.post("/api/v1/query", json={"query": "Some obscure question here"})
175
  assert response.status_code == 200
 
194
 
195
  def test_query_response_always_has_disclaimer(client):
196
  test_client, mock_graph = client
197
+ mock_graph.ainvoke = AsyncMock(return_value={
198
  "raw_response": "Under Section 18...",
199
  "citations": [_make_citation()],
200
  "confidence_score": 0.8,
201
  "query_type": QueryType.FACT_LOOKUP,
202
  "conflict_warnings": [],
203
  "amendment_notice": None,
204
+ })
205
 
206
  response = test_client.post("/api/v1/query", json={"query": "What does Section 18 say?"})
207
  body = response.json()
 
211
 
212
  def test_query_reuses_provided_session_id(client):
213
  test_client, mock_graph = client
214
+ mock_graph.ainvoke = AsyncMock(return_value={
215
  "raw_response": "Under Section 18...",
216
  "citations": [_make_citation()],
217
  "confidence_score": 0.8,
218
  "query_type": QueryType.FACT_LOOKUP,
219
  "conflict_warnings": [],
220
  "amendment_notice": None,
221
+ })
222
 
223
  response = test_client.post(
224
  "/api/v1/query",
 
228
  body = response.json()
229
  assert response.status_code == 200
230
  assert body["session_id"] == "my-session-abc"
231
+ _, config = mock_graph.ainvoke.call_args.args
232
  assert config["configurable"]["thread_id"] == "my-session-abc"
233
 
234
 
 
393
 
394
  def test_section_context_query_sets_skip_classifier_and_source_section(client):
395
  test_client, mock_graph = client
396
+ mock_graph.ainvoke = AsyncMock(return_value={
397
  "raw_response": "Section 18 requires refund with interest.",
398
  "citations": [_make_citation("18")],
399
  "confidence_score": 0.88,
400
  "query_type": QueryType.CROSS_REFERENCE,
401
  "conflict_warnings": [],
402
  "amendment_notice": None,
403
+ })
404
 
405
  response = test_client.post(
406
  "/api/v1/query/section-context",
 
415
  assert response.status_code == 200
416
  body = response.json()
417
  assert body["session_id"] == "section-thread-1"
418
+ initial_state, config = mock_graph.ainvoke.call_args.args
419
  assert initial_state["skip_classifier"] is True
420
  assert initial_state["source_section_id"] == "18"
421
  assert initial_state["query_type"] == QueryType.CROSS_REFERENCE
tests/unit/stores/test_graph_store.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from contextlib import asynccontextmanager
4
+ from unittest.mock import AsyncMock
5
+
6
+ import pytest
7
+
8
+
9
+ def _make_driver(session):
10
+ driver = AsyncMock()
11
+
12
+ @asynccontextmanager
13
+ async def _session_cm():
14
+ yield session
15
+
16
+ driver.session = _session_cm
17
+ return driver
18
+
19
+
20
+ @pytest.mark.asyncio
21
+ async def test_graph_topology_retries_after_transient_driver_error(monkeypatch):
22
+ from civicsetu.stores import graph_store
23
+
24
+ first_session = AsyncMock()
25
+ first_session.run.side_effect = RuntimeError("the connection is closed")
26
+
27
+ edges_result = AsyncMock()
28
+ edges_result.data.return_value = [{"source": "a", "target": "b", "edge_type": "REFERENCES"}]
29
+ nodes_result = AsyncMock()
30
+ nodes_result.data.return_value = [{"chunk_id": "a", "section_id": "18"}]
31
+
32
+ second_session = AsyncMock()
33
+ second_session.run.side_effect = [edges_result, nodes_result]
34
+
35
+ first_driver = _make_driver(first_session)
36
+ second_driver = _make_driver(second_session)
37
+ get_driver = AsyncMock(side_effect=[first_driver, second_driver])
38
+ reset_driver = AsyncMock()
39
+
40
+ monkeypatch.setattr(graph_store, "_get_driver", get_driver)
41
+ monkeypatch.setattr(graph_store, "_reset_driver", reset_driver)
42
+
43
+ nodes, edges = await graph_store.GraphStore.get_topology()
44
+
45
+ assert nodes == [{"chunk_id": "a", "section_id": "18"}]
46
+ assert edges == [{"source": "a", "target": "b", "edge_type": "REFERENCES"}]
47
+ assert get_driver.await_count == 2
48
+ reset_driver.assert_awaited_once()