muthuk1 commited on
Commit
6488963
Β·
verified Β·
1 Parent(s): 06a5155

Add Layer 1: Graph Layer (TigerGraph schema, GSQL queries, vector search, traversal)

Browse files
Files changed (1) hide show
  1. graphrag/layers/graph_layer.py +256 -0
graphrag/layers/graph_layer.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Layer 1: Graph Layer β€” TigerGraph Schema, Connection, and GSQL Queries
3
+ ======================================================================
4
+ Handles all graph database operations: schema creation, data upsert,
5
+ vector search, and multi-hop graph traversal.
6
+ """
7
+ import hashlib
8
+ import logging
9
+ import math
10
+ from typing import Any, Dict, List, Optional, Tuple
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ # ── GSQL Schema Definition ───────────────────────────────
15
+ SCHEMA_DDL = """
16
+ USE GRAPH {graphname}
17
+
18
+ CREATE VERTEX Document (PRIMARY_ID doc_id STRING, title STRING, content STRING, source STRING) WITH primary_id_as_attribute="true"
19
+ CREATE VERTEX Chunk (PRIMARY_ID chunk_id STRING, text STRING, embedding LIST<DOUBLE>, chunk_index INT, token_count INT, doc_id STRING) WITH primary_id_as_attribute="true"
20
+ CREATE VERTEX Entity (PRIMARY_ID entity_id STRING, name STRING, entity_type STRING, description STRING, embedding LIST<DOUBLE>, mention_count INT DEFAULT 1) WITH primary_id_as_attribute="true"
21
+ CREATE VERTEX Community (PRIMARY_ID community_id STRING, summary STRING, level INT DEFAULT 0, entity_count INT DEFAULT 0) WITH primary_id_as_attribute="true"
22
+
23
+ CREATE DIRECTED EDGE PART_OF (FROM Chunk, TO Document, position INT)
24
+ CREATE DIRECTED EDGE MENTIONS (FROM Chunk, TO Entity, mention_count INT DEFAULT 1, confidence FLOAT DEFAULT 1.0)
25
+ CREATE UNDIRECTED EDGE RELATED_TO (FROM Entity, TO Entity, relation_type STRING, weight FLOAT DEFAULT 1.0, description STRING, keywords STRING)
26
+ CREATE DIRECTED EDGE IN_COMMUNITY (FROM Entity, TO Community)
27
+ """
28
+
29
+ # ── GSQL Installed Queries ────────────────────────────────
30
+ VECTOR_SEARCH_QUERY = """
31
+ CREATE OR REPLACE QUERY vectorSearchChunks(LIST<DOUBLE> queryVec, INT topK) FOR GRAPH {graphname} {{
32
+ TYPEDEF TUPLE<STRING chunk_id, STRING text, DOUBLE score> ChunkScore;
33
+ HeapAccum<ChunkScore>(topK, score DESC) @@topChunks;
34
+ allChunks = {{Chunk.*}};
35
+ allChunks = SELECT c FROM allChunks:c WHERE c.embedding.size() > 0
36
+ ACCUM
37
+ DOUBLE dotProduct = 0.0, DOUBLE normA = 0.0, DOUBLE normB = 0.0,
38
+ FOREACH i IN RANGE[0, c.embedding.size() - 1] DO
39
+ dotProduct = dotProduct + queryVec.get(i) * c.embedding.get(i),
40
+ normA = normA + queryVec.get(i) * queryVec.get(i),
41
+ normB = normB + c.embedding.get(i) * c.embedding.get(i)
42
+ END,
43
+ DOUBLE sim = CASE WHEN sqrt(normA) * sqrt(normB) > 0 THEN dotProduct / (sqrt(normA) * sqrt(normB)) ELSE 0.0 END,
44
+ @@topChunks += ChunkScore(c.chunk_id, c.text, sim);
45
+ PRINT @@topChunks;
46
+ }}
47
+ INSTALL QUERY vectorSearchChunks
48
+ """
49
+
50
+ ENTITY_VECTOR_SEARCH_QUERY = """
51
+ CREATE OR REPLACE QUERY vectorSearchEntities(LIST<DOUBLE> queryVec, INT topK) FOR GRAPH {graphname} {{
52
+ TYPEDEF TUPLE<STRING entity_id, STRING name, STRING entity_type, STRING description, DOUBLE score> EntityScore;
53
+ HeapAccum<EntityScore>(topK, score DESC) @@topEntities;
54
+ allEntities = {{Entity.*}};
55
+ allEntities = SELECT e FROM allEntities:e WHERE e.embedding.size() > 0
56
+ ACCUM
57
+ DOUBLE dotProduct = 0.0, DOUBLE normA = 0.0, DOUBLE normB = 0.0,
58
+ FOREACH i IN RANGE[0, e.embedding.size() - 1] DO
59
+ dotProduct = dotProduct + queryVec.get(i) * e.embedding.get(i),
60
+ normA = normA + queryVec.get(i) * queryVec.get(i),
61
+ normB = normB + e.embedding.get(i) * e.embedding.get(i)
62
+ END,
63
+ DOUBLE sim = CASE WHEN sqrt(normA) * sqrt(normB) > 0 THEN dotProduct / (sqrt(normA) * sqrt(normB)) ELSE 0.0 END,
64
+ @@topEntities += EntityScore(e.entity_id, e.name, e.entity_type, e.description, sim);
65
+ PRINT @@topEntities;
66
+ }}
67
+ INSTALL QUERY vectorSearchEntities
68
+ """
69
+
70
+ GRAPHRAG_TRAVERSE_QUERY = """
71
+ CREATE OR REPLACE QUERY graphRAGTraverse(SET<STRING> seedEntityIds, INT hops) FOR GRAPH {graphname} {{
72
+ SetAccum<STRING> @@visitedEntityIds;
73
+ SetAccum<STRING> @@relevantChunkIds;
74
+ ListAccum<STRING> @@chunkTexts;
75
+ SetAccum<STRING> @@relationDescriptions;
76
+
77
+ Seeds = {{Entity.*}};
78
+ Seeds = SELECT e FROM Seeds:e WHERE e.entity_id IN seedEntityIds
79
+ ACCUM @@visitedEntityIds += e.entity_id;
80
+
81
+ FOREACH hop IN RANGE[1, hops] DO
82
+ Seeds = SELECT nbr FROM Seeds:e -(RELATED_TO:rel)- Entity:nbr
83
+ WHERE nbr.entity_id NOT IN @@visitedEntityIds
84
+ ACCUM @@visitedEntityIds += nbr.entity_id,
85
+ @@relationDescriptions += (e.name + " -[" + rel.relation_type + "]-> " + nbr.name + ": " + rel.description);
86
+ END;
87
+
88
+ AllVisited = {{Entity.*}};
89
+ AllVisited = SELECT e FROM AllVisited:e WHERE e.entity_id IN @@visitedEntityIds;
90
+
91
+ Chunks = SELECT c FROM AllVisited:e -(MENTIONS>:m)- Chunk:c
92
+ ACCUM @@relevantChunkIds += c.chunk_id, @@chunkTexts += c.text;
93
+
94
+ PRINT @@visitedEntityIds;
95
+ PRINT @@relevantChunkIds;
96
+ PRINT @@chunkTexts;
97
+ PRINT @@relationDescriptions;
98
+ PRINT AllVisited [AllVisited.name, AllVisited.entity_type, AllVisited.description];
99
+ }}
100
+ INSTALL QUERY graphRAGTraverse
101
+ """
102
+
103
+
104
+ class GraphLayer:
105
+ """Layer 1: TigerGraph Graph Layer β€” connection, schema, upserts, retrieval."""
106
+
107
+ def __init__(self, config=None):
108
+ self.config = config
109
+ self.conn = None
110
+ self._connected = False
111
+
112
+ def connect(self) -> bool:
113
+ """Establish connection to TigerGraph Cloud."""
114
+ try:
115
+ import pyTigerGraph as tg
116
+ cfg = self.config or {}
117
+ self.conn = tg.TigerGraphConnection(
118
+ host=cfg.get("host", ""),
119
+ graphname=cfg.get("graphname", "GraphRAG"),
120
+ username=cfg.get("username", "tigergraph"),
121
+ password=cfg.get("password", ""),
122
+ )
123
+ if cfg.get("token"):
124
+ self.conn.apiToken = cfg["token"]
125
+ else:
126
+ secret = self.conn.createSecret()
127
+ self.conn.getToken(secret)
128
+ self._connected = True
129
+ logger.info("Connected to TigerGraph Cloud successfully.")
130
+ return True
131
+ except Exception as e:
132
+ logger.error(f"TigerGraph connection failed: {e}")
133
+ return False
134
+
135
+ def create_schema(self) -> str:
136
+ gn = (self.config or {}).get("graphname", "GraphRAG")
137
+ return self.conn.gsql(SCHEMA_DDL.format(graphname=gn))
138
+
139
+ def install_queries(self) -> Dict[str, str]:
140
+ gn = (self.config or {}).get("graphname", "GraphRAG")
141
+ results = {}
142
+ for name, q in [("vectorSearchChunks", VECTOR_SEARCH_QUERY),
143
+ ("vectorSearchEntities", ENTITY_VECTOR_SEARCH_QUERY),
144
+ ("graphRAGTraverse", GRAPHRAG_TRAVERSE_QUERY)]:
145
+ try:
146
+ results[name] = self.conn.gsql(q.format(graphname=gn))
147
+ except Exception as e:
148
+ results[name] = str(e)
149
+ return results
150
+
151
+ # ── Data Upsert ───────────────────────────────────────
152
+
153
+ def upsert_document(self, doc_id, title, content, source=""):
154
+ self.conn.upsertVertex("Document", doc_id, {"title": title, "content": content, "source": source})
155
+
156
+ def upsert_chunk(self, chunk_id, text, embedding, chunk_index, token_count, doc_id):
157
+ self.conn.upsertVertex("Chunk", chunk_id, {"text": text, "embedding": embedding,
158
+ "chunk_index": chunk_index, "token_count": token_count, "doc_id": doc_id})
159
+ self.conn.upsertEdge("Chunk", chunk_id, "PART_OF", "Document", doc_id, {"position": chunk_index})
160
+
161
+ def upsert_entity(self, entity_id, name, entity_type, description, embedding):
162
+ self.conn.upsertVertex("Entity", entity_id, {"name": name, "entity_type": entity_type,
163
+ "description": description, "embedding": embedding})
164
+
165
+ def upsert_mention(self, chunk_id, entity_id, count=1, confidence=1.0):
166
+ self.conn.upsertEdge("Chunk", chunk_id, "MENTIONS", "Entity", entity_id,
167
+ {"mention_count": count, "confidence": confidence})
168
+
169
+ def upsert_relation(self, src_id, tgt_id, rtype, desc="", weight=1.0, keywords=""):
170
+ self.conn.upsertEdge("Entity", src_id, "RELATED_TO", "Entity", tgt_id,
171
+ {"relation_type": rtype, "description": desc, "weight": weight, "keywords": keywords})
172
+
173
+ # ── Retrieval ─────────────────────────────────────────
174
+
175
+ def vector_search_chunks(self, query_embedding, top_k=5):
176
+ try:
177
+ result = self.conn.runInstalledQuery("vectorSearchChunks",
178
+ params={"queryVec": query_embedding, "topK": top_k})
179
+ return result[0].get("@@topChunks", []) if result else []
180
+ except Exception as e:
181
+ logger.error(f"Vector search failed: {e}")
182
+ return []
183
+
184
+ def vector_search_entities(self, query_embedding, top_k=5):
185
+ try:
186
+ result = self.conn.runInstalledQuery("vectorSearchEntities",
187
+ params={"queryVec": query_embedding, "topK": top_k})
188
+ return result[0].get("@@topEntities", []) if result else []
189
+ except Exception as e:
190
+ logger.error(f"Entity search failed: {e}")
191
+ return []
192
+
193
+ def graph_traverse(self, seed_entity_ids, hops=2):
194
+ try:
195
+ result = self.conn.runInstalledQuery("graphRAGTraverse",
196
+ params={"seedEntityIds": seed_entity_ids, "hops": hops})
197
+ parsed = {"entity_ids": [], "chunk_ids": [], "chunk_texts": [], "relations": [], "entities": []}
198
+ if result:
199
+ for r in result:
200
+ if "@@visitedEntityIds" in r: parsed["entity_ids"] = list(r["@@visitedEntityIds"])
201
+ if "@@relevantChunkIds" in r: parsed["chunk_ids"] = list(r["@@relevantChunkIds"])
202
+ if "@@chunkTexts" in r: parsed["chunk_texts"] = r["@@chunkTexts"]
203
+ if "@@relationDescriptions" in r: parsed["relations"] = list(r["@@relationDescriptions"])
204
+ if "AllVisited" in r: parsed["entities"] = r["AllVisited"]
205
+ return parsed
206
+ except Exception as e:
207
+ logger.error(f"Traversal failed: {e}")
208
+ return {"entity_ids": [], "chunk_ids": [], "chunk_texts": [], "relations": [], "entities": []}
209
+
210
+ @property
211
+ def is_connected(self):
212
+ return self._connected
213
+
214
+
215
+ # ── Utility Functions ──────────────────────────────────────
216
+
217
+ def generate_entity_id(name: str, entity_type: str) -> str:
218
+ """Generate deterministic entity ID for deduplication."""
219
+ raw = f"{name.lower().strip()}:{entity_type.lower().strip()}"
220
+ return hashlib.md5(raw.encode()).hexdigest()[:12]
221
+
222
+ def generate_chunk_id(doc_id: str, chunk_index: int) -> str:
223
+ return f"{doc_id}_chunk_{chunk_index:04d}"
224
+
225
+ def chunk_text(text: str, chunk_size: int = 1000, overlap: int = 100) -> List[str]:
226
+ """Split text into overlapping chunks with sentence boundary detection."""
227
+ if not text:
228
+ return []
229
+ chunks = []
230
+ start = 0
231
+ while start < len(text):
232
+ end = min(start + chunk_size, len(text))
233
+ if end < len(text):
234
+ for sep in ['. ', '.\n', '\n\n', '\n', ' ']:
235
+ last_sep = text[start:end].rfind(sep)
236
+ if last_sep > chunk_size * 0.5:
237
+ end = start + last_sep + len(sep)
238
+ break
239
+ chunk = text[start:end].strip()
240
+ if chunk:
241
+ chunks.append(chunk)
242
+ start = end - overlap
243
+ if start >= len(text):
244
+ break
245
+ return chunks
246
+
247
+ def cosine_similarity(vec_a: List[float], vec_b: List[float]) -> float:
248
+ """Compute cosine similarity between two vectors."""
249
+ if len(vec_a) != len(vec_b):
250
+ return 0.0
251
+ dot = sum(a * b for a, b in zip(vec_a, vec_b))
252
+ na = math.sqrt(sum(a * a for a in vec_a))
253
+ nb = math.sqrt(sum(b * b for b in vec_b))
254
+ if na == 0 or nb == 0:
255
+ return 0.0
256
+ return dot / (na * nb)