civicsetu / scripts /test_e2e_queries.py
adeshboudh16
feat(phase4): multi-state RERA expansion + pipeline hardening
2d7bf65
# scripts/test_e2e_queries.py
"""
End-to-end query tests across all 4 states.
Tests vector retrieval, graph retrieval, and multi-jurisdiction paths.
Run:
uv run python scripts/test_e2e_queries.py
"""
from __future__ import annotations
import sys, json, time
import io
from pathlib import Path
if sys.stdout.encoding != "utf-8":
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8")
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8")
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
import structlog
from civicsetu.agent.graph import get_compiled_graph
from civicsetu.models.enums import Jurisdiction
log = structlog.get_logger(__name__)
# ── Test cases ────────────────────────────────────────────────────────────────
# Format: (label, query, jurisdiction_filter, expected_query_type, must_contain_section)
TEST_CASES = [
# ── Fact lookup β€” vector path ─────────────────────────────────────────────
(
"FACT_CENTRAL_01",
"What are the obligations of a promoter under RERA?",
None,
"fact_lookup",
"4", # RERA Act Section 4 β€” promoter obligations
),
(
"FACT_MH_01",
"What documents must a promoter submit for project registration in Maharashtra?",
Jurisdiction.MAHARASHTRA,
"fact_lookup",
None,
),
(
"FACT_UP_01",
"What is the timeline for refund of amount paid by allottee in Uttar Pradesh?",
Jurisdiction.UTTAR_PRADESH,
"fact_lookup",
None,
),
(
"FACT_KA_01",
"What are the grounds for revocation of project registration in Karnataka?",
Jurisdiction.KARNATAKA,
"fact_lookup",
None,
),
(
"FACT_TN_01",
"What information must a real estate agent provide to buyers in Tamil Nadu?",
Jurisdiction.TAMIL_NADU,
"fact_lookup",
None,
),
# ── Cross-reference β€” graph path ──────────────────────────────────────────
(
"XREF_01",
"What does section 18 of RERA say about refund obligations?",
None,
"cross_reference",
"18",
),
(
"XREF_02",
"Which state rules implement section 9 of the RERA Act on agent registration?",
None,
"cross_reference",
"9",
),
(
"XREF_MH_01",
"Which MahaRERA rule derives from section 4 of the central RERA Act?",
Jurisdiction.MAHARASHTRA,
"cross_reference",
None,
),
# ── Penalty lookup β€” graph path ───────────────────────────────────────────
(
"PENALTY_01",
"What is the penalty for non-registration of a real estate project under RERA?",
None,
"penalty_lookup",
None,
),
(
"PENALTY_02",
"What are the penalties for a promoter who fails to register under RERA?",
None,
"penalty_lookup",
None,
),
# ── Multi-jurisdiction comparison ─────────────────────────────────────────
(
"MULTI_01",
"How does Karnataka handle extension of project registration compared to the central RERA Act?",
None,
"cross_reference",
None,
),
(
"MULTI_02",
"What is the rate of interest payable on refunds under Tamil Nadu RERA rules?",
Jurisdiction.TAMIL_NADU,
"fact_lookup",
None,
),
]
def run_test(graph, label, query, jurisdiction_filter, expected_type, must_contain):
start = time.time()
state = {
"query": query,
"jurisdiction_filter": jurisdiction_filter,
"top_k": 5,
"session_id": f"e2e_{label}",
"retrieved_chunks": [],
"reranked_chunks": [],
"citations": [],
"confidence_score": 0.0,
"conflict_warnings": [],
"amendment_notice": None,
"retry_count": 0,
"hallucination_flag": False,
"error": None,
}
try:
result = graph.invoke(state)
elapsed = time.time() - start
query_type = result.get("query_type", "UNKNOWN")
confidence = result.get("confidence_score", 0.0)
answer = result.get("raw_response", "") or ""
citations = result.get("citations", [])
error = result.get("error")
# Checks
type_ok = expected_type in str(query_type).lower() if expected_type else True
citation_ok = len(citations) > 0
section_ok = True
if must_contain:
section_ok = any(
must_contain == str(c.section_id) for c in citations
) or must_contain in answer
status = "PASS" if (citation_ok and section_ok and not error) else "FAIL"
return {
"label": label,
"status": status,
"query_type": str(query_type),
"type_match": type_ok,
"confidence": round(confidence, 3),
"citations": len(citations),
"section_check": f"{must_contain}={'OK' if section_ok else 'MISSING'}" if must_contain else "N/A",
"elapsed_s": round(elapsed, 2),
"error": error,
"answer_preview": answer[:120].replace("\n", " ") if answer else "",
}
except Exception as e:
return {
"label": label,
"status": "ERROR",
"query_type": "N/A",
"type_match": False,
"confidence": 0.0,
"citations": 0,
"section_check": "N/A",
"elapsed_s": round(time.time() - start, 2),
"error": str(e),
"answer_preview": "",
}
def main():
print("Compiling LangGraph...")
graph = get_compiled_graph()
print(f"Running {len(TEST_CASES)} test cases...\n")
results = []
for label, query, jfilter, etype, must_section in TEST_CASES:
print(f" [{label}] ...", end=" ", flush=True)
r = run_test(graph, label, query, jfilter, etype, must_section)
results.append(r)
status_icon = "PASS" if r["status"] == "PASS" else "FAIL"
print(f"{status_icon} ({r['elapsed_s']}s, conf={r['confidence']}, cit={r['citations']})")
# ── Summary ───────────────────────────────────────────────────────────────
passed = sum(1 for r in results if r["status"] == "PASS")
failed = sum(1 for r in results if r["status"] == "FAIL")
errors = sum(1 for r in results if r["status"] == "ERROR")
avg_conf = sum(r["confidence"] for r in results) / len(results)
avg_time = sum(r["elapsed_s"] for r in results) / len(results)
print("\n" + "=" * 70)
print(f"E2E Test Results: {passed} PASS {failed} FAIL {errors} ERROR")
print(f"Avg confidence : {avg_conf:.3f}")
print(f"Avg latency : {avg_time:.2f}s")
print("=" * 70)
# Print failures detail
for r in results:
if r["status"] != "PASS":
print(f"\n{r['status']}: {r['label']}")
print(f" query_type : {r['query_type']}")
print(f" section_check: {r['section_check']}")
print(f" citations : {r['citations']}")
print(f" error : {r['error']}")
print(f" answer : {r['answer_preview']}")
# Save full results
out = Path("e2e_results.json")
out.write_text(json.dumps(results, indent=2, default=str))
print(f"\nFull results β†’ {out}")
if failed + errors > 0:
sys.exit(1)
if __name__ == "__main__":
main()