adeshboudh16 commited on
Commit
14e5cbf
·
1 Parent(s): 2139758

eval: add RAGAS benchmark runner script

Browse files
Files changed (1) hide show
  1. scripts/run_eval.py +339 -0
scripts/run_eval.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # scripts/run_eval.py
2
+ """
3
+ RAGAS offline benchmark for CivicSetu.
4
+
5
+ Calls graph.invoke() directly (no HTTP server required) to capture
6
+ reranked_chunks for context fields, then scores with RAGAS in batches
7
+ to stay within free-tier API rate limits.
8
+
9
+ Run:
10
+ uv run python scripts/run_eval.py
11
+
12
+ Tune rate limits:
13
+ BATCH_SIZE=3 BATCH_DELAY_SEC=60 uv run python scripts/run_eval.py
14
+
15
+ Limit rows for smoke testing:
16
+ EVAL_LIMIT=3 BATCH_SIZE=3 BATCH_DELAY_SEC=5 uv run python scripts/run_eval.py
17
+ """
18
+ from __future__ import annotations
19
+
20
+ import json
21
+ import os
22
+ import sys
23
+ import time
24
+ import io
25
+ from datetime import datetime, timezone
26
+ from pathlib import Path
27
+
28
+ if sys.stdout.encoding != "utf-8":
29
+ sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8")
30
+ sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8")
31
+
32
+ sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
33
+
34
+ # ── Rate-limit constants (override via env vars) ───────────────────────────────
35
+ BATCH_SIZE = int(os.getenv("BATCH_SIZE", "3"))
36
+ BATCH_DELAY_SEC = int(os.getenv("BATCH_DELAY_SEC", "60"))
37
+ PASS_THRESHOLD = float(os.getenv("PASS_THRESHOLD", "0.7"))
38
+ EVAL_LIMIT = int(os.getenv("EVAL_LIMIT", "0")) or None # 0 = no limit
39
+ # Judge model: set JUDGE_MODEL to override. Must be a google-generativeai model name
40
+ # (not LiteLLM prefix format). e.g. "gemini-2.0-flash-lite" for gemini-3.1-flash-lite-preview
41
+ JUDGE_MODEL = os.getenv("JUDGE_MODEL", "gemini-2.0-flash-lite")
42
+
43
+ # ── Imports ────────────────────────────────────────────────────────────────────
44
+ from civicsetu.agent.graph import get_compiled_graph
45
+ from civicsetu.models.enums import Jurisdiction
46
+
47
+ from datasets import Dataset
48
+ from ragas import evaluate
49
+ from ragas.metrics import answer_relevancy, context_precision, faithfulness
50
+ from ragas.llms import LangchainLLMWrapper
51
+ from ragas.embeddings import LangchainEmbeddingsWrapper
52
+ from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
53
+
54
+
55
+ def build_judge() -> tuple[LangchainLLMWrapper, LangchainEmbeddingsWrapper]:
56
+ """Build RAGAS judge LLM and embeddings from GEMINI_API_KEY_2 (separate key
57
+ to avoid rate-limit conflicts with the RAG system's own GEMINI_API_KEY)."""
58
+ api_key = os.environ["GEMINI_API_KEY_2"]
59
+ llm = LangchainLLMWrapper(
60
+ ChatGoogleGenerativeAI(
61
+ model=JUDGE_MODEL,
62
+ google_api_key=api_key,
63
+ temperature=1.0, # Gemini requires temperature >= 1.0
64
+ )
65
+ )
66
+ embeddings = LangchainEmbeddingsWrapper(
67
+ GoogleGenerativeAIEmbeddings(
68
+ model="models/embedding-001",
69
+ google_api_key=api_key,
70
+ )
71
+ )
72
+ return llm, embeddings
73
+
74
+
75
+ def load_dataset(path: Path) -> list[dict]:
76
+ rows = []
77
+ for line in path.read_text(encoding="utf-8").splitlines():
78
+ line = line.strip()
79
+ if line:
80
+ rows.append(json.loads(line))
81
+ return rows
82
+
83
+
84
+ def invoke_graph(graph, row: dict) -> dict:
85
+ """Run one golden row through the graph and return enriched result dict."""
86
+ jurisdiction = None
87
+ if row.get("jurisdiction"):
88
+ try:
89
+ jurisdiction = Jurisdiction(row["jurisdiction"])
90
+ except ValueError:
91
+ jurisdiction = None
92
+
93
+ state = {
94
+ "query": row["query"],
95
+ "jurisdiction_filter": jurisdiction,
96
+ "top_k": 5,
97
+ "session_id": f"eval_{row['id']}",
98
+ "retrieved_chunks": [],
99
+ "reranked_chunks": [],
100
+ "citations": [],
101
+ "confidence_score": 0.0,
102
+ "conflict_warnings": [],
103
+ "amendment_notice": None,
104
+ "retry_count": 0,
105
+ "hallucination_flag": False,
106
+ "error": None,
107
+ }
108
+
109
+ start = time.perf_counter()
110
+ try:
111
+ result = graph.invoke(state)
112
+ latency_ms = (time.perf_counter() - start) * 1000
113
+
114
+ answer = result.get("raw_response") or ""
115
+ reranked = result.get("reranked_chunks") or []
116
+ contexts = [rc.chunk.text for rc in reranked if rc.chunk.text]
117
+ citations = result.get("citations") or []
118
+ confidence = result.get("confidence_score") or 0.0
119
+ query_type = str(result.get("query_type") or "unknown")
120
+ error = result.get("error")
121
+
122
+ except Exception as exc:
123
+ latency_ms = (time.perf_counter() - start) * 1000
124
+ answer = ""
125
+ contexts = []
126
+ citations = []
127
+ confidence = 0.0
128
+ query_type = "error"
129
+ error = str(exc)
130
+
131
+ return {
132
+ "id": row["id"],
133
+ "jurisdiction": row["jurisdiction"],
134
+ "query_type": row["query_type"],
135
+ "query": row["query"],
136
+ "ground_truth": row["ground_truth"],
137
+ "answer": answer,
138
+ "contexts": contexts,
139
+ "citations_count": len(citations),
140
+ "confidence_score": round(confidence, 3),
141
+ "query_type_resolved": query_type,
142
+ "latency_ms": round(latency_ms, 1),
143
+ "error": error,
144
+ }
145
+
146
+
147
+ def score_batch(
148
+ batch: list[dict],
149
+ judge_llm: LangchainLLMWrapper,
150
+ judge_embeddings: LangchainEmbeddingsWrapper,
151
+ ) -> list[dict]:
152
+ """Run RAGAS on a batch and return rows annotated with metric scores."""
153
+ # Filter out errored rows — RAGAS needs non-empty answer and contexts
154
+ scoreable = [r for r in batch if r["answer"] and r["contexts"]]
155
+ skipped = [r for r in batch if not (r["answer"] and r["contexts"])]
156
+
157
+ scored = []
158
+ if scoreable:
159
+ ds = Dataset.from_list(
160
+ [
161
+ {
162
+ "question": r["query"],
163
+ "answer": r["answer"],
164
+ "contexts": r["contexts"],
165
+ "ground_truth": r["ground_truth"],
166
+ }
167
+ for r in scoreable
168
+ ]
169
+ )
170
+ result = evaluate(
171
+ ds,
172
+ metrics=[faithfulness, answer_relevancy, context_precision],
173
+ llm=judge_llm,
174
+ embeddings=judge_embeddings,
175
+ raise_exceptions=False,
176
+ )
177
+ result_df = result.to_pandas()
178
+
179
+ for row, (_, scores) in zip(scoreable, result_df.iterrows()):
180
+ row = dict(row)
181
+ row["faithfulness"] = round(float(scores.get("faithfulness", 0.0)), 3)
182
+ row["answer_relevancy"] = round(float(scores.get("answer_relevancy", 0.0)), 3)
183
+ row["context_precision"] = round(float(scores.get("context_precision", 0.0)), 3)
184
+ row["pass"] = (
185
+ row["faithfulness"] >= PASS_THRESHOLD
186
+ and row["answer_relevancy"] >= PASS_THRESHOLD
187
+ and row["context_precision"] >= PASS_THRESHOLD
188
+ )
189
+ scored.append(row)
190
+
191
+ for row in skipped:
192
+ row = dict(row)
193
+ row["faithfulness"] = 0.0
194
+ row["answer_relevancy"] = 0.0
195
+ row["context_precision"] = 0.0
196
+ row["pass"] = False
197
+ scored.append(row)
198
+
199
+ return scored
200
+
201
+
202
+ def compute_group_stats(rows: list[dict], key: str) -> dict:
203
+ """Aggregate RAGAS metrics and latency for a group of rows."""
204
+ if not rows:
205
+ return {}
206
+ latencies = [r["latency_ms"] for r in rows]
207
+ latencies_sorted = sorted(latencies)
208
+ n = len(latencies_sorted)
209
+ p50 = latencies_sorted[n // 2]
210
+ p90 = latencies_sorted[min(int(n * 0.9), n - 1)]
211
+ p99 = latencies_sorted[min(int(n * 0.99), n - 1)]
212
+ return {
213
+ "faithfulness": round(sum(r["faithfulness"] for r in rows) / n, 3),
214
+ "answer_relevancy": round(sum(r["answer_relevancy"] for r in rows) / n, 3),
215
+ "context_precision": round(sum(r["context_precision"] for r in rows) / n, 3),
216
+ "pass_rate": round(sum(1 for r in rows if r["pass"]) / n, 3),
217
+ "p50_latency_ms": round(p50, 1),
218
+ "p90_latency_ms": round(p90, 1),
219
+ "p99_latency_ms": round(p99, 1),
220
+ }
221
+
222
+
223
+ def print_summary(all_rows: list[dict]) -> None:
224
+ overall = compute_group_stats(all_rows, "overall")
225
+ passed = sum(1 for r in all_rows if r["pass"])
226
+ print("\n" + "=" * 72)
227
+ print(f"RAGAS Evaluation Results ({len(all_rows)} queries, {passed} pass)")
228
+ print("=" * 72)
229
+ print(f" faithfulness : {overall['faithfulness']:.3f}")
230
+ print(f" answer_relevancy : {overall['answer_relevancy']:.3f}")
231
+ print(f" context_precision : {overall['context_precision']:.3f}")
232
+ print(f" pass_rate : {overall['pass_rate']:.1%}")
233
+ print(f" p50 latency : {overall['p50_latency_ms']:.0f} ms")
234
+ print(f" p90 latency : {overall['p90_latency_ms']:.0f} ms")
235
+ print()
236
+ # Per-jurisdiction
237
+ jurisdictions = sorted({r["jurisdiction"] or "MULTI" for r in all_rows})
238
+ print(" By jurisdiction:")
239
+ for jur in jurisdictions:
240
+ rows = [r for r in all_rows if (r["jurisdiction"] or "MULTI") == jur]
241
+ stats = compute_group_stats(rows, jur)
242
+ print(f" {jur:<20} faith={stats['faithfulness']:.2f} "
243
+ f"rel={stats['answer_relevancy']:.2f} "
244
+ f"prec={stats['context_precision']:.2f} "
245
+ f"pass={stats['pass_rate']:.0%} p50={stats['p50_latency_ms']:.0f}ms")
246
+ print()
247
+ # Failures
248
+ failures = [r for r in all_rows if not r["pass"]]
249
+ if failures:
250
+ print(f" Failures ({len(failures)}):")
251
+ for r in failures:
252
+ print(f" FAIL [{r['id']}] "
253
+ f"faith={r['faithfulness']:.2f} "
254
+ f"rel={r['answer_relevancy']:.2f} "
255
+ f"prec={r['context_precision']:.2f} "
256
+ f"err={r.get('error') or '-'}")
257
+ print("=" * 72)
258
+
259
+
260
+ def main() -> None:
261
+ dataset_path = Path(__file__).parent.parent / "eval" / "golden_dataset.jsonl"
262
+ if not dataset_path.exists():
263
+ print(f"ERROR: dataset not found at {dataset_path}", file=sys.stderr)
264
+ sys.exit(1)
265
+
266
+ rows = load_dataset(dataset_path)
267
+ if EVAL_LIMIT:
268
+ rows = rows[:EVAL_LIMIT]
269
+ print(f"CivicSetu RAGAS Eval — {len(rows)} queries, batch_size={BATCH_SIZE}, delay={BATCH_DELAY_SEC}s")
270
+
271
+ # Phase 1: run all queries through graph (no rate-limit needed here)
272
+ print("\nPhase 1: Invoking graph for all queries...")
273
+ graph = get_compiled_graph()
274
+ invoked: list[dict] = []
275
+ for i, row in enumerate(rows, 1):
276
+ print(f" [{i:02}/{len(rows)}] {row['id']} ...", end=" ", flush=True)
277
+ result = invoke_graph(graph, row)
278
+ invoked.append(result)
279
+ status = "OK" if result["answer"] else "EMPTY"
280
+ print(f"{status} ({result['latency_ms']:.0f}ms, conf={result['confidence_score']})")
281
+
282
+ # Phase 2: RAGAS scoring in batches
283
+ print("\nPhase 2: RAGAS scoring in batches...")
284
+ judge_llm, judge_embeddings = build_judge()
285
+ all_scored: list[dict] = []
286
+
287
+ batches = [invoked[i:i + BATCH_SIZE] for i in range(0, len(invoked), BATCH_SIZE)]
288
+ for batch_num, batch in enumerate(batches, 1):
289
+ ids = [r["id"] for r in batch]
290
+ print(f" Batch {batch_num}/{len(batches)}: {ids} ...", end=" ", flush=True)
291
+ scored = score_batch(batch, judge_llm, judge_embeddings)
292
+ all_scored.extend(scored)
293
+ print("done")
294
+ if batch_num < len(batches):
295
+ print(f" Sleeping {BATCH_DELAY_SEC}s before next batch...")
296
+ time.sleep(BATCH_DELAY_SEC)
297
+
298
+ # Phase 3: write results
299
+ print_summary(all_scored)
300
+
301
+ # Build structured report
302
+ jurisdictions = sorted({r["jurisdiction"] or "MULTI" for r in all_scored})
303
+ query_types = sorted({r["query_type"] for r in all_scored})
304
+
305
+ report = {
306
+ "run_at": datetime.now(timezone.utc).isoformat(),
307
+ "dataset_size": len(all_scored),
308
+ "batch_size": BATCH_SIZE,
309
+ "batch_delay_sec": BATCH_DELAY_SEC,
310
+ "pass_threshold": PASS_THRESHOLD,
311
+ "overall": compute_group_stats(all_scored, "overall"),
312
+ "by_jurisdiction": {
313
+ jur: compute_group_stats(
314
+ [r for r in all_scored if (r["jurisdiction"] or "MULTI") == jur], jur
315
+ )
316
+ for jur in jurisdictions
317
+ },
318
+ "by_query_type": {
319
+ qt: compute_group_stats(
320
+ [r for r in all_scored if r["query_type"] == qt], qt
321
+ )
322
+ for qt in query_types
323
+ },
324
+ "rows": all_scored,
325
+ }
326
+
327
+ out = Path("eval_results.json")
328
+ out.write_text(json.dumps(report, indent=2, default=str), encoding="utf-8")
329
+ print(f"\nFull results → {out}")
330
+
331
+ # Exit 1 if any failures
332
+ failures = sum(1 for r in all_scored if not r["pass"])
333
+ if failures:
334
+ print(f"{failures} row(s) below pass threshold ({PASS_THRESHOLD})")
335
+ sys.exit(1)
336
+
337
+
338
+ if __name__ == "__main__":
339
+ main()