muthuk1 commited on
Commit
da5b779
·
verified ·
1 Parent(s): a23a620

Add benchmark runner

Browse files
Files changed (1) hide show
  1. graphrag/benchmark.py +107 -0
graphrag/benchmark.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Benchmark Runner — Runs both pipelines on HotpotQA and evaluates
3
+ =================================================================
4
+ """
5
+ import json
6
+ import logging
7
+ from typing import Dict, List
8
+ from .layers.orchestration_layer import InferenceOrchestrator
9
+ from .layers.evaluation_layer import EvaluationLayer, EvalSample
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class BenchmarkRunner:
15
+ """Runs benchmarks on HotpotQA and generates comparison metrics."""
16
+
17
+ def __init__(self, orchestrator, evaluator):
18
+ self.orchestrator = orchestrator
19
+ self.evaluator = evaluator
20
+ self.benchmark_results = []
21
+
22
+ def run_hotpotqa_benchmark(self, num_samples=100, split="validation",
23
+ top_k=5, hops=2, progress_callback=None):
24
+ """Run both pipelines on HotpotQA and evaluate."""
25
+ from datasets import load_dataset
26
+ logger.info(f"Loading HotpotQA ({split}, n={num_samples})...")
27
+ ds = load_dataset("hotpotqa/hotpot_qa", "distractor", split=split)
28
+
29
+ results = []
30
+ for idx in range(min(num_samples, len(ds))):
31
+ row = ds[idx]
32
+ query, gold = row["question"], row["answer"]
33
+ qtype = row.get("type", "unknown")
34
+ level = row.get("level", "unknown")
35
+
36
+ # Build passages from context
37
+ passages = [f"{t}: {' '.join(s)}"
38
+ for t, s in zip(row["context"]["title"], row["context"]["sentences"])]
39
+
40
+ # Extract supporting facts for context hit rate
41
+ sf = []
42
+ for t, si in zip(row["supporting_facts"]["title"], row["supporting_facts"]["sent_id"]):
43
+ for ct, cs in zip(row["context"]["title"], row["context"]["sentences"]):
44
+ if ct == t and si < len(cs):
45
+ sf.append(cs[si])
46
+
47
+ try:
48
+ comp = self.orchestrator.run_comparison(query, passages, top_k, hops)
49
+
50
+ sample = EvalSample(
51
+ query=query, reference_answer=gold,
52
+ baseline_answer=comp.baseline.answer,
53
+ graphrag_answer=comp.graphrag.answer,
54
+ baseline_contexts=comp.baseline.contexts,
55
+ graphrag_contexts=comp.graphrag.contexts,
56
+ question_type=qtype, difficulty=str(level),
57
+ supporting_facts=sf)
58
+
59
+ er = self.evaluator.evaluate_sample(
60
+ sample,
61
+ comp.baseline.total_tokens, comp.graphrag.total_tokens,
62
+ comp.baseline.cost_usd, comp.graphrag.cost_usd,
63
+ comp.baseline.latency_ms, comp.graphrag.latency_ms)
64
+
65
+ rd = {
66
+ "idx": idx, "query": query, "gold_answer": gold,
67
+ "question_type": qtype, "level": level,
68
+ "baseline_answer": comp.baseline.answer,
69
+ "graphrag_answer": comp.graphrag.answer,
70
+ "baseline_f1": er.baseline_f1, "graphrag_f1": er.graphrag_f1,
71
+ "baseline_em": er.baseline_em, "graphrag_em": er.graphrag_em,
72
+ "baseline_tokens": comp.baseline.total_tokens,
73
+ "graphrag_tokens": comp.graphrag.total_tokens,
74
+ "baseline_cost": comp.baseline.cost_usd,
75
+ "graphrag_cost": comp.graphrag.cost_usd,
76
+ "baseline_latency": comp.baseline.latency_ms,
77
+ "graphrag_latency": comp.graphrag.latency_ms,
78
+ "baseline_context_hit": er.baseline_context_hit,
79
+ "graphrag_context_hit": er.graphrag_context_hit,
80
+ "entities_found": len(comp.graphrag.entities_found),
81
+ "relations_traversed": len(comp.graphrag.relations_traversed),
82
+ }
83
+ results.append(rd)
84
+ self.benchmark_results.append(rd)
85
+
86
+ if progress_callback:
87
+ progress_callback(idx + 1, num_samples, rd)
88
+ if (idx + 1) % 10 == 0:
89
+ logger.info(f"Processed {idx + 1}/{num_samples} queries...")
90
+
91
+ except Exception as e:
92
+ logger.error(f"Error on query {idx}: {e}")
93
+
94
+ aggregate = self.evaluator.compute_aggregate_metrics()
95
+ report = self.evaluator.generate_report()
96
+ return {"results": results, "aggregate": aggregate, "report": report,
97
+ "num_completed": len(results), "num_requested": num_samples}
98
+
99
+ def get_results_dataframe(self):
100
+ import pandas as pd
101
+ return pd.DataFrame(self.benchmark_results) if self.benchmark_results else pd.DataFrame()
102
+
103
+ def save_results(self, filepath):
104
+ with open(filepath, 'w') as f:
105
+ json.dump({"results": self.benchmark_results,
106
+ "aggregate": self.evaluator.compute_aggregate_metrics()},
107
+ f, indent=2, default=str)