muthuk1 commited on
Commit
06a5155
Β·
verified Β·
1 Parent(s): 4921b5b

Add Layer 3: LLM Layer (generation, entity extraction, keyword extraction, complexity analysis)

Browse files
Files changed (1) hide show
  1. graphrag/layers/llm_layer.py +195 -0
graphrag/layers/llm_layer.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Layer 3: LLM Layer β€” All LLM interactions with token tracking
3
+ =============================================================
4
+ Handles generation, entity extraction, keyword extraction,
5
+ query complexity analysis, and graph reasoning explanation.
6
+ """
7
+ import json
8
+ import logging
9
+ import time
10
+ from dataclasses import dataclass, field
11
+ from typing import Any, Dict, List
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ @dataclass
17
+ class LLMResponse:
18
+ """Container for LLM response with usage metadata."""
19
+ content: str = ""
20
+ input_tokens: int = 0
21
+ output_tokens: int = 0
22
+ total_tokens: int = 0
23
+ latency_ms: float = 0.0
24
+ cost_usd: float = 0.0
25
+ model: str = ""
26
+
27
+
28
+ @dataclass
29
+ class TokenTracker:
30
+ """Tracks cumulative token usage and costs."""
31
+ total_input_tokens: int = 0
32
+ total_output_tokens: int = 0
33
+ total_cost: float = 0.0
34
+ call_count: int = 0
35
+ calls: List[Dict] = field(default_factory=list)
36
+
37
+ def record(self, resp: LLMResponse, label: str = ""):
38
+ self.total_input_tokens += resp.input_tokens
39
+ self.total_output_tokens += resp.output_tokens
40
+ self.total_cost += resp.cost_usd
41
+ self.call_count += 1
42
+ self.calls.append({
43
+ "label": label, "input_tokens": resp.input_tokens,
44
+ "output_tokens": resp.output_tokens, "cost_usd": resp.cost_usd,
45
+ "latency_ms": resp.latency_ms
46
+ })
47
+
48
+ def summary(self):
49
+ return {
50
+ "total_input_tokens": self.total_input_tokens,
51
+ "total_output_tokens": self.total_output_tokens,
52
+ "total_cost_usd": round(self.total_cost, 6),
53
+ "call_count": self.call_count
54
+ }
55
+
56
+
57
+ class LLMLayer:
58
+ """
59
+ Layer 3: Handles all LLM interactions.
60
+ Supports OpenAI API with mock fallback for testing.
61
+ """
62
+
63
+ def __init__(self, api_key="", model="gpt-4o-mini",
64
+ cost_per_1k_input=0.00015, cost_per_1k_output=0.0006):
65
+ self.model = model
66
+ self.cost_in = cost_per_1k_input
67
+ self.cost_out = cost_per_1k_output
68
+ self.client = None
69
+ self._api_key = api_key
70
+
71
+ def initialize(self):
72
+ """Initialize the OpenAI client."""
73
+ try:
74
+ from openai import OpenAI
75
+ import os
76
+ key = self._api_key or os.getenv("OPENAI_API_KEY", "")
77
+ if key:
78
+ self.client = OpenAI(api_key=key)
79
+ logger.info(f"LLM initialized: {self.model}")
80
+ else:
81
+ logger.warning("No API key β€” using mock mode")
82
+ except ImportError:
83
+ logger.warning("openai not installed β€” mock mode")
84
+
85
+ def _cost(self, inp, out):
86
+ return inp / 1000 * self.cost_in + out / 1000 * self.cost_out
87
+
88
+ def generate(self, messages, temperature=0.0, max_tokens=1024, json_mode=False):
89
+ """Generate a response from the LLM."""
90
+ start = time.perf_counter()
91
+
92
+ if self.client is None:
93
+ return LLMResponse(
94
+ content="[Mock response β€” no API key configured]",
95
+ input_tokens=50, output_tokens=20, total_tokens=70,
96
+ latency_ms=100.0, cost_usd=self._cost(50, 20), model=self.model
97
+ )
98
+
99
+ try:
100
+ kwargs = {"model": self.model, "messages": messages,
101
+ "temperature": temperature, "max_tokens": max_tokens}
102
+ if json_mode:
103
+ kwargs["response_format"] = {"type": "json_object"}
104
+ resp = self.client.chat.completions.create(**kwargs)
105
+ elapsed = (time.perf_counter() - start) * 1000
106
+ u = resp.usage
107
+ return LLMResponse(
108
+ content=resp.choices[0].message.content,
109
+ input_tokens=u.prompt_tokens, output_tokens=u.completion_tokens,
110
+ total_tokens=u.prompt_tokens + u.completion_tokens,
111
+ latency_ms=elapsed,
112
+ cost_usd=self._cost(u.prompt_tokens, u.completion_tokens),
113
+ model=self.model
114
+ )
115
+ except Exception as e:
116
+ elapsed = (time.perf_counter() - start) * 1000
117
+ logger.error(f"LLM error: {e}")
118
+ return LLMResponse(content=f"[Error: {e}]", latency_ms=elapsed, model=self.model)
119
+
120
+ # ── Specialized Functions ─────────────────────────────
121
+
122
+ def generate_answer(self, query, context, system_prompt=None):
123
+ """Generate an answer given query and context."""
124
+ if not system_prompt:
125
+ system_prompt = (
126
+ "You are a helpful assistant. Answer accurately using ONLY the provided context. "
127
+ "If the context doesn't contain enough info, say so. Be concise and precise."
128
+ )
129
+ return self.generate([
130
+ {"role": "system", "content": system_prompt},
131
+ {"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}\n\nAnswer:"}
132
+ ], max_tokens=512)
133
+
134
+ def extract_entities(self, text, entity_types=None, relation_types=None):
135
+ """Extract entities and relationships using schema-bounded extraction (novelty)."""
136
+ if not entity_types:
137
+ entity_types = ["PERSON", "ORGANIZATION", "LOCATION", "EVENT",
138
+ "DATE", "CONCEPT", "WORK", "PRODUCT", "TECHNOLOGY"]
139
+ if not relation_types:
140
+ relation_types = ["WORKS_FOR", "LOCATED_IN", "FOUNDED_BY", "PART_OF",
141
+ "RELATED_TO", "CREATED_BY", "HAPPENED_IN", "MEMBER_OF",
142
+ "COLLABORATES_WITH", "INFLUENCES"]
143
+
144
+ prompt = f"""Extract all entities and relationships from the text.
145
+ ALLOWED ENTITY TYPES: {json.dumps(entity_types)}
146
+ ALLOWED RELATION TYPES: {json.dumps(relation_types)}
147
+
148
+ Return JSON:
149
+ {{"entities": [{{"name": "exact name", "type": "one of allowed types", "description": "brief 1-sentence"}}],
150
+ "relations": [{{"source": "source entity name", "target": "target entity name", "type": "one of allowed types", "description": "brief"}}]}}
151
+
152
+ Text: {text}"""
153
+ return self.generate([{"role": "user", "content": prompt}], max_tokens=2048, json_mode=True)
154
+
155
+ def extract_keywords(self, query):
156
+ """Extract dual-level keywords for GraphRAG retrieval (novelty: LightRAG-inspired)."""
157
+ prompt = """Extract search keywords from this question. Return JSON:
158
+ {"high_level": ["abstract themes/topics"], "low_level": ["specific entities/names/dates"]}
159
+
160
+ Question: """ + query
161
+ return self.generate([{"role": "user", "content": prompt}], max_tokens=256, json_mode=True)
162
+
163
+ def analyze_query_complexity(self, query):
164
+ """Analyze query complexity for adaptive routing (novelty)."""
165
+ prompt = """Rate this question's complexity from 0.0 to 1.0. Return JSON:
166
+ {"complexity_score": 0.0-1.0, "reasoning": "brief", "query_type": "factoid|comparison|bridge|multi_hop", "estimated_hops": 1-4}
167
+
168
+ Score guide: 0.0-0.3 simple factoid, 0.3-0.6 moderate, 0.6-0.8 complex multi-entity, 0.8-1.0 multi-hop reasoning
169
+
170
+ Question: """ + query
171
+ return self.generate([{"role": "user", "content": prompt}], max_tokens=256, json_mode=True)
172
+
173
+ def generate_graph_explanation(self, query, entities, relations, answer):
174
+ """Generate natural language explanation of graph reasoning path (novelty)."""
175
+ ent_str = "\n".join([f"- {e.get('name','?')} ({e.get('entity_type','?')}): {e.get('description','')}"
176
+ for e in entities[:10]])
177
+ rel_str = "\n".join([f"- {r}" for r in relations[:15]])
178
+ prompt = f"""Explain the graph reasoning path for this answer step-by-step.
179
+
180
+ Question: {query}
181
+
182
+ Entities Found:
183
+ {ent_str}
184
+
185
+ Relationships Traversed:
186
+ {rel_str}
187
+
188
+ Generated Answer: {answer}
189
+
190
+ Format as:
191
+ 1. **Entry Points**: [which entities were found first]
192
+ 2. **Traversal**: [which relationships were followed, use A β†’ B β†’ C notation]
193
+ 3. **Evidence**: [which facts support the answer]
194
+ 4. **Conclusion**: [how the answer was derived]"""
195
+ return self.generate([{"role": "user", "content": prompt}], max_tokens=512)