""" Layer 1: Graph Layer — TigerGraph Schema, Connection, and GSQL Queries ====================================================================== Handles all graph database operations: schema creation, data upsert, vector search, and multi-hop graph traversal. """ import hashlib import logging import math from typing import Any, Dict, List, Optional, Tuple logger = logging.getLogger(__name__) # ── GSQL Schema Definition ─────────────────────────────── SCHEMA_DDL_GLOBAL = """ USE GLOBAL CREATE VERTEX Document (PRIMARY_ID doc_id STRING, title STRING, content STRING, source STRING) WITH primary_id_as_attribute="true" CREATE VERTEX Chunk (PRIMARY_ID chunk_id STRING, text STRING, embedding LIST, chunk_index INT, token_count INT, doc_id STRING) WITH primary_id_as_attribute="true" CREATE VERTEX Entity (PRIMARY_ID entity_id STRING, name STRING, entity_type STRING, description STRING, embedding LIST, mention_count INT DEFAULT 1) WITH primary_id_as_attribute="true" CREATE VERTEX Community (PRIMARY_ID community_id STRING, summary STRING, level INT DEFAULT 0, entity_count INT DEFAULT 0) WITH primary_id_as_attribute="true" CREATE DIRECTED EDGE PART_OF (FROM Chunk, TO Document, position INT) CREATE DIRECTED EDGE MENTIONS (FROM Chunk, TO Entity, mention_count INT DEFAULT 1, confidence FLOAT DEFAULT 1.0) CREATE UNDIRECTED EDGE RELATED_TO (FROM Entity, TO Entity, relation_type STRING, weight FLOAT DEFAULT 1.0, description STRING, keywords STRING) CREATE DIRECTED EDGE IN_COMMUNITY (FROM Entity, TO Community) """ SCHEMA_DDL_GRAPH = """ CREATE GRAPH {graphname}(Document, Chunk, Entity, Community, PART_OF, MENTIONS, RELATED_TO, IN_COMMUNITY) """ SCHEMA_DDL_DROP_GRAPH = """ DROP GRAPH {graphname} """ # ── GSQL Installed Queries ──────────────────────────────── VECTOR_SEARCH_QUERY = """ CREATE OR REPLACE QUERY vectorSearchChunks(LIST queryVec, INT topK) FOR GRAPH {graphname} {{ TYPEDEF TUPLE ChunkScore; HeapAccum(topK, score DESC) @@topChunks; allChunks = {{Chunk.*}}; allChunks = SELECT c FROM allChunks:c WHERE c.embedding.size() > 0 ACCUM DOUBLE dotProduct = 0.0, DOUBLE normA = 0.0, DOUBLE normB = 0.0, FOREACH i IN RANGE[0, c.embedding.size() - 1] DO dotProduct = dotProduct + queryVec.get(i) * c.embedding.get(i), normA = normA + queryVec.get(i) * queryVec.get(i), normB = normB + c.embedding.get(i) * c.embedding.get(i) END, DOUBLE sim = CASE WHEN sqrt(normA) * sqrt(normB) > 0 THEN dotProduct / (sqrt(normA) * sqrt(normB)) ELSE 0.0 END, @@topChunks += ChunkScore(c.chunk_id, c.text, sim); PRINT @@topChunks; }} INSTALL QUERY vectorSearchChunks """ ENTITY_VECTOR_SEARCH_QUERY = """ CREATE OR REPLACE QUERY vectorSearchEntities(LIST queryVec, INT topK) FOR GRAPH {graphname} {{ TYPEDEF TUPLE EntityScore; HeapAccum(topK, score DESC) @@topEntities; allEntities = {{Entity.*}}; allEntities = SELECT e FROM allEntities:e WHERE e.embedding.size() > 0 ACCUM DOUBLE dotProduct = 0.0, DOUBLE normA = 0.0, DOUBLE normB = 0.0, FOREACH i IN RANGE[0, e.embedding.size() - 1] DO dotProduct = dotProduct + queryVec.get(i) * e.embedding.get(i), normA = normA + queryVec.get(i) * queryVec.get(i), normB = normB + e.embedding.get(i) * e.embedding.get(i) END, DOUBLE sim = CASE WHEN sqrt(normA) * sqrt(normB) > 0 THEN dotProduct / (sqrt(normA) * sqrt(normB)) ELSE 0.0 END, @@topEntities += EntityScore(e.entity_id, e.name, e.entity_type, e.description, sim); PRINT @@topEntities; }} INSTALL QUERY vectorSearchEntities """ GRAPHRAG_TRAVERSE_QUERY = """ CREATE OR REPLACE QUERY graphRAGTraverse(SET seedEntityIds, INT hops) FOR GRAPH {graphname} {{ SetAccum @@visitedEntityIds; SetAccum @@relevantChunkIds; ListAccum @@chunkTexts; SetAccum @@relationDescriptions; Seeds = {{Entity.*}}; Seeds = SELECT e FROM Seeds:e WHERE e.entity_id IN seedEntityIds ACCUM @@visitedEntityIds += e.entity_id; FOREACH hop IN RANGE[1, hops] DO Seeds = SELECT nbr FROM Seeds:e -(RELATED_TO:rel)- Entity:nbr WHERE nbr.entity_id NOT IN @@visitedEntityIds ACCUM @@visitedEntityIds += nbr.entity_id, @@relationDescriptions += (e.name + " -[" + rel.relation_type + "]-> " + nbr.name + ": " + rel.description); END; AllVisited = {{Entity.*}}; AllVisited = SELECT e FROM AllVisited:e WHERE e.entity_id IN @@visitedEntityIds; Chunks = SELECT c FROM AllVisited:e -(MENTIONS>:m)- Chunk:c ACCUM @@relevantChunkIds += c.chunk_id, @@chunkTexts += c.text; PRINT @@visitedEntityIds; PRINT @@relevantChunkIds; PRINT @@chunkTexts; PRINT @@relationDescriptions; PRINT AllVisited [AllVisited.name, AllVisited.entity_type, AllVisited.description]; }} INSTALL QUERY graphRAGTraverse """ class GraphLayer: """Layer 1: TigerGraph Graph Layer — connection, schema, upserts, retrieval.""" def __init__(self, config=None): self.config = config self.conn = None self._connected = False def connect(self) -> bool: """Establish connection to TigerGraph Cloud.""" try: import pyTigerGraph as tg cfg = self.config or {} import requests as _req host = cfg.get("host", "").rstrip("/") secret = cfg.get("token", "") graphname = cfg.get("graphname", "GraphRAG") # Try TG 4.x then 3.x token endpoints api_token = "" for endpoint, payload in [ ("/gsql/v1/tokens", {"secret": secret}), ("/restpp/requesttoken", {"secret": secret, "lifetime": 2592000}), ]: try: r = _req.post(f"{host}{endpoint}", json=payload, timeout=15) logger.info(f"[{endpoint}] status={r.status_code} body={r.text[:300]}") if r.status_code == 200: data = r.json() api_token = (data.get("token") or data.get("results", {}).get("token", "") or data.get("data", {}).get("token", "")) if api_token: logger.info(f"Token obtained via {endpoint}") break except Exception as ex: logger.info(f"[{endpoint}] exception: {ex}") continue if not api_token: raise RuntimeError("Could not obtain token from any endpoint") self.conn = tg.TigerGraphConnection( host=host, graphname=graphname, apiToken=api_token, ) self._connected = True logger.info("Connected to TigerGraph Cloud successfully.") return True except Exception as e: logger.error(f"TigerGraph connection failed: {e}") return False def create_schema(self) -> str: gn = (self.config or {}).get("graphname", "GraphRAG") try: existing = self.conn.getVertexTypes() if "Document" not in existing: r1 = self.conn.gsql(SCHEMA_DDL_GLOBAL) logger.info(f"Global schema: {str(r1)[:300]}") else: logger.info("Global vertex types already exist, skipping.") except Exception as e: logger.warning(f"Global schema check: {e}") try: r2 = self.conn.gsql(SCHEMA_DDL_GRAPH.format(graphname=gn)) if "could not be created" in str(r2) or "conflicts" in str(r2): logger.info(f"Graph '{gn}' already exists, skipping.") return "exists" logger.info(f"Graph schema: {str(r2)[:300]}") return r2 except Exception as e: if "conflict" in str(e).lower() or "already" in str(e).lower(): logger.info(f"Graph '{gn}' already exists, skipping.") return "exists" raise def install_queries(self) -> Dict[str, str]: gn = (self.config or {}).get("graphname", "GraphRAG") results = {} for name, q in [("vectorSearchChunks", VECTOR_SEARCH_QUERY), ("vectorSearchEntities", ENTITY_VECTOR_SEARCH_QUERY), ("graphRAGTraverse", GRAPHRAG_TRAVERSE_QUERY)]: try: results[name] = self.conn.gsql(q.format(graphname=gn)) except Exception as e: results[name] = str(e) return results # ── Data Upsert ─────────────────────────────────────── def upsert_document(self, doc_id, title, content, source=""): self.conn.upsertVertex("Document", doc_id, {"title": title, "content": content, "source": source}) def upsert_chunk(self, chunk_id, text, embedding, chunk_index, token_count, doc_id): self.conn.upsertVertex("Chunk", chunk_id, {"text": text, "embedding": embedding, "chunk_index": chunk_index, "token_count": token_count, "doc_id": doc_id}) self.conn.upsertEdge("Chunk", chunk_id, "PART_OF", "Document", doc_id, {"position": chunk_index}) def upsert_entity(self, entity_id, name, entity_type, description, embedding): self.conn.upsertVertex("Entity", entity_id, {"name": name, "entity_type": entity_type, "description": description, "embedding": embedding}) def upsert_mention(self, chunk_id, entity_id, count=1, confidence=1.0): self.conn.upsertEdge("Chunk", chunk_id, "MENTIONS", "Entity", entity_id, {"mention_count": count, "confidence": confidence}) def upsert_relation(self, src_id, tgt_id, rtype, desc="", weight=1.0, keywords=""): self.conn.upsertEdge("Entity", src_id, "RELATED_TO", "Entity", tgt_id, {"relation_type": rtype, "description": desc, "weight": weight, "keywords": keywords}) # ── Retrieval ───────────────────────────────────────── def vector_search_chunks(self, query_embedding, top_k=5): try: result = self.conn.runInstalledQuery("vectorSearchChunks", params={"queryVec": query_embedding, "topK": top_k}) return result[0].get("@@topChunks", []) if result else [] except Exception as e: logger.error(f"Vector search failed: {e}") return [] def vector_search_entities(self, query_embedding, top_k=5): try: result = self.conn.runInstalledQuery("vectorSearchEntities", params={"queryVec": query_embedding, "topK": top_k}) return result[0].get("@@topEntities", []) if result else [] except Exception as e: logger.error(f"Entity search failed: {e}") return [] def graph_traverse(self, seed_entity_ids, hops=2): try: result = self.conn.runInstalledQuery("graphRAGTraverse", params={"seedEntityIds": seed_entity_ids, "hops": hops}) parsed = {"entity_ids": [], "chunk_ids": [], "chunk_texts": [], "relations": [], "entities": []} if result: for r in result: if "@@visitedEntityIds" in r: parsed["entity_ids"] = list(r["@@visitedEntityIds"]) if "@@relevantChunkIds" in r: parsed["chunk_ids"] = list(r["@@relevantChunkIds"]) if "@@chunkTexts" in r: parsed["chunk_texts"] = r["@@chunkTexts"] if "@@relationDescriptions" in r: parsed["relations"] = list(r["@@relationDescriptions"]) if "AllVisited" in r: parsed["entities"] = r["AllVisited"] return parsed except Exception as e: logger.error(f"Traversal failed: {e}") return {"entity_ids": [], "chunk_ids": [], "chunk_texts": [], "relations": [], "entities": []} @property def is_connected(self): return self._connected # ── Utility Functions ────────────────────────────────────── def generate_entity_id(name: str, entity_type: str) -> str: """Generate deterministic entity ID for deduplication.""" raw = f"{name.lower().strip()}:{entity_type.lower().strip()}" return hashlib.md5(raw.encode()).hexdigest()[:12] def generate_chunk_id(doc_id: str, chunk_index: int) -> str: return f"{doc_id}_chunk_{chunk_index:04d}" def chunk_text(text: str, chunk_size: int = 1000, overlap: int = 100) -> List[str]: """Split text into overlapping chunks with sentence boundary detection.""" if not text: return [] chunks = [] start = 0 while start < len(text): end = min(start + chunk_size, len(text)) if end < len(text): for sep in ['. ', '.\n', '\n\n', '\n', ' ']: last_sep = text[start:end].rfind(sep) if last_sep > chunk_size * 0.5: end = start + last_sep + len(sep) break chunk = text[start:end].strip() if chunk: chunks.append(chunk) if end >= len(text): break start = end - overlap return chunks def cosine_similarity(vec_a: List[float], vec_b: List[float]) -> float: """Compute cosine similarity between two vectors.""" if len(vec_a) != len(vec_b): return 0.0 dot = sum(a * b for a, b in zip(vec_a, vec_b)) na = math.sqrt(sum(a * a for a in vec_a)) nb = math.sqrt(sum(b * b for b in vec_b)) if na == 0 or nb == 0: return 0.0 return dot / (na * nb)