Add Layer 4: Evaluation Layer (RAGAS + custom F1/EM metrics + benchmarking)
Browse files
graphrag/layers/evaluation_layer.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Layer 4: Evaluation Layer β RAGAS + Custom Metrics + Benchmarking
|
| 3 |
+
=================================================================
|
| 4 |
+
Computes faithfulness, answer relevancy, context precision/recall,
|
| 5 |
+
F1, exact match, and cost efficiency metrics.
|
| 6 |
+
"""
|
| 7 |
+
import logging
|
| 8 |
+
import re
|
| 9 |
+
import string
|
| 10 |
+
from collections import Counter
|
| 11 |
+
from dataclasses import dataclass, field
|
| 12 |
+
from typing import Any, Dict, List
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# ββ Custom Metrics (No LLM dependency) ββββββββββββββββββββ
|
| 18 |
+
|
| 19 |
+
def normalize_answer(s: str) -> str:
|
| 20 |
+
"""SQuAD/HotpotQA standard answer normalization."""
|
| 21 |
+
def remove_articles(t): return re.sub(r'\b(a|an|the)\b', ' ', t)
|
| 22 |
+
def white_space_fix(t): return ' '.join(t.split())
|
| 23 |
+
def remove_punc(t): return ''.join(ch for ch in t if ch not in string.punctuation)
|
| 24 |
+
return white_space_fix(remove_articles(remove_punc(s.lower())))
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def compute_exact_match(prediction: str, ground_truth: str) -> float:
|
| 28 |
+
"""Exact match after normalization."""
|
| 29 |
+
return float(normalize_answer(prediction) == normalize_answer(ground_truth))
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def compute_f1(prediction: str, ground_truth: str) -> float:
|
| 33 |
+
"""Token-level F1 score (SQuAD/HotpotQA standard)."""
|
| 34 |
+
pred_tokens = normalize_answer(prediction).split()
|
| 35 |
+
gold_tokens = normalize_answer(ground_truth).split()
|
| 36 |
+
if not pred_tokens and not gold_tokens: return 1.0
|
| 37 |
+
if not pred_tokens or not gold_tokens: return 0.0
|
| 38 |
+
common = Counter(pred_tokens) & Counter(gold_tokens)
|
| 39 |
+
num_same = sum(common.values())
|
| 40 |
+
if num_same == 0: return 0.0
|
| 41 |
+
precision = num_same / len(pred_tokens)
|
| 42 |
+
recall = num_same / len(gold_tokens)
|
| 43 |
+
return (2 * precision * recall) / (precision + recall)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def compute_context_hit_rate(contexts: List[str], facts: List[str]) -> float:
|
| 47 |
+
"""Fraction of supporting facts found in retrieved contexts."""
|
| 48 |
+
if not facts: return 0.0
|
| 49 |
+
combined = " ".join(contexts).lower()
|
| 50 |
+
return sum(1 for f in facts if f.lower() in combined) / len(facts)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def compute_token_efficiency(baseline_tokens: int, graphrag_tokens: int) -> float:
|
| 54 |
+
"""Token efficiency ratio: <1 means GraphRAG uses fewer tokens."""
|
| 55 |
+
return graphrag_tokens / baseline_tokens if baseline_tokens > 0 else 0.0
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# ββ Data Structures βββββββββββββββββββββββββββββββββββββββ
|
| 59 |
+
|
| 60 |
+
@dataclass
|
| 61 |
+
class EvalSample:
|
| 62 |
+
"""Single evaluation sample."""
|
| 63 |
+
query: str = ""
|
| 64 |
+
reference_answer: str = ""
|
| 65 |
+
baseline_answer: str = ""
|
| 66 |
+
graphrag_answer: str = ""
|
| 67 |
+
baseline_contexts: List[str] = field(default_factory=list)
|
| 68 |
+
graphrag_contexts: List[str] = field(default_factory=list)
|
| 69 |
+
question_type: str = ""
|
| 70 |
+
difficulty: str = ""
|
| 71 |
+
supporting_facts: List[str] = field(default_factory=list)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
@dataclass
|
| 75 |
+
class EvalResult:
|
| 76 |
+
"""Evaluation result for a single sample."""
|
| 77 |
+
query: str = ""
|
| 78 |
+
baseline_f1: float = 0.0
|
| 79 |
+
graphrag_f1: float = 0.0
|
| 80 |
+
baseline_em: float = 0.0
|
| 81 |
+
graphrag_em: float = 0.0
|
| 82 |
+
baseline_context_hit: float = 0.0
|
| 83 |
+
graphrag_context_hit: float = 0.0
|
| 84 |
+
baseline_faithfulness: float = 0.0
|
| 85 |
+
graphrag_faithfulness: float = 0.0
|
| 86 |
+
baseline_relevancy: float = 0.0
|
| 87 |
+
graphrag_relevancy: float = 0.0
|
| 88 |
+
baseline_context_precision: float = 0.0
|
| 89 |
+
graphrag_context_precision: float = 0.0
|
| 90 |
+
baseline_context_recall: float = 0.0
|
| 91 |
+
graphrag_context_recall: float = 0.0
|
| 92 |
+
baseline_tokens: int = 0
|
| 93 |
+
graphrag_tokens: int = 0
|
| 94 |
+
baseline_cost: float = 0.0
|
| 95 |
+
graphrag_cost: float = 0.0
|
| 96 |
+
baseline_latency: float = 0.0
|
| 97 |
+
graphrag_latency: float = 0.0
|
| 98 |
+
question_type: str = ""
|
| 99 |
+
difficulty: str = ""
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# ββ Evaluation Layer ββββββββββββββββββββββββββββββββββββββ
|
| 103 |
+
|
| 104 |
+
class EvaluationLayer:
|
| 105 |
+
"""
|
| 106 |
+
Layer 4: Evaluation Layer.
|
| 107 |
+
Computes all metrics and generates benchmark reports.
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
def __init__(self, eval_llm_model="gpt-4o-mini", api_key=""):
|
| 111 |
+
self.eval_llm_model = eval_llm_model
|
| 112 |
+
self._api_key = api_key
|
| 113 |
+
self._ragas_available = False
|
| 114 |
+
self.results: List[EvalResult] = []
|
| 115 |
+
|
| 116 |
+
def initialize(self):
|
| 117 |
+
"""Initialize RAGAS components if available."""
|
| 118 |
+
try:
|
| 119 |
+
from ragas import evaluate, EvaluationDataset, SingleTurnSample
|
| 120 |
+
from ragas.metrics import Faithfulness, AnswerRelevancy
|
| 121 |
+
self._ragas_available = True
|
| 122 |
+
logger.info("RAGAS evaluation available.")
|
| 123 |
+
except ImportError:
|
| 124 |
+
logger.warning("RAGAS not installed β using custom metrics only.")
|
| 125 |
+
|
| 126 |
+
def evaluate_sample(self, sample: EvalSample,
|
| 127 |
+
baseline_tokens=0, graphrag_tokens=0,
|
| 128 |
+
baseline_cost=0.0, graphrag_cost=0.0,
|
| 129 |
+
baseline_latency=0.0, graphrag_latency=0.0) -> EvalResult:
|
| 130 |
+
"""Evaluate a single sample with all metrics."""
|
| 131 |
+
r = EvalResult(
|
| 132 |
+
query=sample.query,
|
| 133 |
+
question_type=sample.question_type,
|
| 134 |
+
difficulty=sample.difficulty,
|
| 135 |
+
baseline_f1=compute_f1(sample.baseline_answer, sample.reference_answer),
|
| 136 |
+
graphrag_f1=compute_f1(sample.graphrag_answer, sample.reference_answer),
|
| 137 |
+
baseline_em=compute_exact_match(sample.baseline_answer, sample.reference_answer),
|
| 138 |
+
graphrag_em=compute_exact_match(sample.graphrag_answer, sample.reference_answer),
|
| 139 |
+
baseline_context_hit=compute_context_hit_rate(
|
| 140 |
+
sample.baseline_contexts, sample.supporting_facts),
|
| 141 |
+
graphrag_context_hit=compute_context_hit_rate(
|
| 142 |
+
sample.graphrag_contexts, sample.supporting_facts),
|
| 143 |
+
baseline_tokens=baseline_tokens, graphrag_tokens=graphrag_tokens,
|
| 144 |
+
baseline_cost=baseline_cost, graphrag_cost=graphrag_cost,
|
| 145 |
+
baseline_latency=baseline_latency, graphrag_latency=graphrag_latency,
|
| 146 |
+
)
|
| 147 |
+
self.results.append(r)
|
| 148 |
+
return r
|
| 149 |
+
|
| 150 |
+
def evaluate_batch_ragas(self, samples: List[EvalSample], pipeline="baseline") -> Dict[str, float]:
|
| 151 |
+
"""Run RAGAS evaluation on a batch (requires RAGAS + OpenAI key)."""
|
| 152 |
+
if not self._ragas_available:
|
| 153 |
+
return {}
|
| 154 |
+
try:
|
| 155 |
+
from ragas import evaluate, EvaluationDataset, SingleTurnSample
|
| 156 |
+
from ragas.metrics import (Faithfulness, AnswerRelevancy,
|
| 157 |
+
LLMContextPrecisionWithReference, LLMContextRecall)
|
| 158 |
+
from ragas.llms import LangchainLLMWrapper
|
| 159 |
+
from ragas.embeddings import LangchainEmbeddingsWrapper
|
| 160 |
+
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
| 161 |
+
import os
|
| 162 |
+
|
| 163 |
+
key = self._api_key or os.getenv("OPENAI_API_KEY", "")
|
| 164 |
+
llm = LangchainLLMWrapper(ChatOpenAI(model=self.eval_llm_model, api_key=key))
|
| 165 |
+
emb = LangchainEmbeddingsWrapper(OpenAIEmbeddings(api_key=key))
|
| 166 |
+
|
| 167 |
+
ragas_samples = []
|
| 168 |
+
for s in samples:
|
| 169 |
+
answer = s.baseline_answer if pipeline == "baseline" else s.graphrag_answer
|
| 170 |
+
ctxs = s.baseline_contexts if pipeline == "baseline" else s.graphrag_contexts
|
| 171 |
+
if answer and ctxs:
|
| 172 |
+
ragas_samples.append(SingleTurnSample(
|
| 173 |
+
user_input=s.query, response=answer,
|
| 174 |
+
retrieved_contexts=ctxs, reference=s.reference_answer))
|
| 175 |
+
|
| 176 |
+
if not ragas_samples: return {}
|
| 177 |
+
dataset = EvaluationDataset(samples=ragas_samples)
|
| 178 |
+
metrics = [Faithfulness(llm=llm), AnswerRelevancy(llm=llm, embeddings=emb),
|
| 179 |
+
LLMContextPrecisionWithReference(llm=llm), LLMContextRecall(llm=llm)]
|
| 180 |
+
return dict(evaluate(dataset=dataset, metrics=metrics))
|
| 181 |
+
except Exception as e:
|
| 182 |
+
logger.error(f"RAGAS evaluation failed: {e}")
|
| 183 |
+
return {}
|
| 184 |
+
|
| 185 |
+
def compute_aggregate_metrics(self) -> Dict[str, Any]:
|
| 186 |
+
"""Compute aggregate metrics across all evaluated samples."""
|
| 187 |
+
if not self.results: return {"message": "No results"}
|
| 188 |
+
n = len(self.results)
|
| 189 |
+
avg = lambda vals: sum(vals) / len(vals) if vals else 0.0
|
| 190 |
+
|
| 191 |
+
b = {
|
| 192 |
+
"avg_f1": round(avg([r.baseline_f1 for r in self.results]), 4),
|
| 193 |
+
"avg_em": round(avg([r.baseline_em for r in self.results]), 4),
|
| 194 |
+
"avg_context_hit": round(avg([r.baseline_context_hit for r in self.results]), 4),
|
| 195 |
+
"avg_tokens": round(avg([r.baseline_tokens for r in self.results]), 1),
|
| 196 |
+
"avg_cost": round(avg([r.baseline_cost for r in self.results]), 6),
|
| 197 |
+
"avg_latency_ms": round(avg([r.baseline_latency for r in self.results]), 1),
|
| 198 |
+
"total_tokens": sum(r.baseline_tokens for r in self.results),
|
| 199 |
+
"total_cost": round(sum(r.baseline_cost for r in self.results), 6),
|
| 200 |
+
}
|
| 201 |
+
g = {
|
| 202 |
+
"avg_f1": round(avg([r.graphrag_f1 for r in self.results]), 4),
|
| 203 |
+
"avg_em": round(avg([r.graphrag_em for r in self.results]), 4),
|
| 204 |
+
"avg_context_hit": round(avg([r.graphrag_context_hit for r in self.results]), 4),
|
| 205 |
+
"avg_tokens": round(avg([r.graphrag_tokens for r in self.results]), 1),
|
| 206 |
+
"avg_cost": round(avg([r.graphrag_cost for r in self.results]), 6),
|
| 207 |
+
"avg_latency_ms": round(avg([r.graphrag_latency for r in self.results]), 1),
|
| 208 |
+
"total_tokens": sum(r.graphrag_tokens for r in self.results),
|
| 209 |
+
"total_cost": round(sum(r.graphrag_cost for r in self.results), 6),
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
win_rate = sum(1 for r in self.results if r.graphrag_f1 > r.baseline_f1) / n
|
| 213 |
+
|
| 214 |
+
by_type = {}
|
| 215 |
+
for r in self.results:
|
| 216 |
+
qt = r.question_type or "unknown"
|
| 217 |
+
by_type.setdefault(qt, {"baseline_f1": [], "graphrag_f1": [], "count": 0})
|
| 218 |
+
by_type[qt]["baseline_f1"].append(r.baseline_f1)
|
| 219 |
+
by_type[qt]["graphrag_f1"].append(r.graphrag_f1)
|
| 220 |
+
by_type[qt]["count"] += 1
|
| 221 |
+
|
| 222 |
+
return {
|
| 223 |
+
"num_samples": n, "baseline": b, "graphrag": g,
|
| 224 |
+
"graphrag_f1_win_rate": round(win_rate, 4),
|
| 225 |
+
"token_ratio": round(g["total_tokens"] / max(b["total_tokens"], 1), 3),
|
| 226 |
+
"by_question_type": {
|
| 227 |
+
qt: {"count": d["count"],
|
| 228 |
+
"baseline_avg_f1": round(avg(d["baseline_f1"]), 4),
|
| 229 |
+
"graphrag_avg_f1": round(avg(d["graphrag_f1"]), 4)}
|
| 230 |
+
for qt, d in by_type.items()
|
| 231 |
+
}
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
def generate_report(self) -> str:
|
| 235 |
+
"""Generate a text benchmark report."""
|
| 236 |
+
m = self.compute_aggregate_metrics()
|
| 237 |
+
if "message" in m: return m["message"]
|
| 238 |
+
lines = [
|
| 239 |
+
"=" * 60, "GRAPHRAG INFERENCE BENCHMARK REPORT", "=" * 60,
|
| 240 |
+
f"\nTotal Samples Evaluated: {m['num_samples']}",
|
| 241 |
+
f"\n{'Metric':<25} {'Baseline':>12} {'GraphRAG':>12} {'Winner':>12}",
|
| 242 |
+
"-" * 65
|
| 243 |
+
]
|
| 244 |
+
b, g = m["baseline"], m["graphrag"]
|
| 245 |
+
for name, key in [("Avg F1 Score", "avg_f1"), ("Avg Exact Match", "avg_em"),
|
| 246 |
+
("Avg Context Hit Rate", "avg_context_hit")]:
|
| 247 |
+
bv, gv = b[key], g[key]
|
| 248 |
+
winner = "GraphRAG" if gv > bv else ("Baseline" if bv > gv else "Tie")
|
| 249 |
+
lines.append(f"{name:<25} {bv:>12.4f} {gv:>12.4f} {winner:>12}")
|
| 250 |
+
|
| 251 |
+
lines.append(f"\n{'Metric':<25} {'Baseline':>12} {'GraphRAG':>12} {'Ratio':>12}")
|
| 252 |
+
lines.append("-" * 65)
|
| 253 |
+
for name, key in [("Avg Tokens/Query", "avg_tokens"), ("Avg Cost ($)", "avg_cost"),
|
| 254 |
+
("Avg Latency (ms)", "avg_latency_ms")]:
|
| 255 |
+
bv, gv = b[key], g[key]
|
| 256 |
+
ratio = gv / bv if bv > 0 else 0
|
| 257 |
+
lines.append(f"{name:<25} {bv:>12.4f} {gv:>12.4f} {ratio:>11.2f}x")
|
| 258 |
+
|
| 259 |
+
lines.append(f"\nGraphRAG F1 Win Rate: {m['graphrag_f1_win_rate']:.1%}")
|
| 260 |
+
lines.append(f"Token Ratio (G/B): {m['token_ratio']:.2f}x")
|
| 261 |
+
|
| 262 |
+
if m.get("by_question_type"):
|
| 263 |
+
lines.extend(["\n--- By Question Type ---",
|
| 264 |
+
f"{'Type':<20} {'Count':>6} {'Base F1':>10} {'Graph F1':>10}", "-" * 50])
|
| 265 |
+
for qt, d in m["by_question_type"].items():
|
| 266 |
+
lines.append(f"{qt:<20} {d['count']:>6} {d['baseline_avg_f1']:>10.4f} {d['graphrag_avg_f1']:>10.4f}")
|
| 267 |
+
lines.append("\n" + "=" * 60)
|
| 268 |
+
return "\n".join(lines)
|