muthuk1 commited on
Commit
0117576
Β·
verified Β·
1 Parent(s): 6488963

Add Layer 4: Evaluation Layer (RAGAS + custom F1/EM metrics + benchmarking)

Browse files
Files changed (1) hide show
  1. graphrag/layers/evaluation_layer.py +268 -0
graphrag/layers/evaluation_layer.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Layer 4: Evaluation Layer β€” RAGAS + Custom Metrics + Benchmarking
3
+ =================================================================
4
+ Computes faithfulness, answer relevancy, context precision/recall,
5
+ F1, exact match, and cost efficiency metrics.
6
+ """
7
+ import logging
8
+ import re
9
+ import string
10
+ from collections import Counter
11
+ from dataclasses import dataclass, field
12
+ from typing import Any, Dict, List
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ # ── Custom Metrics (No LLM dependency) ────────────────────
18
+
19
+ def normalize_answer(s: str) -> str:
20
+ """SQuAD/HotpotQA standard answer normalization."""
21
+ def remove_articles(t): return re.sub(r'\b(a|an|the)\b', ' ', t)
22
+ def white_space_fix(t): return ' '.join(t.split())
23
+ def remove_punc(t): return ''.join(ch for ch in t if ch not in string.punctuation)
24
+ return white_space_fix(remove_articles(remove_punc(s.lower())))
25
+
26
+
27
+ def compute_exact_match(prediction: str, ground_truth: str) -> float:
28
+ """Exact match after normalization."""
29
+ return float(normalize_answer(prediction) == normalize_answer(ground_truth))
30
+
31
+
32
+ def compute_f1(prediction: str, ground_truth: str) -> float:
33
+ """Token-level F1 score (SQuAD/HotpotQA standard)."""
34
+ pred_tokens = normalize_answer(prediction).split()
35
+ gold_tokens = normalize_answer(ground_truth).split()
36
+ if not pred_tokens and not gold_tokens: return 1.0
37
+ if not pred_tokens or not gold_tokens: return 0.0
38
+ common = Counter(pred_tokens) & Counter(gold_tokens)
39
+ num_same = sum(common.values())
40
+ if num_same == 0: return 0.0
41
+ precision = num_same / len(pred_tokens)
42
+ recall = num_same / len(gold_tokens)
43
+ return (2 * precision * recall) / (precision + recall)
44
+
45
+
46
+ def compute_context_hit_rate(contexts: List[str], facts: List[str]) -> float:
47
+ """Fraction of supporting facts found in retrieved contexts."""
48
+ if not facts: return 0.0
49
+ combined = " ".join(contexts).lower()
50
+ return sum(1 for f in facts if f.lower() in combined) / len(facts)
51
+
52
+
53
+ def compute_token_efficiency(baseline_tokens: int, graphrag_tokens: int) -> float:
54
+ """Token efficiency ratio: <1 means GraphRAG uses fewer tokens."""
55
+ return graphrag_tokens / baseline_tokens if baseline_tokens > 0 else 0.0
56
+
57
+
58
+ # ── Data Structures ───────────────────────────────────────
59
+
60
+ @dataclass
61
+ class EvalSample:
62
+ """Single evaluation sample."""
63
+ query: str = ""
64
+ reference_answer: str = ""
65
+ baseline_answer: str = ""
66
+ graphrag_answer: str = ""
67
+ baseline_contexts: List[str] = field(default_factory=list)
68
+ graphrag_contexts: List[str] = field(default_factory=list)
69
+ question_type: str = ""
70
+ difficulty: str = ""
71
+ supporting_facts: List[str] = field(default_factory=list)
72
+
73
+
74
+ @dataclass
75
+ class EvalResult:
76
+ """Evaluation result for a single sample."""
77
+ query: str = ""
78
+ baseline_f1: float = 0.0
79
+ graphrag_f1: float = 0.0
80
+ baseline_em: float = 0.0
81
+ graphrag_em: float = 0.0
82
+ baseline_context_hit: float = 0.0
83
+ graphrag_context_hit: float = 0.0
84
+ baseline_faithfulness: float = 0.0
85
+ graphrag_faithfulness: float = 0.0
86
+ baseline_relevancy: float = 0.0
87
+ graphrag_relevancy: float = 0.0
88
+ baseline_context_precision: float = 0.0
89
+ graphrag_context_precision: float = 0.0
90
+ baseline_context_recall: float = 0.0
91
+ graphrag_context_recall: float = 0.0
92
+ baseline_tokens: int = 0
93
+ graphrag_tokens: int = 0
94
+ baseline_cost: float = 0.0
95
+ graphrag_cost: float = 0.0
96
+ baseline_latency: float = 0.0
97
+ graphrag_latency: float = 0.0
98
+ question_type: str = ""
99
+ difficulty: str = ""
100
+
101
+
102
+ # ── Evaluation Layer ──────────────────────────────────────
103
+
104
+ class EvaluationLayer:
105
+ """
106
+ Layer 4: Evaluation Layer.
107
+ Computes all metrics and generates benchmark reports.
108
+ """
109
+
110
+ def __init__(self, eval_llm_model="gpt-4o-mini", api_key=""):
111
+ self.eval_llm_model = eval_llm_model
112
+ self._api_key = api_key
113
+ self._ragas_available = False
114
+ self.results: List[EvalResult] = []
115
+
116
+ def initialize(self):
117
+ """Initialize RAGAS components if available."""
118
+ try:
119
+ from ragas import evaluate, EvaluationDataset, SingleTurnSample
120
+ from ragas.metrics import Faithfulness, AnswerRelevancy
121
+ self._ragas_available = True
122
+ logger.info("RAGAS evaluation available.")
123
+ except ImportError:
124
+ logger.warning("RAGAS not installed β€” using custom metrics only.")
125
+
126
+ def evaluate_sample(self, sample: EvalSample,
127
+ baseline_tokens=0, graphrag_tokens=0,
128
+ baseline_cost=0.0, graphrag_cost=0.0,
129
+ baseline_latency=0.0, graphrag_latency=0.0) -> EvalResult:
130
+ """Evaluate a single sample with all metrics."""
131
+ r = EvalResult(
132
+ query=sample.query,
133
+ question_type=sample.question_type,
134
+ difficulty=sample.difficulty,
135
+ baseline_f1=compute_f1(sample.baseline_answer, sample.reference_answer),
136
+ graphrag_f1=compute_f1(sample.graphrag_answer, sample.reference_answer),
137
+ baseline_em=compute_exact_match(sample.baseline_answer, sample.reference_answer),
138
+ graphrag_em=compute_exact_match(sample.graphrag_answer, sample.reference_answer),
139
+ baseline_context_hit=compute_context_hit_rate(
140
+ sample.baseline_contexts, sample.supporting_facts),
141
+ graphrag_context_hit=compute_context_hit_rate(
142
+ sample.graphrag_contexts, sample.supporting_facts),
143
+ baseline_tokens=baseline_tokens, graphrag_tokens=graphrag_tokens,
144
+ baseline_cost=baseline_cost, graphrag_cost=graphrag_cost,
145
+ baseline_latency=baseline_latency, graphrag_latency=graphrag_latency,
146
+ )
147
+ self.results.append(r)
148
+ return r
149
+
150
+ def evaluate_batch_ragas(self, samples: List[EvalSample], pipeline="baseline") -> Dict[str, float]:
151
+ """Run RAGAS evaluation on a batch (requires RAGAS + OpenAI key)."""
152
+ if not self._ragas_available:
153
+ return {}
154
+ try:
155
+ from ragas import evaluate, EvaluationDataset, SingleTurnSample
156
+ from ragas.metrics import (Faithfulness, AnswerRelevancy,
157
+ LLMContextPrecisionWithReference, LLMContextRecall)
158
+ from ragas.llms import LangchainLLMWrapper
159
+ from ragas.embeddings import LangchainEmbeddingsWrapper
160
+ from langchain_openai import ChatOpenAI, OpenAIEmbeddings
161
+ import os
162
+
163
+ key = self._api_key or os.getenv("OPENAI_API_KEY", "")
164
+ llm = LangchainLLMWrapper(ChatOpenAI(model=self.eval_llm_model, api_key=key))
165
+ emb = LangchainEmbeddingsWrapper(OpenAIEmbeddings(api_key=key))
166
+
167
+ ragas_samples = []
168
+ for s in samples:
169
+ answer = s.baseline_answer if pipeline == "baseline" else s.graphrag_answer
170
+ ctxs = s.baseline_contexts if pipeline == "baseline" else s.graphrag_contexts
171
+ if answer and ctxs:
172
+ ragas_samples.append(SingleTurnSample(
173
+ user_input=s.query, response=answer,
174
+ retrieved_contexts=ctxs, reference=s.reference_answer))
175
+
176
+ if not ragas_samples: return {}
177
+ dataset = EvaluationDataset(samples=ragas_samples)
178
+ metrics = [Faithfulness(llm=llm), AnswerRelevancy(llm=llm, embeddings=emb),
179
+ LLMContextPrecisionWithReference(llm=llm), LLMContextRecall(llm=llm)]
180
+ return dict(evaluate(dataset=dataset, metrics=metrics))
181
+ except Exception as e:
182
+ logger.error(f"RAGAS evaluation failed: {e}")
183
+ return {}
184
+
185
+ def compute_aggregate_metrics(self) -> Dict[str, Any]:
186
+ """Compute aggregate metrics across all evaluated samples."""
187
+ if not self.results: return {"message": "No results"}
188
+ n = len(self.results)
189
+ avg = lambda vals: sum(vals) / len(vals) if vals else 0.0
190
+
191
+ b = {
192
+ "avg_f1": round(avg([r.baseline_f1 for r in self.results]), 4),
193
+ "avg_em": round(avg([r.baseline_em for r in self.results]), 4),
194
+ "avg_context_hit": round(avg([r.baseline_context_hit for r in self.results]), 4),
195
+ "avg_tokens": round(avg([r.baseline_tokens for r in self.results]), 1),
196
+ "avg_cost": round(avg([r.baseline_cost for r in self.results]), 6),
197
+ "avg_latency_ms": round(avg([r.baseline_latency for r in self.results]), 1),
198
+ "total_tokens": sum(r.baseline_tokens for r in self.results),
199
+ "total_cost": round(sum(r.baseline_cost for r in self.results), 6),
200
+ }
201
+ g = {
202
+ "avg_f1": round(avg([r.graphrag_f1 for r in self.results]), 4),
203
+ "avg_em": round(avg([r.graphrag_em for r in self.results]), 4),
204
+ "avg_context_hit": round(avg([r.graphrag_context_hit for r in self.results]), 4),
205
+ "avg_tokens": round(avg([r.graphrag_tokens for r in self.results]), 1),
206
+ "avg_cost": round(avg([r.graphrag_cost for r in self.results]), 6),
207
+ "avg_latency_ms": round(avg([r.graphrag_latency for r in self.results]), 1),
208
+ "total_tokens": sum(r.graphrag_tokens for r in self.results),
209
+ "total_cost": round(sum(r.graphrag_cost for r in self.results), 6),
210
+ }
211
+
212
+ win_rate = sum(1 for r in self.results if r.graphrag_f1 > r.baseline_f1) / n
213
+
214
+ by_type = {}
215
+ for r in self.results:
216
+ qt = r.question_type or "unknown"
217
+ by_type.setdefault(qt, {"baseline_f1": [], "graphrag_f1": [], "count": 0})
218
+ by_type[qt]["baseline_f1"].append(r.baseline_f1)
219
+ by_type[qt]["graphrag_f1"].append(r.graphrag_f1)
220
+ by_type[qt]["count"] += 1
221
+
222
+ return {
223
+ "num_samples": n, "baseline": b, "graphrag": g,
224
+ "graphrag_f1_win_rate": round(win_rate, 4),
225
+ "token_ratio": round(g["total_tokens"] / max(b["total_tokens"], 1), 3),
226
+ "by_question_type": {
227
+ qt: {"count": d["count"],
228
+ "baseline_avg_f1": round(avg(d["baseline_f1"]), 4),
229
+ "graphrag_avg_f1": round(avg(d["graphrag_f1"]), 4)}
230
+ for qt, d in by_type.items()
231
+ }
232
+ }
233
+
234
+ def generate_report(self) -> str:
235
+ """Generate a text benchmark report."""
236
+ m = self.compute_aggregate_metrics()
237
+ if "message" in m: return m["message"]
238
+ lines = [
239
+ "=" * 60, "GRAPHRAG INFERENCE BENCHMARK REPORT", "=" * 60,
240
+ f"\nTotal Samples Evaluated: {m['num_samples']}",
241
+ f"\n{'Metric':<25} {'Baseline':>12} {'GraphRAG':>12} {'Winner':>12}",
242
+ "-" * 65
243
+ ]
244
+ b, g = m["baseline"], m["graphrag"]
245
+ for name, key in [("Avg F1 Score", "avg_f1"), ("Avg Exact Match", "avg_em"),
246
+ ("Avg Context Hit Rate", "avg_context_hit")]:
247
+ bv, gv = b[key], g[key]
248
+ winner = "GraphRAG" if gv > bv else ("Baseline" if bv > gv else "Tie")
249
+ lines.append(f"{name:<25} {bv:>12.4f} {gv:>12.4f} {winner:>12}")
250
+
251
+ lines.append(f"\n{'Metric':<25} {'Baseline':>12} {'GraphRAG':>12} {'Ratio':>12}")
252
+ lines.append("-" * 65)
253
+ for name, key in [("Avg Tokens/Query", "avg_tokens"), ("Avg Cost ($)", "avg_cost"),
254
+ ("Avg Latency (ms)", "avg_latency_ms")]:
255
+ bv, gv = b[key], g[key]
256
+ ratio = gv / bv if bv > 0 else 0
257
+ lines.append(f"{name:<25} {bv:>12.4f} {gv:>12.4f} {ratio:>11.2f}x")
258
+
259
+ lines.append(f"\nGraphRAG F1 Win Rate: {m['graphrag_f1_win_rate']:.1%}")
260
+ lines.append(f"Token Ratio (G/B): {m['token_ratio']:.2f}x")
261
+
262
+ if m.get("by_question_type"):
263
+ lines.extend(["\n--- By Question Type ---",
264
+ f"{'Type':<20} {'Count':>6} {'Base F1':>10} {'Graph F1':>10}", "-" * 50])
265
+ for qt, d in m["by_question_type"].items():
266
+ lines.append(f"{qt:<20} {d['count']:>6} {d['baseline_avg_f1']:>10.4f} {d['graphrag_avg_f1']:>10.4f}")
267
+ lines.append("\n" + "=" * 60)
268
+ return "\n".join(lines)