File size: 8,133 Bytes
06a5155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
577adc4
 
06a5155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
577adc4
06a5155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
"""
Layer 3: LLM Layer β€” All LLM interactions with token tracking
=============================================================
Handles generation, entity extraction, keyword extraction,
query complexity analysis, and graph reasoning explanation.
"""
import json
import logging
import time
from dataclasses import dataclass, field
from typing import Any, Dict, List

logger = logging.getLogger(__name__)


@dataclass
class LLMResponse:
    """Container for LLM response with usage metadata."""
    content: str = ""
    input_tokens: int = 0
    output_tokens: int = 0
    total_tokens: int = 0
    latency_ms: float = 0.0
    cost_usd: float = 0.0
    model: str = ""


@dataclass
class TokenTracker:
    """Tracks cumulative token usage and costs."""
    total_input_tokens: int = 0
    total_output_tokens: int = 0
    total_cost: float = 0.0
    call_count: int = 0
    calls: List[Dict] = field(default_factory=list)

    def record(self, resp: LLMResponse, label: str = ""):
        self.total_input_tokens += resp.input_tokens
        self.total_output_tokens += resp.output_tokens
        self.total_cost += resp.cost_usd
        self.call_count += 1
        self.calls.append({
            "label": label, "input_tokens": resp.input_tokens,
            "output_tokens": resp.output_tokens, "cost_usd": resp.cost_usd,
            "latency_ms": resp.latency_ms
        })

    def summary(self):
        return {
            "total_input_tokens": self.total_input_tokens,
            "total_output_tokens": self.total_output_tokens,
            "total_cost_usd": round(self.total_cost, 6),
            "call_count": self.call_count
        }


class LLMLayer:
    """
    Layer 3: Handles all LLM interactions.
    Supports OpenAI API with mock fallback for testing.
    """

    def __init__(self, api_key="", model="gpt-4o-mini",
                 cost_per_1k_input=0.00015, cost_per_1k_output=0.0006):
        self.model = model
        self.cost_in = cost_per_1k_input
        self.cost_out = cost_per_1k_output
        self.client = None
        self._api_key = api_key

    def initialize(self):
        """Initialize the OpenAI client."""
        try:
            from openai import OpenAI
            import os
            key = self._api_key or os.getenv("OPENAI_API_KEY", "")
            if key:
                base_url = os.getenv("OPENAI_BASE_URL", "")
                self.client = OpenAI(api_key=key, base_url=base_url) if base_url else OpenAI(api_key=key)
                logger.info(f"LLM initialized: {self.model}")
            else:
                logger.warning("No API key β€” using mock mode")
        except ImportError:
            logger.warning("openai not installed β€” mock mode")

    def _cost(self, inp, out):
        return inp / 1000 * self.cost_in + out / 1000 * self.cost_out

    def generate(self, messages, temperature=0.0, max_tokens=1024, json_mode=False):
        """Generate a response from the LLM."""
        start = time.perf_counter()

        if self.client is None:
            return LLMResponse(
                content="[Mock response β€” no API key configured]",
                input_tokens=50, output_tokens=20, total_tokens=70,
                latency_ms=100.0, cost_usd=self._cost(50, 20), model=self.model
            )

        try:
            kwargs = {"model": self.model, "messages": messages,
                      "temperature": temperature, "max_tokens": max_tokens}
            if json_mode:
                kwargs["response_format"] = {"type": "json_object"}
            resp = self.client.chat.completions.create(**kwargs)
            elapsed = (time.perf_counter() - start) * 1000
            u = resp.usage
            return LLMResponse(
                content=resp.choices[0].message.content,
                input_tokens=u.prompt_tokens, output_tokens=u.completion_tokens,
                total_tokens=u.prompt_tokens + u.completion_tokens,
                latency_ms=elapsed,
                cost_usd=self._cost(u.prompt_tokens, u.completion_tokens),
                model=self.model
            )
        except Exception as e:
            elapsed = (time.perf_counter() - start) * 1000
            logger.error(f"LLM error: {e}")
            return LLMResponse(content=f"[Error: {e}]", latency_ms=elapsed, model=self.model)

    # ── Specialized Functions ─────────────────────────────

    def generate_answer(self, query, context, system_prompt=None):
        """Generate an answer given query and context."""
        if not system_prompt:
            system_prompt = (
                "You are a helpful assistant. Answer accurately using ONLY the provided context. "
                "If the context doesn't contain enough info, say so. Be concise and precise."
            )
        return self.generate([
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}\n\nAnswer:"}
        ], max_tokens=512)

    def extract_entities(self, text, entity_types=None, relation_types=None):
        """Extract entities and relationships using schema-bounded extraction (novelty)."""
        if not entity_types:
            entity_types = ["PERSON", "ORGANIZATION", "LOCATION", "EVENT",
                            "DATE", "CONCEPT", "WORK", "PRODUCT", "TECHNOLOGY"]
        if not relation_types:
            relation_types = ["WORKS_FOR", "LOCATED_IN", "FOUNDED_BY", "PART_OF",
                              "RELATED_TO", "CREATED_BY", "HAPPENED_IN", "MEMBER_OF",
                              "COLLABORATES_WITH", "INFLUENCES"]

        prompt = f"""Extract all entities and relationships from the text.
ALLOWED ENTITY TYPES: {json.dumps(entity_types)}
ALLOWED RELATION TYPES: {json.dumps(relation_types)}

Return JSON:
{{"entities": [{{"name": "exact name", "type": "one of allowed types", "description": "brief 1-sentence"}}],
 "relations": [{{"source": "source entity name", "target": "target entity name", "type": "one of allowed types", "description": "brief"}}]}}

Text: {text}"""
        return self.generate([{"role": "user", "content": prompt}], max_tokens=4096, json_mode=False)

    def extract_keywords(self, query):
        """Extract dual-level keywords for GraphRAG retrieval (novelty: LightRAG-inspired)."""
        prompt = """Extract search keywords from this question. Return JSON:
{"high_level": ["abstract themes/topics"], "low_level": ["specific entities/names/dates"]}

Question: """ + query
        return self.generate([{"role": "user", "content": prompt}], max_tokens=256, json_mode=True)

    def analyze_query_complexity(self, query):
        """Analyze query complexity for adaptive routing (novelty)."""
        prompt = """Rate this question's complexity from 0.0 to 1.0. Return JSON:
{"complexity_score": 0.0-1.0, "reasoning": "brief", "query_type": "factoid|comparison|bridge|multi_hop", "estimated_hops": 1-4}

Score guide: 0.0-0.3 simple factoid, 0.3-0.6 moderate, 0.6-0.8 complex multi-entity, 0.8-1.0 multi-hop reasoning

Question: """ + query
        return self.generate([{"role": "user", "content": prompt}], max_tokens=256, json_mode=True)

    def generate_graph_explanation(self, query, entities, relations, answer):
        """Generate natural language explanation of graph reasoning path (novelty)."""
        ent_str = "\n".join([f"- {e.get('name','?')} ({e.get('entity_type','?')}): {e.get('description','')}"
                             for e in entities[:10]])
        rel_str = "\n".join([f"- {r}" for r in relations[:15]])
        prompt = f"""Explain the graph reasoning path for this answer step-by-step.

Question: {query}

Entities Found:
{ent_str}

Relationships Traversed:
{rel_str}

Generated Answer: {answer}

Format as:
1. **Entry Points**: [which entities were found first]
2. **Traversal**: [which relationships were followed, use A β†’ B β†’ C notation]
3. **Evidence**: [which facts support the answer]
4. **Conclusion**: [how the answer was derived]"""
        return self.generate([{"role": "user", "content": prompt}], max_tokens=512)