muthuk1 commited on
Commit
a23a620
Β·
verified Β·
1 Parent(s): 0117576

Add Layer 2: Inference Orchestration (dual pipeline, adaptive routing, comparison)

Browse files
graphrag/layers/orchestration_layer.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Layer 2: Inference Orchestration β€” Dual Pipeline Manager
3
+ ========================================================
4
+ Routes queries through Baseline RAG and GraphRAG pipelines,
5
+ collects metrics, and provides adaptive routing.
6
+ """
7
+ import json
8
+ import logging
9
+ import time
10
+ from dataclasses import dataclass, field
11
+ from typing import Any, Dict, List, Tuple
12
+
13
+ from .graph_layer import GraphLayer, cosine_similarity
14
+ from .llm_layer import LLMLayer, LLMResponse, TokenTracker
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ @dataclass
20
+ class PipelineResult:
21
+ """Result from a single pipeline execution."""
22
+ answer: str = ""
23
+ contexts: List[str] = field(default_factory=list)
24
+ total_tokens: int = 0
25
+ input_tokens: int = 0
26
+ output_tokens: int = 0
27
+ latency_ms: float = 0.0
28
+ cost_usd: float = 0.0
29
+ pipeline_type: str = ""
30
+ entities_found: List[Dict] = field(default_factory=list)
31
+ relations_traversed: List[str] = field(default_factory=list)
32
+ hops_used: int = 0
33
+ complexity_score: float = 0.0
34
+ query_type: str = ""
35
+ token_breakdown: Dict = field(default_factory=dict)
36
+
37
+
38
+ @dataclass
39
+ class ComparisonResult:
40
+ """Side-by-side comparison of both pipelines."""
41
+ query: str = ""
42
+ baseline: PipelineResult = field(default_factory=PipelineResult)
43
+ graphrag: PipelineResult = field(default_factory=PipelineResult)
44
+ token_savings_pct: float = 0.0
45
+ latency_diff_ms: float = 0.0
46
+ cost_diff_usd: float = 0.0
47
+ recommended_pipeline: str = ""
48
+ routing_reason: str = ""
49
+
50
+
51
+ class EmbeddingManager:
52
+ """Manages embedding generation (OpenAI or local)."""
53
+
54
+ def __init__(self, provider="openai", model="text-embedding-3-small",
55
+ api_key="", dimension=1536):
56
+ self.provider = provider
57
+ self.model = model
58
+ self._api_key = api_key
59
+ self.dimension = dimension
60
+ self._client = None
61
+ self._local_model = None
62
+
63
+ def initialize(self):
64
+ if self.provider == "openai":
65
+ try:
66
+ from openai import OpenAI
67
+ import os
68
+ key = self._api_key or os.getenv("OPENAI_API_KEY", "")
69
+ if key:
70
+ self._client = OpenAI(api_key=key)
71
+ logger.info(f"OpenAI embeddings: {self.model}")
72
+ else:
73
+ self._init_local()
74
+ except ImportError:
75
+ self._init_local()
76
+ else:
77
+ self._init_local()
78
+
79
+ def _init_local(self):
80
+ try:
81
+ from sentence_transformers import SentenceTransformer
82
+ self._local_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
83
+ self.dimension = 384
84
+ self.provider = "local"
85
+ logger.info("Local embeddings: all-MiniLM-L6-v2")
86
+ except ImportError:
87
+ logger.warning("No embedding model available β€” zero vectors")
88
+
89
+ def embed(self, texts: List[str]) -> List[List[float]]:
90
+ if not texts: return []
91
+ if self.provider == "openai" and self._client:
92
+ try:
93
+ resp = self._client.embeddings.create(input=texts, model=self.model)
94
+ return [item.embedding for item in resp.data]
95
+ except Exception as e:
96
+ logger.error(f"Embedding error: {e}")
97
+ return [[0.0] * self.dimension for _ in texts]
98
+ elif self._local_model:
99
+ return [emb.tolist() for emb in self._local_model.encode(texts)]
100
+ return [[0.0] * self.dimension for _ in texts]
101
+
102
+ def embed_single(self, text: str) -> List[float]:
103
+ r = self.embed([text])
104
+ return r[0] if r else [0.0] * self.dimension
105
+
106
+
107
+ class InferenceOrchestrator:
108
+ """
109
+ Layer 2: Manages both pipelines and routes queries.
110
+ """
111
+
112
+ def __init__(self, graph_layer=None, llm_layer=None, embedder=None, config=None):
113
+ self.graph = graph_layer or GraphLayer()
114
+ self.llm = llm_layer or LLMLayer()
115
+ self.embedder = embedder or EmbeddingManager()
116
+ self.config = config or {}
117
+ self.baseline_tracker = TokenTracker()
118
+ self.graphrag_tracker = TokenTracker()
119
+ self.comparison_history: List[ComparisonResult] = []
120
+
121
+ def initialize(self):
122
+ self.llm.initialize()
123
+ self.embedder.initialize()
124
+ logger.info("Inference Orchestrator initialized.")
125
+
126
+ # ── Pipeline A: Baseline RAG ────────────────────────────
127
+
128
+ def run_baseline_rag(self, query, passages=None, top_k=5):
129
+ """
130
+ Pipeline A: Query β†’ Embed β†’ Vector Search β†’ Top-K Chunks β†’ LLM β†’ Answer
131
+ """
132
+ start = time.perf_counter()
133
+ result = PipelineResult(pipeline_type="baseline")
134
+ ti = to = cost = 0.0
135
+
136
+ if passages:
137
+ query_emb = self.embedder.embed_single(query)
138
+ passage_embs = self.embedder.embed(passages)
139
+ scored = sorted(
140
+ [(cosine_similarity(query_emb, emb), p) for p, emb in zip(passages, passage_embs)],
141
+ reverse=True
142
+ )
143
+ result.contexts = [p for _, p in scored[:top_k]]
144
+ elif self.graph.is_connected:
145
+ query_emb = self.embedder.embed_single(query)
146
+ chunks = self.graph.vector_search_chunks(query_emb, top_k)
147
+ result.contexts = [c.get("text", "") for c in chunks]
148
+ else:
149
+ result.contexts = ["[No context available β€” connect TigerGraph or provide passages]"]
150
+
151
+ ctx_text = "\n\n".join(result.contexts[:top_k])
152
+ resp = self.llm.generate_answer(query, ctx_text)
153
+ result.answer = resp.content
154
+ ti += resp.input_tokens; to += resp.output_tokens; cost += resp.cost_usd
155
+ result.input_tokens = int(ti); result.output_tokens = int(to)
156
+ result.total_tokens = int(ti + to); result.cost_usd = cost
157
+ result.latency_ms = (time.perf_counter() - start) * 1000
158
+ self.baseline_tracker.record(resp, "baseline")
159
+ return result
160
+
161
+ # ── Pipeline B: GraphRAG ────────────────────────────────
162
+
163
+ def run_graphrag(self, query, passages=None, seed_entities=5, hops=2, max_ctx=10):
164
+ """
165
+ Pipeline B: Query β†’ Keywords β†’ Entity Search β†’ Graph Traverse β†’ Structured Context β†’ LLM
166
+ Novelties: Dual-level keywords, schema-bounded extraction, graph reasoning
167
+ """
168
+ start = time.perf_counter()
169
+ result = PipelineResult(pipeline_type="graphrag")
170
+ ti = to = cost = 0.0
171
+
172
+ # Step 1: Extract dual-level keywords (LightRAG-inspired)
173
+ kw_resp = self.llm.extract_keywords(query)
174
+ ti += kw_resp.input_tokens; to += kw_resp.output_tokens; cost += kw_resp.cost_usd
175
+ self.graphrag_tracker.record(kw_resp, "keywords")
176
+
177
+ try:
178
+ kws = json.loads(kw_resp.content)
179
+ except json.JSONDecodeError:
180
+ kws = {"high_level": [], "low_level": [query]}
181
+
182
+ low_level = kws.get("low_level", [])
183
+
184
+ if self.graph.is_connected:
185
+ # Step 2: Find seed entities via vector search
186
+ search_text = " ".join(low_level) if low_level else query
187
+ query_emb = self.embedder.embed_single(search_text)
188
+ ents = self.graph.vector_search_entities(query_emb, seed_entities)
189
+ seed_ids = [e.get("entity_id", "") for e in ents]
190
+ result.entities_found = [
191
+ {"name": e.get("name",""), "entity_type": e.get("entity_type",""),
192
+ "description": e.get("description",""), "score": e.get("score",0)}
193
+ for e in ents
194
+ ]
195
+ # Step 3: Multi-hop graph traversal
196
+ if seed_ids:
197
+ traversal = self.graph.graph_traverse(seed_ids, hops)
198
+ result.contexts = traversal.get("chunk_texts", [])[:max_ctx]
199
+ result.relations_traversed = traversal.get("relations", [])
200
+ result.hops_used = hops
201
+ else:
202
+ # Fallback: simulate GraphRAG with passages + entity extraction
203
+ if passages:
204
+ query_emb = self.embedder.embed_single(query)
205
+ passage_embs = self.embedder.embed(passages)
206
+ scored = sorted(
207
+ [(cosine_similarity(query_emb, emb), p, i)
208
+ for i, (p, emb) in enumerate(zip(passages, passage_embs))],
209
+ reverse=True
210
+ )
211
+
212
+ # Extract entities from top passages (simulates graph construction)
213
+ top_p = scored[:3]
214
+ all_ent_names = set()
215
+ for _, passage, _ in top_p:
216
+ ext_resp = self.llm.extract_entities(passage)
217
+ ti += ext_resp.input_tokens; to += ext_resp.output_tokens; cost += ext_resp.cost_usd
218
+ self.graphrag_tracker.record(ext_resp, "entity_extraction")
219
+ try:
220
+ extracted = json.loads(ext_resp.content)
221
+ for ent in extracted.get("entities", []):
222
+ all_ent_names.add(ent.get("name", ""))
223
+ result.entities_found.append(ent)
224
+ for rel in extracted.get("relations", []):
225
+ result.relations_traversed.append(
226
+ f"{rel['source']} -[{rel['type']}]-> {rel['target']}: {rel.get('description','')}")
227
+ except json.JSONDecodeError:
228
+ pass
229
+
230
+ # Multi-hop simulation: expand by entity mentions
231
+ expanded = []
232
+ for _, passage, idx in scored:
233
+ for en in all_ent_names:
234
+ if en.lower() in passage.lower():
235
+ expanded.append(passage)
236
+ break
237
+ all_ctx = [p for _, p, _ in top_p]
238
+ for ep in expanded:
239
+ if ep not in all_ctx: all_ctx.append(ep)
240
+ result.contexts = all_ctx[:max_ctx]
241
+ result.hops_used = hops
242
+
243
+ # Step 4: Build structured context with graph information
244
+ ctx_parts = []
245
+ if result.entities_found:
246
+ ctx_parts.append("### Entities Found:\n" + "\n".join(
247
+ [f"- **{e.get('name','?')}** ({e.get('entity_type','?')}): {e.get('description','')}"
248
+ for e in result.entities_found[:10]]))
249
+ if result.relations_traversed:
250
+ ctx_parts.append("### Relationships:\n" + "\n".join(
251
+ [f"- {r}" for r in result.relations_traversed[:15]]))
252
+ if result.contexts:
253
+ ctx_parts.append("### Retrieved Passages:\n" + "\n\n".join(
254
+ [f"[Passage {i+1}]: {c}" for i, c in enumerate(result.contexts[:max_ctx])]))
255
+
256
+ structured = "\n\n".join(ctx_parts)
257
+ sys_prompt = (
258
+ "You are a knowledgeable assistant with access to a knowledge graph. "
259
+ "Use the structured context including entities, relationships, and passages "
260
+ "to answer accurately. Follow relationship chains for multi-hop reasoning. Be concise."
261
+ )
262
+ gen_resp = self.llm.generate_answer(query, structured, sys_prompt)
263
+ ti += gen_resp.input_tokens; to += gen_resp.output_tokens; cost += gen_resp.cost_usd
264
+ self.graphrag_tracker.record(gen_resp, "graphrag_gen")
265
+
266
+ result.answer = gen_resp.content
267
+ result.input_tokens = int(ti); result.output_tokens = int(to)
268
+ result.total_tokens = int(ti + to); result.cost_usd = cost
269
+ result.latency_ms = (time.perf_counter() - start) * 1000
270
+ return result
271
+
272
+ # ── Adaptive Query Router (Novelty) ─────────────────────
273
+
274
+ def analyze_complexity(self, query):
275
+ """Analyze query complexity for adaptive routing."""
276
+ resp = self.llm.analyze_query_complexity(query)
277
+ try:
278
+ a = json.loads(resp.content)
279
+ return float(a.get("complexity_score", 0.5)), a.get("query_type", "unknown"), a.get("reasoning", "")
280
+ except (json.JSONDecodeError, ValueError):
281
+ return 0.5, "unknown", "Analysis failed"
282
+
283
+ def run_comparison(self, query, passages=None, top_k=5, hops=2):
284
+ """Run both pipelines and compare."""
285
+ b = self.run_baseline_rag(query, passages, top_k)
286
+ g = self.run_graphrag(query, passages, hops=hops)
287
+ comp = ComparisonResult(query=query, baseline=b, graphrag=g)
288
+ if b.total_tokens > 0:
289
+ comp.token_savings_pct = (g.total_tokens - b.total_tokens) / b.total_tokens * 100
290
+ comp.latency_diff_ms = g.latency_ms - b.latency_ms
291
+ comp.cost_diff_usd = g.cost_usd - b.cost_usd
292
+ self.comparison_history.append(comp)
293
+ return comp
294
+
295
+ def run_adaptive(self, query, passages=None, threshold=0.6):
296
+ """Adaptive routing: automatically picks optimal pipeline."""
297
+ score, qtype, reasoning = self.analyze_complexity(query)
298
+ comp = self.run_comparison(query, passages)
299
+ comp.baseline.complexity_score = score
300
+ comp.baseline.query_type = qtype
301
+ comp.graphrag.complexity_score = score
302
+ comp.graphrag.query_type = qtype
303
+ if score >= threshold:
304
+ comp.recommended_pipeline = "graphrag"
305
+ comp.routing_reason = f"Complex query (score={score:.2f}, type={qtype}): {reasoning}"
306
+ else:
307
+ comp.recommended_pipeline = "baseline"
308
+ comp.routing_reason = f"Simple query (score={score:.2f}, type={qtype}): {reasoning}"
309
+ return comp
310
+
311
+ def explain_graphrag_reasoning(self, query, graphrag_result):
312
+ """Generate reasoning path explanation (novelty)."""
313
+ resp = self.llm.generate_graph_explanation(
314
+ query, graphrag_result.entities_found,
315
+ graphrag_result.relations_traversed, graphrag_result.answer)
316
+ return resp.content
317
+
318
+ def get_aggregate_metrics(self):
319
+ if not self.comparison_history: return {"message": "No comparisons"}
320
+ n = len(self.comparison_history)
321
+ return {
322
+ "total_queries": n,
323
+ "baseline": {
324
+ "total_tokens": sum(c.baseline.total_tokens for c in self.comparison_history),
325
+ "avg_tokens": sum(c.baseline.total_tokens for c in self.comparison_history) / n,
326
+ "total_cost": sum(c.baseline.cost_usd for c in self.comparison_history),
327
+ "avg_latency": sum(c.baseline.latency_ms for c in self.comparison_history) / n,
328
+ },
329
+ "graphrag": {
330
+ "total_tokens": sum(c.graphrag.total_tokens for c in self.comparison_history),
331
+ "avg_tokens": sum(c.graphrag.total_tokens for c in self.comparison_history) / n,
332
+ "total_cost": sum(c.graphrag.cost_usd for c in self.comparison_history),
333
+ "avg_latency": sum(c.graphrag.latency_ms for c in self.comparison_history) / n,
334
+ },
335
+ }