muthuk1 commited on
Commit
0a8419b
·
verified ·
1 Parent(s): 1ae3c99

Add main entry point (CLI: dashboard, benchmark, ingest, demo)

Browse files
Files changed (1) hide show
  1. graphrag/main.py +143 -0
graphrag/main.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Main Entry Point — GraphRAG Inference Hackathon
3
+ ================================================
4
+ Run: python -m graphrag.main {dashboard|benchmark|ingest|demo}
5
+ """
6
+ import argparse
7
+ import logging
8
+ import os
9
+ import sys
10
+
11
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def main():
16
+ parser = argparse.ArgumentParser(description="GraphRAG Inference Hackathon — Dual Pipeline System")
17
+ parser.add_argument("command", choices=["dashboard", "benchmark", "ingest", "demo"],
18
+ help="Command to run")
19
+ parser.add_argument("--port", type=int, default=7860, help="Dashboard port")
20
+ parser.add_argument("--samples", type=int, default=50, help="Number of samples")
21
+ parser.add_argument("--top-k", type=int, default=5, help="Top-K retrieval")
22
+ parser.add_argument("--hops", type=int, default=2, help="Graph traversal hops")
23
+ parser.add_argument("--share", action="store_true", help="Create Gradio share link")
24
+ parser.add_argument("--output", type=str, default="results.json", help="Output file")
25
+ args = parser.parse_args()
26
+
27
+ if args.command == "dashboard":
28
+ from graphrag.dashboard import build_dashboard
29
+ demo = build_dashboard()
30
+ demo.launch(server_port=args.port, share=args.share, show_error=True)
31
+
32
+ elif args.command == "benchmark":
33
+ run_benchmark(args)
34
+ elif args.command == "ingest":
35
+ run_ingestion(args)
36
+ elif args.command == "demo":
37
+ run_demo(args)
38
+
39
+
40
+ def run_benchmark(args):
41
+ from graphrag.layers.graph_layer import GraphLayer
42
+ from graphrag.layers.llm_layer import LLMLayer
43
+ from graphrag.layers.orchestration_layer import InferenceOrchestrator, EmbeddingManager
44
+ from graphrag.layers.evaluation_layer import EvaluationLayer
45
+ from graphrag.benchmark import BenchmarkRunner
46
+
47
+ llm = LLMLayer(api_key=os.getenv("OPENAI_API_KEY", ""), model=os.getenv("LLM_MODEL", "gpt-4o-mini"))
48
+ llm.initialize()
49
+ embedder = EmbeddingManager(provider="openai", model="text-embedding-3-small",
50
+ api_key=os.getenv("OPENAI_API_KEY", ""))
51
+ embedder.initialize()
52
+ graph = GraphLayer()
53
+ orchestrator = InferenceOrchestrator(graph_layer=graph, llm_layer=llm, embedder=embedder)
54
+ orchestrator.initialize()
55
+ evaluator = EvaluationLayer(eval_llm_model=os.getenv("LLM_MODEL", "gpt-4o-mini"),
56
+ api_key=os.getenv("OPENAI_API_KEY", ""))
57
+ evaluator.initialize()
58
+ runner = BenchmarkRunner(orchestrator, evaluator)
59
+
60
+ logger.info(f"Running benchmark with {args.samples} samples...")
61
+ results = runner.run_hotpotqa_benchmark(num_samples=args.samples, top_k=args.top_k, hops=args.hops)
62
+ print("\n" + results["report"])
63
+ runner.save_results(args.output)
64
+ logger.info(f"Results saved to {args.output}")
65
+
66
+
67
+ def run_ingestion(args):
68
+ from graphrag.layers.graph_layer import GraphLayer
69
+ from graphrag.layers.llm_layer import LLMLayer
70
+ from graphrag.layers.orchestration_layer import EmbeddingManager
71
+ from graphrag.ingestion import IngestionPipeline
72
+
73
+ graph = GraphLayer(config={"host": os.getenv("TG_HOST", ""), "graphname": os.getenv("TG_GRAPH", "GraphRAG"),
74
+ "username": os.getenv("TG_USERNAME", "tigergraph"),
75
+ "password": os.getenv("TG_PASSWORD", "")})
76
+ if not graph.connect():
77
+ logger.error("Failed to connect to TigerGraph. Set TG_HOST, TG_PASSWORD env vars.")
78
+ sys.exit(1)
79
+ graph.create_schema()
80
+ graph.install_queries()
81
+
82
+ llm = LLMLayer(api_key=os.getenv("OPENAI_API_KEY", ""), model="gpt-4o-mini")
83
+ llm.initialize()
84
+ embedder = EmbeddingManager(provider="openai", model="text-embedding-3-small")
85
+ embedder.initialize()
86
+ pipeline = IngestionPipeline(graph, llm, embedder)
87
+ stats = pipeline.ingest_hotpotqa(max_docs=args.samples)
88
+ logger.info(f"Ingestion complete: {stats}")
89
+
90
+
91
+ def run_demo(args):
92
+ from graphrag.layers.llm_layer import LLMLayer
93
+ from graphrag.layers.orchestration_layer import InferenceOrchestrator, EmbeddingManager
94
+ from graphrag.layers.graph_layer import GraphLayer
95
+ from graphrag.layers.evaluation_layer import compute_f1
96
+
97
+ print("=" * 60)
98
+ print("🔍 GraphRAG Inference Demo")
99
+ print("=" * 60)
100
+
101
+ llm = LLMLayer(api_key=os.getenv("OPENAI_API_KEY", ""), model="gpt-4o-mini")
102
+ llm.initialize()
103
+ embedder = EmbeddingManager(provider="openai", model="text-embedding-3-small")
104
+ embedder.initialize()
105
+ graph = GraphLayer()
106
+ orch = InferenceOrchestrator(graph_layer=graph, llm_layer=llm, embedder=embedder)
107
+ orch.initialize()
108
+
109
+ queries = [
110
+ "Were Scott Derrickson and Ed Wood of the same nationality?",
111
+ "Which magazine was started first, Arthur's Magazine or First for Women?",
112
+ ]
113
+
114
+ for query in queries:
115
+ print(f"\n{'─' * 60}")
116
+ print(f"Query: {query}")
117
+ try:
118
+ from datasets import load_dataset
119
+ ds = load_dataset("hotpotqa/hotpot_qa", "distractor", split="validation", streaming=True)
120
+ for row in ds:
121
+ if query.lower() == row["question"].lower():
122
+ passages = [f"{t}: {' '.join(s)}"
123
+ for t, s in zip(row["context"]["title"], row["context"]["sentences"])]
124
+ comp = orch.run_comparison(query, passages)
125
+ gold = row["answer"]
126
+ print(f"\n🔵 Baseline: {comp.baseline.answer}")
127
+ print(f" Tokens: {comp.baseline.total_tokens} | Cost: ${comp.baseline.cost_usd:.6f}")
128
+ print(f"\n🔴 GraphRAG: {comp.graphrag.answer}")
129
+ print(f" Tokens: {comp.graphrag.total_tokens} | Cost: ${comp.graphrag.cost_usd:.6f}")
130
+ print(f" Entities: {len(comp.graphrag.entities_found)} | Relations: {len(comp.graphrag.relations_traversed)}")
131
+ print(f"\n📋 Gold: {gold}")
132
+ print(f" Baseline F1: {compute_f1(comp.baseline.answer, gold):.4f}")
133
+ print(f" GraphRAG F1: {compute_f1(comp.graphrag.answer, gold):.4f}")
134
+ break
135
+ except Exception as e:
136
+ print(f"Error: {e}")
137
+
138
+ print(f"\n{'=' * 60}")
139
+ print("Run 'python -m graphrag.main dashboard' for the full UI!")
140
+
141
+
142
+ if __name__ == "__main__":
143
+ main()