adeshboudh16 commited on
Commit
a84bdca
·
1 Parent(s): 75fe6a6

feat: new run_eval.py — single-pass no-phases, 3 batch_score calls total via osmapi

Browse files
Files changed (1) hide show
  1. scripts/run_eval.py +175 -303
scripts/run_eval.py CHANGED
@@ -1,47 +1,24 @@
1
- # scripts/run_eval.py
2
  """
3
- RAGAS offline benchmark for CivicSetu.
4
 
5
- Two independent phases so you never re-invoke the graph just to re-score:
 
 
 
6
 
7
- Phase 1 — invoke graph for all queries, save raw results:
8
- make eval-collect
9
- # or: uv run python scripts/run_eval.py --phase 1
10
-
11
- Phase 2 — RAGAS scoring on saved results (reads eval_phase1_results.json):
12
- make eval-score
13
- # or: uv run python scripts/run_eval.py --phase 2
14
-
15
- Both phases in sequence:
16
- make eval
17
- # or: uv run python scripts/run_eval.py
18
-
19
- Env-var tuning:
20
- BATCH_SIZE=3 BATCH_DELAY_SEC=60 EVAL_LIMIT=3 make eval-score
21
-
22
- Judge provider (Phase 2):
23
- # Gemini free tier (15 RPM) — default; sleeps BATCH_DELAY_SEC between each metric
24
- JUDGE_PROVIDER=gemini BATCH_SIZE=1 BATCH_DELAY_SEC=60 make eval-score
25
-
26
- # Dual Gemini keys — parallel 2-worker mode (~30 RPM effective, default delay=30s)
27
- # Set GEMINI_API_KEY_2 + GEMINI_API_KEY_3 in .env; script auto-detects second key.
28
- BATCH_SIZE=2 make eval-score
29
-
30
- # OpenRouter free tier — more generous RPM; set OPENROUTER_API_KEY in .env
31
- JUDGE_PROVIDER=openrouter JUDGE_MODEL=stepfun/step-3.5-flash:free make eval-score
32
- JUDGE_PROVIDER=openrouter make eval-score # uses stepfun/step-3.5-flash:free by default
33
  """
34
  from __future__ import annotations
35
 
36
- import argparse
37
- # import asyncio # not needed — sequential mode
38
  import json
39
  import math
40
  import os
41
  import sys
42
  import time
43
  import io
44
- # from concurrent.futures import ThreadPoolExecutor, as_completed # not needed — sequential mode
45
  from datetime import datetime, timezone
46
  from pathlib import Path
47
 
@@ -51,146 +28,69 @@ if sys.stdout.encoding != "utf-8":
51
 
52
  sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
53
 
54
- # ── Constants ──────────────────────────────────────────────────────────────────
55
- # BATCH_SIZE = int(os.getenv("BATCH_SIZE", "2")) # commented out — sequential mode
56
- # BATCH_DELAY_SEC = int(os.getenv("BATCH_DELAY_SEC", "60")) # commented out — no rate-limit sleep
57
- PASS_THRESHOLD = float(os.getenv("PASS_THRESHOLD", "0.7"))
58
- EVAL_LIMIT = int(os.getenv("EVAL_LIMIT", "0")) or None # 0 = no limit
59
 
60
- # Judge provider: osmapi (OSMAPI_API_KEY) — active
61
- # Embeddings: google-genai (GEMINI_API_KEY_2) still needed for AnswerRelevancy
62
- # Previous providers (commented out in build_judge): gemini, openrouter
63
- JUDGE_PROVIDER = os.getenv("JUDGE_PROVIDER", "osmapi")
64
- JUDGE_MODEL = os.getenv("JUDGE_MODEL", "qwen3.5-122b-a10b")
65
 
66
- ROOT = Path(__file__).parent.parent
67
- DATASET_PATH = ROOT / "eval" / "golden_dataset.jsonl"
68
- PHASE1_OUT = ROOT / "eval_phase1_results.json"
69
- PHASE2_OUT = ROOT / "eval_results.json"
70
 
 
71
 
72
- # ── Logging helper ─────────────────────────────────────────────────────────────
73
-
74
- def _log(msg: str, label: str = "") -> None:
75
- ts = datetime.now().strftime("%H:%M:%S")
76
- prefix = f"[{ts}]" + (f" [{label}]" if label else "")
77
- print(f"{prefix} {msg}", flush=True)
78
-
79
-
80
- def _sleep_log(seconds: int, label: str = "") -> None:
81
- resume = datetime.fromtimestamp(time.time() + seconds).strftime("%H:%M:%S")
82
- _log(f"sleeping {seconds}s — resuming at {resume}", label)
83
- time.sleep(seconds)
84
 
85
 
86
- # ── RAGAS judge (RAGAS 0.4.x native API via instructor) ───────────────────────
87
 
88
- def build_judge(gemini_key: str):
89
  """
90
- Build RAGAS 0.4.x judge LLM + embeddings.
91
-
92
- LLM: osmapi (OSMAPI_API_KEY) — OpenAI-compatible endpoint, free credits.
93
- Embeddings: Google GenAI (GEMINI_API_KEY_2) — AnswerRelevancy needs semantic similarity.
94
-
95
- # ── Previously used providers (commented out) ──────────────────────────
96
- # JUDGE_PROVIDER=gemini:
97
- # llm_client = AsyncOpenAI(
98
- # api_key=gemini_key,
99
- # base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
100
- # timeout=120.0,
101
- # )
102
- # print(f" Judge: Gemini / {JUDGE_MODEL}")
103
- #
104
- # JUDGE_PROVIDER=openrouter:
105
- # openrouter_key = os.getenv("OPENROUTER_API_KEY")
106
- # llm_client = AsyncOpenAI(
107
- # api_key=openrouter_key,
108
- # base_url="https://openrouter.ai/api/v1",
109
- # timeout=120.0,
110
- # )
111
- # print(f" Judge: OpenRouter / {JUDGE_MODEL}")
112
- # ───────────────────────────────────────────────────────────────────────
113
  """
114
- from google import genai
 
 
115
  from openai import AsyncOpenAI
116
  from ragas.llms import llm_factory
117
  from ragas.embeddings import GoogleEmbeddings
 
118
 
119
  osmapi_key = os.getenv("OSMAPI_API_KEY")
120
  if not osmapi_key:
121
- print(
122
- "ERROR: OSMAPI_API_KEY is not set.\n"
123
- "Add it to your .env: OSMAPI_API_KEY=<your-key>",
124
- file=sys.stderr,
125
- )
 
126
  sys.exit(1)
127
 
128
  llm_client = AsyncOpenAI(
129
  api_key=osmapi_key,
130
  base_url="https://api.osmapi.com/v1",
131
- timeout=120.0, # 2-min cap per call — prevents infinite hangs
 
 
 
 
 
132
  )
133
- print(f" Judge: osmapi / {JUDGE_MODEL}")
134
-
135
- # max_tokens=8192: RERA answers produce many NLI statements; the RAGAS
136
- # default of 1024 causes IncompleteOutputException on complex legal answers.
137
- judge_llm = llm_factory(JUDGE_MODEL, client=llm_client, max_tokens=8192)
138
-
139
- # Embeddings: google-genai (AnswerRelevancy needs semantic similarity;
140
- # osmapi has no embedding endpoint).
141
- genai_client = genai.Client(api_key=gemini_key)
142
- judge_embeddings = GoogleEmbeddings(client=genai_client, model="gemini-embedding-001")
143
 
 
 
144
  return judge_llm, judge_embeddings
145
 
146
 
147
- def build_judge_pool() -> list[tuple]:
148
- """
149
- Build a single judge (llm, embeddings) pair using GEMINI_API_KEY_2.
150
- Always single-worker sequential mode to avoid auth conflicts.
151
- """
152
- from dotenv import load_dotenv
153
- load_dotenv()
154
-
155
- key1 = os.getenv("GEMINI_API_KEY_2")
156
- if not key1:
157
- print(
158
- "ERROR: GEMINI_API_KEY_2 is not set (required for embeddings).\n"
159
- "Add it to your .env: GEMINI_API_KEY_2=<your-gemini-key>",
160
- file=sys.stderr,
161
- )
162
- sys.exit(1)
163
-
164
- return [build_judge(key1)]
165
-
166
-
167
- # def score_batch_in_thread(batch, judge_llm, judge_embeddings, label=""):
168
- # """Commented out — was used for parallel dual-worker mode."""
169
- # ids = [r["id"] for r in batch]
170
- # _log(f"starting {ids}", label)
171
- # loop = asyncio.new_event_loop()
172
- # asyncio.set_event_loop(loop)
173
- # try:
174
- # return score_batch(batch, judge_llm, judge_embeddings, label=label)
175
- # finally:
176
- # loop.close()
177
-
178
-
179
- # ── Dataset helpers ────────────────────────────────────────────────────────────
180
-
181
- def load_dataset(path: Path) -> list[dict]:
182
- rows = []
183
- for line in path.read_text(encoding="utf-8").splitlines():
184
- line = line.strip()
185
- if line:
186
- rows.append(json.loads(line))
187
- return rows
188
-
189
-
190
- # ── Phase 1: graph invocation ──────────────────────────────────────────────────
191
 
192
  def invoke_graph(graph, row: dict) -> dict:
193
- """Run one golden row through the graph and return enriched result dict."""
194
  from civicsetu.models.enums import Jurisdiction
195
 
196
  jurisdiction = None
@@ -198,7 +98,7 @@ def invoke_graph(graph, row: dict) -> dict:
198
  try:
199
  jurisdiction = Jurisdiction(row["jurisdiction"])
200
  except ValueError:
201
- jurisdiction = None
202
 
203
  state = {
204
  "query": row["query"],
@@ -219,7 +119,7 @@ def invoke_graph(graph, row: dict) -> dict:
219
 
220
  start = time.perf_counter()
221
  try:
222
- result = graph.invoke(state)
223
  latency_ms = (time.perf_counter() - start) * 1000
224
  answer = result.get("raw_response") or ""
225
  reranked = result.get("reranked_chunks") or []
@@ -234,42 +134,24 @@ def invoke_graph(graph, row: dict) -> dict:
234
  confidence, query_type, error = 0.0, "error", str(exc)
235
 
236
  return {
237
- "id": row["id"],
238
- "jurisdiction": row["jurisdiction"],
239
- "query_type": row["query_type"],
240
- "query": row["query"],
241
- "ground_truth": row["ground_truth"],
242
- "answer": answer,
243
- "contexts": contexts,
244
- "citations_count": len(citations),
245
- "confidence_score": round(confidence, 3),
246
  "query_type_resolved": query_type,
247
- "latency_ms": round(latency_ms, 1),
248
- "error": error,
249
  }
250
 
251
 
252
- def run_phase1(rows: list[dict]) -> list[dict]:
253
- from civicsetu.agent.graph import get_compiled_graph
254
-
255
- print(f"\nPhase 1: Invoking graph for {len(rows)} queries...")
256
- graph = get_compiled_graph()
257
- invoked: list[dict] = []
258
- for i, row in enumerate(rows, 1):
259
- print(f" [{i:02}/{len(rows)}] {row['id']} ...", end=" ", flush=True)
260
- result = invoke_graph(graph, row)
261
- invoked.append(result)
262
- status = "OK" if result["answer"] else "EMPTY"
263
- print(f"{status} ({result['latency_ms']:.0f}ms, conf={result['confidence_score']})")
264
-
265
- PHASE1_OUT.write_text(json.dumps(invoked, indent=2, default=str), encoding="utf-8")
266
- print(f"\nPhase 1 complete — results saved to {PHASE1_OUT}")
267
- return invoked
268
-
269
 
270
- # ── Phase 2: RAGAS scoring ─────────────────────────────────────────────────────
271
-
272
- def _safe_metric(val, default: float = 0.0) -> float:
273
  try:
274
  f = float(val)
275
  return default if math.isnan(f) else f
@@ -277,54 +159,79 @@ def _safe_metric(val, default: float = 0.0) -> float:
277
  return default
278
 
279
 
280
- def score_row(row: dict, judge_llm, judge_embeddings) -> dict:
281
- """Score a single row with all three RAGAS metrics. Simple sequential API calls."""
282
  from ragas.metrics.collections import Faithfulness, AnswerRelevancy, ContextPrecision
283
 
284
- row = dict(row)
 
285
 
286
- if not row["answer"] or not row["contexts"]:
287
- row["faithfulness"] = row["answer_relevancy"] = row["context_precision"] = 0.0
288
- row["pass"] = False
289
- return row
290
 
291
  f_metric = Faithfulness(llm=judge_llm)
292
  ar_metric = AnswerRelevancy(llm=judge_llm, embeddings=judge_embeddings)
293
  cp_metric = ContextPrecision(llm=judge_llm)
294
 
295
- f_results = f_metric.batch_score([
296
- {"user_input": row["query"], "response": row["answer"], "retrieved_contexts": row["contexts"]}
 
 
 
297
  ])
 
 
 
 
298
  ar_results = ar_metric.batch_score([
299
- {"user_input": row["query"], "response": row["answer"]}
 
300
  ])
 
 
 
 
301
  cp_results = cp_metric.batch_score([
302
- {"user_input": row["query"], "reference": row["ground_truth"], "retrieved_contexts": row["contexts"]}
 
303
  ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
 
305
- row["faithfulness"] = round(_safe_metric(f_results[0].value), 3)
306
- row["answer_relevancy"] = round(_safe_metric(ar_results[0].value), 3)
307
- row["context_precision"] = round(_safe_metric(cp_results[0].value), 3)
308
- row["pass"] = (
309
- row["faithfulness"] >= PASS_THRESHOLD
310
- and row["answer_relevancy"] >= PASS_THRESHOLD
311
- and row["context_precision"] >= PASS_THRESHOLD
312
- )
313
- return row
 
314
 
315
- # def score_batch(batch, judge_llm, judge_embeddings, label=""):
316
- # """Commented out — was the batched+sleep scoring path for Gemini free-tier rate limits."""
317
- # ... (see git history)
318
 
 
319
 
320
- def compute_group_stats(rows: list[dict], key: str) -> dict:
321
  if not rows:
322
  return {}
323
- latencies = sorted(r["latency_ms"] for r in rows)
324
- n = len(latencies)
325
- p50 = (latencies[n // 2 - 1] + latencies[n // 2]) / 2.0 if n % 2 == 0 else latencies[n // 2]
326
- p90 = latencies[min(int(n * 0.9), n - 1)]
327
- p99 = latencies[min(int(n * 0.99), n - 1)]
328
  return {
329
  "faithfulness": round(sum(r["faithfulness"] for r in rows) / n, 3),
330
  "answer_relevancy": round(sum(r["answer_relevancy"] for r in rows) / n, 3),
@@ -332,15 +239,14 @@ def compute_group_stats(rows: list[dict], key: str) -> dict:
332
  "pass_rate": round(sum(1 for r in rows if r["pass"]) / n, 3),
333
  "p50_latency_ms": round(p50, 1),
334
  "p90_latency_ms": round(p90, 1),
335
- "p99_latency_ms": round(p99, 1),
336
  }
337
 
338
 
339
- def print_summary(all_rows: list[dict]) -> None:
340
- overall = compute_group_stats(all_rows, "overall")
341
- passed = sum(1 for r in all_rows if r["pass"])
342
  print("\n" + "=" * 72)
343
- print(f"RAGAS Evaluation Results ({len(all_rows)} queries, {passed} pass)")
344
  print("=" * 72)
345
  print(f" faithfulness : {overall['faithfulness']:.3f}")
346
  print(f" answer_relevancy : {overall['answer_relevancy']:.3f}")
@@ -349,85 +255,30 @@ def print_summary(all_rows: list[dict]) -> None:
349
  print(f" p50 latency : {overall['p50_latency_ms']:.0f} ms")
350
  print(f" p90 latency : {overall['p90_latency_ms']:.0f} ms")
351
  print()
352
- for jur in sorted({r["jurisdiction"] or "MULTI" for r in all_rows}):
353
- rows = [r for r in all_rows if (r["jurisdiction"] or "MULTI") == jur]
354
- stats = compute_group_stats(rows, jur)
355
- print(f" {jur:<20} faith={stats['faithfulness']:.2f} "
356
- f"rel={stats['answer_relevancy']:.2f} "
357
- f"prec={stats['context_precision']:.2f} "
358
- f"pass={stats['pass_rate']:.0%} p50={stats['p50_latency_ms']:.0f}ms")
359
- failures = [r for r in all_rows if not r["pass"]]
 
 
360
  if failures:
361
  print(f"\n Failures ({len(failures)}):")
362
  for r in failures:
363
  print(f" FAIL [{r['id']}] "
364
- f"faith={r['faithfulness']:.2f} rel={r['answer_relevancy']:.2f} "
365
- f"prec={r['context_precision']:.2f} err={r.get('error') or '-'}")
 
 
366
  print("=" * 72)
367
 
368
 
369
- def run_phase2(invoked: list[dict]) -> list[dict]:
370
- print(f"\nPhase 2: RAGAS scoring {len(invoked)} rows (sequential)...")
371
- judge_llm, judge_embeddings = build_judge_pool()[0]
372
- all_scored: list[dict] = []
373
-
374
- for i, row in enumerate(invoked, 1):
375
- _log(f"[{i:02}/{len(invoked)}] {row['id']}")
376
- try:
377
- scored = score_row(row, judge_llm, judge_embeddings)
378
- except Exception as exc:
379
- _log(f" FAILED ({type(exc).__name__}: {exc}) — skipping with zeros")
380
- scored = dict(row)
381
- scored["faithfulness"] = scored["answer_relevancy"] = scored["context_precision"] = 0.0
382
- scored["pass"] = False
383
- all_scored.append(scored)
384
- _log(f" faith={scored['faithfulness']:.2f} rel={scored['answer_relevancy']:.2f} prec={scored['context_precision']:.2f} {'PASS' if scored['pass'] else 'fail'}")
385
-
386
- # ── Commented out: batched + parallel + rate-limit-sleep mode ──────────────
387
- # batches = [invoked[i:i + BATCH_SIZE] for i in range(0, len(invoked), BATCH_SIZE)]
388
- # if num_workers == 1:
389
- # for batch_num, batch in enumerate(batches, 1):
390
- # scored = score_batch(batch, judge_llm, judge_embeddings, label=f"B{batch_num}")
391
- # all_scored.extend(scored)
392
- # _sleep_log(BATCH_DELAY_SEC)
393
- # else: # ThreadPoolExecutor dual-worker path — see git history
394
- # ───────────────────────────────────────────────────────────────────────────
395
-
396
- print_summary(all_scored)
397
-
398
- jurisdictions = sorted({r["jurisdiction"] or "MULTI" for r in all_scored})
399
- query_types = sorted({r["query_type"] for r in all_scored})
400
- report = {
401
- "run_at": datetime.now(timezone.utc).isoformat(),
402
- "dataset_size": len(all_scored),
403
- "mode": "sequential",
404
- "pass_threshold": PASS_THRESHOLD,
405
- "overall": compute_group_stats(all_scored, "overall"),
406
- "by_jurisdiction": {
407
- jur: compute_group_stats([r for r in all_scored if (r["jurisdiction"] or "MULTI") == jur], jur)
408
- for jur in jurisdictions
409
- },
410
- "by_query_type": {
411
- qt: compute_group_stats([r for r in all_scored if r["query_type"] == qt], qt)
412
- for qt in query_types
413
- },
414
- "rows": all_scored,
415
- }
416
- PHASE2_OUT.write_text(json.dumps(report, indent=2, default=str), encoding="utf-8")
417
- print(f"\nFull results → {PHASE2_OUT}")
418
- return all_scored
419
-
420
-
421
  # ── Entry point ────────────────────────────────────────────────────────────────
422
 
423
  def main() -> None:
424
- parser = argparse.ArgumentParser(description="CivicSetu RAGAS benchmark")
425
- parser.add_argument(
426
- "--phase", type=int, choices=[1, 2],
427
- help="1 = collect (graph invocations only), 2 = score (RAGAS only). Default: both."
428
- )
429
- args = parser.parse_args()
430
-
431
  if not DATASET_PATH.exists():
432
  print(f"ERROR: dataset not found at {DATASET_PATH}", file=sys.stderr)
433
  sys.exit(1)
@@ -436,28 +287,49 @@ def main() -> None:
436
  if EVAL_LIMIT:
437
  rows = rows[:EVAL_LIMIT]
438
 
439
- print(f"CivicSetu RAGAS Eval — {len(rows)} queries | sequential threshold={PASS_THRESHOLD}")
440
-
441
- if args.phase == 1:
442
- run_phase1(rows)
443
-
444
- elif args.phase == 2:
445
- if not PHASE1_OUT.exists():
446
- print(f"ERROR: {PHASE1_OUT} not found. Run Phase 1 first: make eval-collect", file=sys.stderr)
447
- sys.exit(1)
448
- invoked = json.loads(PHASE1_OUT.read_text(encoding="utf-8"))
449
- if EVAL_LIMIT:
450
- invoked = invoked[:EVAL_LIMIT]
451
- all_scored = run_phase2(invoked)
452
- if sum(1 for r in all_scored if not r["pass"]):
453
- sys.exit(1)
454
-
455
- else:
456
- # Both phases
457
- invoked = run_phase1(rows)
458
- all_scored = run_phase2(invoked)
459
- if sum(1 for r in all_scored if not r["pass"]):
460
- sys.exit(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461
 
462
 
463
  if __name__ == "__main__":
 
 
1
  """
2
+ CivicSetu RAGAS evaluation single pass, no phases.
3
 
4
+ 1. Load golden dataset
5
+ 2. Invoke RAG graph for every query (collect answers + contexts)
6
+ 3. Score all rows at once with RAGAS (3 batch_score calls total)
7
+ 4. Print summary + save eval_results.json
8
 
9
+ Usage:
10
+ uv run python scripts/run_eval.py
11
+ EVAL_LIMIT=5 uv run python scripts/run_eval.py # quick smoke-test
12
+ JUDGE_MODEL=qwen3.5-122b-a10b uv run python scripts/run_eval.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  """
14
  from __future__ import annotations
15
 
 
 
16
  import json
17
  import math
18
  import os
19
  import sys
20
  import time
21
  import io
 
22
  from datetime import datetime, timezone
23
  from pathlib import Path
24
 
 
28
 
29
  sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
30
 
31
+ ROOT = Path(__file__).parent.parent
32
+ DATASET_PATH = ROOT / "eval" / "golden_dataset.jsonl"
33
+ OUTPUT_PATH = ROOT / "eval_results.json"
 
 
34
 
35
+ PASS_THRESHOLD = float(os.getenv("PASS_THRESHOLD", "0.7"))
36
+ EVAL_LIMIT = int(os.getenv("EVAL_LIMIT", "0")) or None
37
+ JUDGE_MODEL = os.getenv("JUDGE_MODEL", "qwen3.5-122b-a10b")
 
 
38
 
 
 
 
 
39
 
40
+ # ── Dataset ────────────────────────────────────────────────────────────────────
41
 
42
+ def load_dataset(path: Path) -> list[dict]:
43
+ return [
44
+ json.loads(line)
45
+ for line in path.read_text(encoding="utf-8").splitlines()
46
+ if line.strip()
47
+ ]
 
 
 
 
 
 
48
 
49
 
50
+ # ── Judge setup ────────────────────────────────────────────────────────────���───
51
 
52
+ def build_judge():
53
  """
54
+ LLM : osmapi (OSMAPI_API_KEY) OpenAI-compatible endpoint.
55
+ Embed: Google GenAI (GEMINI_API_KEY_2) — needed for AnswerRelevancy.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  """
57
+ from dotenv import load_dotenv
58
+ load_dotenv()
59
+
60
  from openai import AsyncOpenAI
61
  from ragas.llms import llm_factory
62
  from ragas.embeddings import GoogleEmbeddings
63
+ from google import genai
64
 
65
  osmapi_key = os.getenv("OSMAPI_API_KEY")
66
  if not osmapi_key:
67
+ print("ERROR: OSMAPI_API_KEY not set in .env", file=sys.stderr)
68
+ sys.exit(1)
69
+
70
+ gemini_key = os.getenv("GEMINI_API_KEY_2")
71
+ if not gemini_key:
72
+ print("ERROR: GEMINI_API_KEY_2 not set in .env (needed for embeddings)", file=sys.stderr)
73
  sys.exit(1)
74
 
75
  llm_client = AsyncOpenAI(
76
  api_key=osmapi_key,
77
  base_url="https://api.osmapi.com/v1",
78
+ timeout=120.0,
79
+ )
80
+ judge_llm = llm_factory(JUDGE_MODEL, client=llm_client, max_tokens=8192)
81
+ judge_embeddings = GoogleEmbeddings(
82
+ client=genai.Client(api_key=gemini_key),
83
+ model="gemini-embedding-001",
84
  )
 
 
 
 
 
 
 
 
 
 
85
 
86
+ print(f" Judge LLM : osmapi / {JUDGE_MODEL}")
87
+ print(f" Embeddings : Google gemini-embedding-001")
88
  return judge_llm, judge_embeddings
89
 
90
 
91
+ # ── Graph invocation ───────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  def invoke_graph(graph, row: dict) -> dict:
 
94
  from civicsetu.models.enums import Jurisdiction
95
 
96
  jurisdiction = None
 
98
  try:
99
  jurisdiction = Jurisdiction(row["jurisdiction"])
100
  except ValueError:
101
+ pass
102
 
103
  state = {
104
  "query": row["query"],
 
119
 
120
  start = time.perf_counter()
121
  try:
122
+ result = graph.invoke(state)
123
  latency_ms = (time.perf_counter() - start) * 1000
124
  answer = result.get("raw_response") or ""
125
  reranked = result.get("reranked_chunks") or []
 
134
  confidence, query_type, error = 0.0, "error", str(exc)
135
 
136
  return {
137
+ "id": row["id"],
138
+ "jurisdiction": row["jurisdiction"],
139
+ "query_type": row["query_type"],
140
+ "query": row["query"],
141
+ "ground_truth": row["ground_truth"],
142
+ "answer": answer,
143
+ "contexts": contexts,
144
+ "citations_count": len(citations),
145
+ "confidence_score": round(confidence, 3),
146
  "query_type_resolved": query_type,
147
+ "latency_ms": round(latency_ms, 1),
148
+ "error": error,
149
  }
150
 
151
 
152
+ # ── RAGAS scoring ──────────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
+ def _safe(val, default: float = 0.0) -> float:
 
 
155
  try:
156
  f = float(val)
157
  return default if math.isnan(f) else f
 
159
  return default
160
 
161
 
162
+ def score_all(rows: list[dict], judge_llm, judge_embeddings) -> list[dict]:
163
+ """Score all rows at once 3 batch_score calls total (one per metric)."""
164
  from ragas.metrics.collections import Faithfulness, AnswerRelevancy, ContextPrecision
165
 
166
+ scoreable = [r for r in rows if r["answer"] and r["contexts"]]
167
+ skipped = [r for r in rows if not (r["answer"] and r["contexts"])]
168
 
169
+ if skipped:
170
+ print(f" Skipping {len(skipped)} rows (no answer/context): "
171
+ f"{[r['id'] for r in skipped]}")
 
172
 
173
  f_metric = Faithfulness(llm=judge_llm)
174
  ar_metric = AnswerRelevancy(llm=judge_llm, embeddings=judge_embeddings)
175
  cp_metric = ContextPrecision(llm=judge_llm)
176
 
177
+ print(f" faithfulness ({len(scoreable)} rows) ...", end=" ", flush=True)
178
+ t0 = time.perf_counter()
179
+ f_results = f_metric.batch_score([
180
+ {"user_input": r["query"], "response": r["answer"], "retrieved_contexts": r["contexts"]}
181
+ for r in scoreable
182
  ])
183
+ print(f"done ({time.perf_counter() - t0:.1f}s)")
184
+
185
+ print(f" answer_relevancy ...", end=" ", flush=True)
186
+ t0 = time.perf_counter()
187
  ar_results = ar_metric.batch_score([
188
+ {"user_input": r["query"], "response": r["answer"]}
189
+ for r in scoreable
190
  ])
191
+ print(f"done ({time.perf_counter() - t0:.1f}s)")
192
+
193
+ print(f" context_precision ...", end=" ", flush=True)
194
+ t0 = time.perf_counter()
195
  cp_results = cp_metric.batch_score([
196
+ {"user_input": r["query"], "reference": r["ground_truth"], "retrieved_contexts": r["contexts"]}
197
+ for r in scoreable
198
  ])
199
+ print(f"done ({time.perf_counter() - t0:.1f}s)")
200
+
201
+ scored_map: dict[str, dict] = {}
202
+ for row, f_r, ar_r, cp_r in zip(scoreable, f_results, ar_results, cp_results):
203
+ row = dict(row)
204
+ row["faithfulness"] = round(_safe(f_r.value), 3)
205
+ row["answer_relevancy"] = round(_safe(ar_r.value), 3)
206
+ row["context_precision"] = round(_safe(cp_r.value), 3)
207
+ row["pass"] = (
208
+ row["faithfulness"] >= PASS_THRESHOLD
209
+ and row["answer_relevancy"] >= PASS_THRESHOLD
210
+ and row["context_precision"] >= PASS_THRESHOLD
211
+ )
212
+ scored_map[row["id"]] = row
213
 
214
+ result = []
215
+ for row in rows:
216
+ if row["id"] not in scored_map:
217
+ row = dict(row)
218
+ row["faithfulness"] = row["answer_relevancy"] = row["context_precision"] = 0.0
219
+ row["pass"] = False
220
+ else:
221
+ row = scored_map[row["id"]]
222
+ result.append(row)
223
+ return result
224
 
 
 
 
225
 
226
+ # ── Summary ────────────────────────────────────────────────────────────────────
227
 
228
+ def _group_stats(rows: list[dict]) -> dict:
229
  if not rows:
230
  return {}
231
+ n = len(rows)
232
+ lat = sorted(r["latency_ms"] for r in rows)
233
+ p50 = (lat[n // 2 - 1] + lat[n // 2]) / 2.0 if n % 2 == 0 else lat[n // 2]
234
+ p90 = lat[min(int(n * 0.9), n - 1)]
 
235
  return {
236
  "faithfulness": round(sum(r["faithfulness"] for r in rows) / n, 3),
237
  "answer_relevancy": round(sum(r["answer_relevancy"] for r in rows) / n, 3),
 
239
  "pass_rate": round(sum(1 for r in rows if r["pass"]) / n, 3),
240
  "p50_latency_ms": round(p50, 1),
241
  "p90_latency_ms": round(p90, 1),
 
242
  }
243
 
244
 
245
+ def print_summary(rows: list[dict]) -> None:
246
+ overall = _group_stats(rows)
247
+ passed = sum(1 for r in rows if r["pass"])
248
  print("\n" + "=" * 72)
249
+ print(f"RAGAS Results ({len(rows)} queries, {passed} pass @ threshold={PASS_THRESHOLD})")
250
  print("=" * 72)
251
  print(f" faithfulness : {overall['faithfulness']:.3f}")
252
  print(f" answer_relevancy : {overall['answer_relevancy']:.3f}")
 
255
  print(f" p50 latency : {overall['p50_latency_ms']:.0f} ms")
256
  print(f" p90 latency : {overall['p90_latency_ms']:.0f} ms")
257
  print()
258
+
259
+ for jur in sorted({r["jurisdiction"] or "MULTI" for r in rows}):
260
+ jrows = [r for r in rows if (r["jurisdiction"] or "MULTI") == jur]
261
+ s = _group_stats(jrows)
262
+ print(f" {jur:<20} faith={s['faithfulness']:.2f} "
263
+ f"rel={s['answer_relevancy']:.2f} "
264
+ f"prec={s['context_precision']:.2f} "
265
+ f"pass={s['pass_rate']:.0%}")
266
+
267
+ failures = [r for r in rows if not r["pass"]]
268
  if failures:
269
  print(f"\n Failures ({len(failures)}):")
270
  for r in failures:
271
  print(f" FAIL [{r['id']}] "
272
+ f"faith={r['faithfulness']:.2f} "
273
+ f"rel={r['answer_relevancy']:.2f} "
274
+ f"prec={r['context_precision']:.2f} "
275
+ f"err={r.get('error') or '-'}")
276
  print("=" * 72)
277
 
278
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
  # ── Entry point ────────────────────────────────────────────────────────────────
280
 
281
  def main() -> None:
 
 
 
 
 
 
 
282
  if not DATASET_PATH.exists():
283
  print(f"ERROR: dataset not found at {DATASET_PATH}", file=sys.stderr)
284
  sys.exit(1)
 
287
  if EVAL_LIMIT:
288
  rows = rows[:EVAL_LIMIT]
289
 
290
+ print(f"CivicSetu RAGAS Eval — {len(rows)} queries | model={JUDGE_MODEL} | threshold={PASS_THRESHOLD}")
291
+
292
+ # ── Step 1: collect ────────────────────────────────────────────────────────
293
+ from civicsetu.agent.graph import get_compiled_graph
294
+ graph = get_compiled_graph()
295
+
296
+ print(f"\nStep 1/2 — invoking graph ({len(rows)} queries)...")
297
+ invoked: list[dict] = []
298
+ for i, row in enumerate(rows, 1):
299
+ print(f" [{i:02}/{len(rows)}] {row['id']} ...", end=" ", flush=True)
300
+ result = invoke_graph(graph, row)
301
+ invoked.append(result)
302
+ status = "OK" if result["answer"] else "EMPTY"
303
+ print(f"{status} ({result['latency_ms']:.0f}ms conf={result['confidence_score']})")
304
+
305
+ # ── Step 2: score ──────────────────────────────────────────────────────────
306
+ print(f"\nStep 2/2 — RAGAS scoring (3 batch calls)...")
307
+ judge_llm, judge_embeddings = build_judge()
308
+ scored = score_all(invoked, judge_llm, judge_embeddings)
309
+
310
+ # ── Save + print ───────────────────────────────────────────────────────────
311
+ print_summary(scored)
312
+
313
+ jurisdictions = sorted({r["jurisdiction"] or "MULTI" for r in scored})
314
+ query_types = sorted({r["query_type"] for r in scored})
315
+ report = {
316
+ "run_at": datetime.now(timezone.utc).isoformat(),
317
+ "dataset_size": len(scored),
318
+ "judge_model": JUDGE_MODEL,
319
+ "pass_threshold": PASS_THRESHOLD,
320
+ "overall": _group_stats(scored),
321
+ "by_jurisdiction": {
322
+ jur: _group_stats([r for r in scored if (r["jurisdiction"] or "MULTI") == jur])
323
+ for jur in jurisdictions
324
+ },
325
+ "by_query_type": {
326
+ qt: _group_stats([r for r in scored if r["query_type"] == qt])
327
+ for qt in query_types
328
+ },
329
+ "rows": scored,
330
+ }
331
+ OUTPUT_PATH.write_text(json.dumps(report, indent=2, default=str), encoding="utf-8")
332
+ print(f"\nFull results → {OUTPUT_PATH}")
333
 
334
 
335
  if __name__ == "__main__":