File size: 14,477 Bytes
6488963
 
 
 
 
 
 
 
 
 
 
 
 
 
577adc4
 
6488963
 
 
 
 
 
 
 
 
 
 
 
577adc4
 
 
 
 
 
 
 
6488963
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
577adc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6488963
577adc4
 
 
6488963
 
 
 
 
 
 
 
 
 
577adc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6488963
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
577adc4
6488963
577adc4
6488963
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
"""
Layer 1: Graph Layer β€” TigerGraph Schema, Connection, and GSQL Queries
======================================================================
Handles all graph database operations: schema creation, data upsert,
vector search, and multi-hop graph traversal.
"""
import hashlib
import logging
import math
from typing import Any, Dict, List, Optional, Tuple

logger = logging.getLogger(__name__)

# ── GSQL Schema Definition ───────────────────────────────
SCHEMA_DDL_GLOBAL = """
USE GLOBAL

CREATE VERTEX Document (PRIMARY_ID doc_id STRING, title STRING, content STRING, source STRING) WITH primary_id_as_attribute="true"
CREATE VERTEX Chunk (PRIMARY_ID chunk_id STRING, text STRING, embedding LIST<DOUBLE>, chunk_index INT, token_count INT, doc_id STRING) WITH primary_id_as_attribute="true"
CREATE VERTEX Entity (PRIMARY_ID entity_id STRING, name STRING, entity_type STRING, description STRING, embedding LIST<DOUBLE>, mention_count INT DEFAULT 1) WITH primary_id_as_attribute="true"
CREATE VERTEX Community (PRIMARY_ID community_id STRING, summary STRING, level INT DEFAULT 0, entity_count INT DEFAULT 0) WITH primary_id_as_attribute="true"

CREATE DIRECTED EDGE PART_OF (FROM Chunk, TO Document, position INT)
CREATE DIRECTED EDGE MENTIONS (FROM Chunk, TO Entity, mention_count INT DEFAULT 1, confidence FLOAT DEFAULT 1.0)
CREATE UNDIRECTED EDGE RELATED_TO (FROM Entity, TO Entity, relation_type STRING, weight FLOAT DEFAULT 1.0, description STRING, keywords STRING)
CREATE DIRECTED EDGE IN_COMMUNITY (FROM Entity, TO Community)
"""

SCHEMA_DDL_GRAPH = """
CREATE GRAPH {graphname}(Document, Chunk, Entity, Community, PART_OF, MENTIONS, RELATED_TO, IN_COMMUNITY)
"""

SCHEMA_DDL_DROP_GRAPH = """
DROP GRAPH {graphname}
"""

# ── GSQL Installed Queries ────────────────────────────────
VECTOR_SEARCH_QUERY = """
CREATE OR REPLACE QUERY vectorSearchChunks(LIST<DOUBLE> queryVec, INT topK) FOR GRAPH {graphname} {{
    TYPEDEF TUPLE<STRING chunk_id, STRING text, DOUBLE score> ChunkScore;
    HeapAccum<ChunkScore>(topK, score DESC) @@topChunks;
    allChunks = {{Chunk.*}};
    allChunks = SELECT c FROM allChunks:c WHERE c.embedding.size() > 0
        ACCUM
            DOUBLE dotProduct = 0.0, DOUBLE normA = 0.0, DOUBLE normB = 0.0,
            FOREACH i IN RANGE[0, c.embedding.size() - 1] DO
                dotProduct = dotProduct + queryVec.get(i) * c.embedding.get(i),
                normA = normA + queryVec.get(i) * queryVec.get(i),
                normB = normB + c.embedding.get(i) * c.embedding.get(i)
            END,
            DOUBLE sim = CASE WHEN sqrt(normA) * sqrt(normB) > 0 THEN dotProduct / (sqrt(normA) * sqrt(normB)) ELSE 0.0 END,
            @@topChunks += ChunkScore(c.chunk_id, c.text, sim);
    PRINT @@topChunks;
}}
INSTALL QUERY vectorSearchChunks
"""

ENTITY_VECTOR_SEARCH_QUERY = """
CREATE OR REPLACE QUERY vectorSearchEntities(LIST<DOUBLE> queryVec, INT topK) FOR GRAPH {graphname} {{
    TYPEDEF TUPLE<STRING entity_id, STRING name, STRING entity_type, STRING description, DOUBLE score> EntityScore;
    HeapAccum<EntityScore>(topK, score DESC) @@topEntities;
    allEntities = {{Entity.*}};
    allEntities = SELECT e FROM allEntities:e WHERE e.embedding.size() > 0
        ACCUM
            DOUBLE dotProduct = 0.0, DOUBLE normA = 0.0, DOUBLE normB = 0.0,
            FOREACH i IN RANGE[0, e.embedding.size() - 1] DO
                dotProduct = dotProduct + queryVec.get(i) * e.embedding.get(i),
                normA = normA + queryVec.get(i) * queryVec.get(i),
                normB = normB + e.embedding.get(i) * e.embedding.get(i)
            END,
            DOUBLE sim = CASE WHEN sqrt(normA) * sqrt(normB) > 0 THEN dotProduct / (sqrt(normA) * sqrt(normB)) ELSE 0.0 END,
            @@topEntities += EntityScore(e.entity_id, e.name, e.entity_type, e.description, sim);
    PRINT @@topEntities;
}}
INSTALL QUERY vectorSearchEntities
"""

GRAPHRAG_TRAVERSE_QUERY = """
CREATE OR REPLACE QUERY graphRAGTraverse(SET<STRING> seedEntityIds, INT hops) FOR GRAPH {graphname} {{
    SetAccum<STRING> @@visitedEntityIds;
    SetAccum<STRING> @@relevantChunkIds;
    ListAccum<STRING> @@chunkTexts;
    SetAccum<STRING> @@relationDescriptions;

    Seeds = {{Entity.*}};
    Seeds = SELECT e FROM Seeds:e WHERE e.entity_id IN seedEntityIds
        ACCUM @@visitedEntityIds += e.entity_id;

    FOREACH hop IN RANGE[1, hops] DO
        Seeds = SELECT nbr FROM Seeds:e -(RELATED_TO:rel)- Entity:nbr
            WHERE nbr.entity_id NOT IN @@visitedEntityIds
            ACCUM @@visitedEntityIds += nbr.entity_id,
                  @@relationDescriptions += (e.name + " -[" + rel.relation_type + "]-> " + nbr.name + ": " + rel.description);
    END;

    AllVisited = {{Entity.*}};
    AllVisited = SELECT e FROM AllVisited:e WHERE e.entity_id IN @@visitedEntityIds;

    Chunks = SELECT c FROM AllVisited:e -(MENTIONS>:m)- Chunk:c
        ACCUM @@relevantChunkIds += c.chunk_id, @@chunkTexts += c.text;

    PRINT @@visitedEntityIds;
    PRINT @@relevantChunkIds;
    PRINT @@chunkTexts;
    PRINT @@relationDescriptions;
    PRINT AllVisited [AllVisited.name, AllVisited.entity_type, AllVisited.description];
}}
INSTALL QUERY graphRAGTraverse
"""


class GraphLayer:
    """Layer 1: TigerGraph Graph Layer β€” connection, schema, upserts, retrieval."""

    def __init__(self, config=None):
        self.config = config
        self.conn = None
        self._connected = False

    def connect(self) -> bool:
        """Establish connection to TigerGraph Cloud."""
        try:
            import pyTigerGraph as tg
            cfg = self.config or {}
            import requests as _req
            host = cfg.get("host", "").rstrip("/")
            secret = cfg.get("token", "")
            graphname = cfg.get("graphname", "GraphRAG")
            # Try TG 4.x then 3.x token endpoints
            api_token = ""
            for endpoint, payload in [
                ("/gsql/v1/tokens", {"secret": secret}),
                ("/restpp/requesttoken", {"secret": secret, "lifetime": 2592000}),
            ]:
                try:
                    r = _req.post(f"{host}{endpoint}", json=payload, timeout=15)
                    logger.info(f"[{endpoint}] status={r.status_code} body={r.text[:300]}")
                    if r.status_code == 200:
                        data = r.json()
                        api_token = (data.get("token")
                                     or data.get("results", {}).get("token", "")
                                     or data.get("data", {}).get("token", ""))
                        if api_token:
                            logger.info(f"Token obtained via {endpoint}")
                            break
                except Exception as ex:
                    logger.info(f"[{endpoint}] exception: {ex}")
                    continue
            if not api_token:
                raise RuntimeError("Could not obtain token from any endpoint")
            self.conn = tg.TigerGraphConnection(
                host=host,
                graphname=graphname,
                apiToken=api_token,
            )
            self._connected = True
            logger.info("Connected to TigerGraph Cloud successfully.")
            return True
        except Exception as e:
            logger.error(f"TigerGraph connection failed: {e}")
            return False

    def create_schema(self) -> str:
        gn = (self.config or {}).get("graphname", "GraphRAG")
        try:
            existing = self.conn.getVertexTypes()
            if "Document" not in existing:
                r1 = self.conn.gsql(SCHEMA_DDL_GLOBAL)
                logger.info(f"Global schema: {str(r1)[:300]}")
            else:
                logger.info("Global vertex types already exist, skipping.")
        except Exception as e:
            logger.warning(f"Global schema check: {e}")
        try:
            r2 = self.conn.gsql(SCHEMA_DDL_GRAPH.format(graphname=gn))
            if "could not be created" in str(r2) or "conflicts" in str(r2):
                logger.info(f"Graph '{gn}' already exists, skipping.")
                return "exists"
            logger.info(f"Graph schema: {str(r2)[:300]}")
            return r2
        except Exception as e:
            if "conflict" in str(e).lower() or "already" in str(e).lower():
                logger.info(f"Graph '{gn}' already exists, skipping.")
                return "exists"
            raise

    def install_queries(self) -> Dict[str, str]:
        gn = (self.config or {}).get("graphname", "GraphRAG")
        results = {}
        for name, q in [("vectorSearchChunks", VECTOR_SEARCH_QUERY),
                         ("vectorSearchEntities", ENTITY_VECTOR_SEARCH_QUERY),
                         ("graphRAGTraverse", GRAPHRAG_TRAVERSE_QUERY)]:
            try:
                results[name] = self.conn.gsql(q.format(graphname=gn))
            except Exception as e:
                results[name] = str(e)
        return results

    # ── Data Upsert ───────────────────────────────────────

    def upsert_document(self, doc_id, title, content, source=""):
        self.conn.upsertVertex("Document", doc_id, {"title": title, "content": content, "source": source})

    def upsert_chunk(self, chunk_id, text, embedding, chunk_index, token_count, doc_id):
        self.conn.upsertVertex("Chunk", chunk_id, {"text": text, "embedding": embedding,
                                                     "chunk_index": chunk_index, "token_count": token_count, "doc_id": doc_id})
        self.conn.upsertEdge("Chunk", chunk_id, "PART_OF", "Document", doc_id, {"position": chunk_index})

    def upsert_entity(self, entity_id, name, entity_type, description, embedding):
        self.conn.upsertVertex("Entity", entity_id, {"name": name, "entity_type": entity_type,
                                                       "description": description, "embedding": embedding})

    def upsert_mention(self, chunk_id, entity_id, count=1, confidence=1.0):
        self.conn.upsertEdge("Chunk", chunk_id, "MENTIONS", "Entity", entity_id,
                             {"mention_count": count, "confidence": confidence})

    def upsert_relation(self, src_id, tgt_id, rtype, desc="", weight=1.0, keywords=""):
        self.conn.upsertEdge("Entity", src_id, "RELATED_TO", "Entity", tgt_id,
                             {"relation_type": rtype, "description": desc, "weight": weight, "keywords": keywords})

    # ── Retrieval ─────────────────────────────────────────

    def vector_search_chunks(self, query_embedding, top_k=5):
        try:
            result = self.conn.runInstalledQuery("vectorSearchChunks",
                                                  params={"queryVec": query_embedding, "topK": top_k})
            return result[0].get("@@topChunks", []) if result else []
        except Exception as e:
            logger.error(f"Vector search failed: {e}")
            return []

    def vector_search_entities(self, query_embedding, top_k=5):
        try:
            result = self.conn.runInstalledQuery("vectorSearchEntities",
                                                  params={"queryVec": query_embedding, "topK": top_k})
            return result[0].get("@@topEntities", []) if result else []
        except Exception as e:
            logger.error(f"Entity search failed: {e}")
            return []

    def graph_traverse(self, seed_entity_ids, hops=2):
        try:
            result = self.conn.runInstalledQuery("graphRAGTraverse",
                                                  params={"seedEntityIds": seed_entity_ids, "hops": hops})
            parsed = {"entity_ids": [], "chunk_ids": [], "chunk_texts": [], "relations": [], "entities": []}
            if result:
                for r in result:
                    if "@@visitedEntityIds" in r: parsed["entity_ids"] = list(r["@@visitedEntityIds"])
                    if "@@relevantChunkIds" in r: parsed["chunk_ids"] = list(r["@@relevantChunkIds"])
                    if "@@chunkTexts" in r: parsed["chunk_texts"] = r["@@chunkTexts"]
                    if "@@relationDescriptions" in r: parsed["relations"] = list(r["@@relationDescriptions"])
                    if "AllVisited" in r: parsed["entities"] = r["AllVisited"]
            return parsed
        except Exception as e:
            logger.error(f"Traversal failed: {e}")
            return {"entity_ids": [], "chunk_ids": [], "chunk_texts": [], "relations": [], "entities": []}

    @property
    def is_connected(self):
        return self._connected


# ── Utility Functions ──────────────────────────────────────

def generate_entity_id(name: str, entity_type: str) -> str:
    """Generate deterministic entity ID for deduplication."""
    raw = f"{name.lower().strip()}:{entity_type.lower().strip()}"
    return hashlib.md5(raw.encode()).hexdigest()[:12]

def generate_chunk_id(doc_id: str, chunk_index: int) -> str:
    return f"{doc_id}_chunk_{chunk_index:04d}"

def chunk_text(text: str, chunk_size: int = 1000, overlap: int = 100) -> List[str]:
    """Split text into overlapping chunks with sentence boundary detection."""
    if not text:
        return []
    chunks = []
    start = 0
    while start < len(text):
        end = min(start + chunk_size, len(text))
        if end < len(text):
            for sep in ['. ', '.\n', '\n\n', '\n', ' ']:
                last_sep = text[start:end].rfind(sep)
                if last_sep > chunk_size * 0.5:
                    end = start + last_sep + len(sep)
                    break
        chunk = text[start:end].strip()
        if chunk:
            chunks.append(chunk)
        if end >= len(text):
            break
        start = end - overlap
    return chunks

def cosine_similarity(vec_a: List[float], vec_b: List[float]) -> float:
    """Compute cosine similarity between two vectors."""
    if len(vec_a) != len(vec_b):
        return 0.0
    dot = sum(a * b for a, b in zip(vec_a, vec_b))
    na = math.sqrt(sum(a * a for a in vec_a))
    nb = math.sqrt(sum(b * b for b in vec_b))
    if na == 0 or nb == 0:
        return 0.0
    return dot / (na * nb)