Add benchmark runner
Browse files- graphrag/benchmark.py +107 -0
graphrag/benchmark.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Benchmark Runner — Runs both pipelines on HotpotQA and evaluates
|
| 3 |
+
=================================================================
|
| 4 |
+
"""
|
| 5 |
+
import json
|
| 6 |
+
import logging
|
| 7 |
+
from typing import Dict, List
|
| 8 |
+
from .layers.orchestration_layer import InferenceOrchestrator
|
| 9 |
+
from .layers.evaluation_layer import EvaluationLayer, EvalSample
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class BenchmarkRunner:
|
| 15 |
+
"""Runs benchmarks on HotpotQA and generates comparison metrics."""
|
| 16 |
+
|
| 17 |
+
def __init__(self, orchestrator, evaluator):
|
| 18 |
+
self.orchestrator = orchestrator
|
| 19 |
+
self.evaluator = evaluator
|
| 20 |
+
self.benchmark_results = []
|
| 21 |
+
|
| 22 |
+
def run_hotpotqa_benchmark(self, num_samples=100, split="validation",
|
| 23 |
+
top_k=5, hops=2, progress_callback=None):
|
| 24 |
+
"""Run both pipelines on HotpotQA and evaluate."""
|
| 25 |
+
from datasets import load_dataset
|
| 26 |
+
logger.info(f"Loading HotpotQA ({split}, n={num_samples})...")
|
| 27 |
+
ds = load_dataset("hotpotqa/hotpot_qa", "distractor", split=split)
|
| 28 |
+
|
| 29 |
+
results = []
|
| 30 |
+
for idx in range(min(num_samples, len(ds))):
|
| 31 |
+
row = ds[idx]
|
| 32 |
+
query, gold = row["question"], row["answer"]
|
| 33 |
+
qtype = row.get("type", "unknown")
|
| 34 |
+
level = row.get("level", "unknown")
|
| 35 |
+
|
| 36 |
+
# Build passages from context
|
| 37 |
+
passages = [f"{t}: {' '.join(s)}"
|
| 38 |
+
for t, s in zip(row["context"]["title"], row["context"]["sentences"])]
|
| 39 |
+
|
| 40 |
+
# Extract supporting facts for context hit rate
|
| 41 |
+
sf = []
|
| 42 |
+
for t, si in zip(row["supporting_facts"]["title"], row["supporting_facts"]["sent_id"]):
|
| 43 |
+
for ct, cs in zip(row["context"]["title"], row["context"]["sentences"]):
|
| 44 |
+
if ct == t and si < len(cs):
|
| 45 |
+
sf.append(cs[si])
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
comp = self.orchestrator.run_comparison(query, passages, top_k, hops)
|
| 49 |
+
|
| 50 |
+
sample = EvalSample(
|
| 51 |
+
query=query, reference_answer=gold,
|
| 52 |
+
baseline_answer=comp.baseline.answer,
|
| 53 |
+
graphrag_answer=comp.graphrag.answer,
|
| 54 |
+
baseline_contexts=comp.baseline.contexts,
|
| 55 |
+
graphrag_contexts=comp.graphrag.contexts,
|
| 56 |
+
question_type=qtype, difficulty=str(level),
|
| 57 |
+
supporting_facts=sf)
|
| 58 |
+
|
| 59 |
+
er = self.evaluator.evaluate_sample(
|
| 60 |
+
sample,
|
| 61 |
+
comp.baseline.total_tokens, comp.graphrag.total_tokens,
|
| 62 |
+
comp.baseline.cost_usd, comp.graphrag.cost_usd,
|
| 63 |
+
comp.baseline.latency_ms, comp.graphrag.latency_ms)
|
| 64 |
+
|
| 65 |
+
rd = {
|
| 66 |
+
"idx": idx, "query": query, "gold_answer": gold,
|
| 67 |
+
"question_type": qtype, "level": level,
|
| 68 |
+
"baseline_answer": comp.baseline.answer,
|
| 69 |
+
"graphrag_answer": comp.graphrag.answer,
|
| 70 |
+
"baseline_f1": er.baseline_f1, "graphrag_f1": er.graphrag_f1,
|
| 71 |
+
"baseline_em": er.baseline_em, "graphrag_em": er.graphrag_em,
|
| 72 |
+
"baseline_tokens": comp.baseline.total_tokens,
|
| 73 |
+
"graphrag_tokens": comp.graphrag.total_tokens,
|
| 74 |
+
"baseline_cost": comp.baseline.cost_usd,
|
| 75 |
+
"graphrag_cost": comp.graphrag.cost_usd,
|
| 76 |
+
"baseline_latency": comp.baseline.latency_ms,
|
| 77 |
+
"graphrag_latency": comp.graphrag.latency_ms,
|
| 78 |
+
"baseline_context_hit": er.baseline_context_hit,
|
| 79 |
+
"graphrag_context_hit": er.graphrag_context_hit,
|
| 80 |
+
"entities_found": len(comp.graphrag.entities_found),
|
| 81 |
+
"relations_traversed": len(comp.graphrag.relations_traversed),
|
| 82 |
+
}
|
| 83 |
+
results.append(rd)
|
| 84 |
+
self.benchmark_results.append(rd)
|
| 85 |
+
|
| 86 |
+
if progress_callback:
|
| 87 |
+
progress_callback(idx + 1, num_samples, rd)
|
| 88 |
+
if (idx + 1) % 10 == 0:
|
| 89 |
+
logger.info(f"Processed {idx + 1}/{num_samples} queries...")
|
| 90 |
+
|
| 91 |
+
except Exception as e:
|
| 92 |
+
logger.error(f"Error on query {idx}: {e}")
|
| 93 |
+
|
| 94 |
+
aggregate = self.evaluator.compute_aggregate_metrics()
|
| 95 |
+
report = self.evaluator.generate_report()
|
| 96 |
+
return {"results": results, "aggregate": aggregate, "report": report,
|
| 97 |
+
"num_completed": len(results), "num_requested": num_samples}
|
| 98 |
+
|
| 99 |
+
def get_results_dataframe(self):
|
| 100 |
+
import pandas as pd
|
| 101 |
+
return pd.DataFrame(self.benchmark_results) if self.benchmark_results else pd.DataFrame()
|
| 102 |
+
|
| 103 |
+
def save_results(self, filepath):
|
| 104 |
+
with open(filepath, 'w') as f:
|
| 105 |
+
json.dump({"results": self.benchmark_results,
|
| 106 |
+
"aggregate": self.evaluator.compute_aggregate_metrics()},
|
| 107 |
+
f, indent=2, default=str)
|