| """ |
| TigerGraph GraphRAG Client β Integration with the Official tigergraph/graphrag Repo |
| ==================================================================================== |
| This module integrates with the official TigerGraph GraphRAG service |
| (https://github.com/tigergraph/graphrag) deployed via Docker. |
| |
| The official repo exposes REST APIs for graph-powered Q&A with three retrievers: |
| - Hybrid Search: vector similarity + graph traversal combined |
| - Community: hierarchical community summaries (Leiden algorithm) |
| - Sibling: sibling/neighbor node traversal from seed entities |
| |
| This client calls those APIs. When the official service is not available, |
| it falls back to our custom pyTigerGraph-based GraphLayer implementation. |
| |
| Usage: |
| client = TGGraphRAGClient(service_url="http://localhost:8000", ...) |
| if client.connect(): |
| result = client.retrieve(query, retriever="hybrid", top_k=5, num_hops=2) |
| answer = client.query(question, retriever="hybrid") |
| """ |
|
|
| import json |
| import logging |
| import os |
| import time |
| from dataclasses import dataclass, field |
| from typing import Any, Dict, List, Optional |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| @dataclass |
| class RetrievalResult: |
| """Result from a TG GraphRAG retrieval call.""" |
| content: str = "" |
| chunks: List[Dict[str, Any]] = field(default_factory=list) |
| entities: List[Dict[str, Any]] = field(default_factory=list) |
| relations: List[str] = field(default_factory=list) |
| community_summaries: List[str] = field(default_factory=list) |
| retriever_used: str = "" |
| score: float = 0.0 |
| latency_ms: float = 0.0 |
| metadata: Dict[str, Any] = field(default_factory=dict) |
|
|
|
|
| @dataclass |
| class GraphRAGAnswer: |
| """Full answer from the TG GraphRAG service.""" |
| answer: str = "" |
| retrieval: RetrievalResult = field(default_factory=RetrievalResult) |
| total_tokens: int = 0 |
| input_tokens: int = 0 |
| output_tokens: int = 0 |
| latency_ms: float = 0.0 |
| cost_usd: float = 0.0 |
|
|
|
|
| class TGGraphRAGClient: |
| """ |
| Client for the official TigerGraph GraphRAG service. |
| |
| Supports two modes: |
| 1. REST API mode: calls the deployed tigergraph/graphrag Docker service |
| 2. Direct mode: uses pyTigerGraph SDK with our custom GSQL queries (fallback) |
| |
| The hackathon allows both Path A (use as-is) and Path B (customize). |
| This client implements Path A (REST API) with Path B fallback (direct GSQL). |
| """ |
|
|
| def __init__( |
| self, |
| service_url: str = "", |
| tg_host: str = "", |
| tg_graph: str = "GraphRAG", |
| tg_username: str = "tigergraph", |
| tg_password: str = "", |
| tg_token: str = "", |
| ): |
| self.service_url = ( |
| service_url |
| or os.getenv("GRAPHRAG_SERVICE_URL", "") |
| or os.getenv("TG_GRAPHRAG_URL", "") |
| ).rstrip("/") |
| self.tg_host = tg_host or os.getenv("TG_HOST", "") |
| self.tg_graph = tg_graph or os.getenv("TG_GRAPH", "GraphRAG") |
| self.tg_username = tg_username or os.getenv("TG_USERNAME", "tigergraph") |
| self.tg_password = tg_password or os.getenv("TG_PASSWORD", "") |
| self.tg_token = tg_token or os.getenv("TG_TOKEN", "") |
|
|
| self._service_available = False |
| self._direct_available = False |
| self._conn = None |
| self._api_token = "" |
| self._openapi_spec: Dict = {} |
|
|
| |
|
|
| def connect(self) -> bool: |
| """ |
| Connect to the TG GraphRAG service. |
| Tries REST API first, then falls back to direct pyTigerGraph. |
| """ |
| |
| if self.service_url: |
| self._service_available = self._check_service() |
| if self._service_available: |
| logger.info(f"Connected to TG GraphRAG service at {self.service_url}") |
| self._discover_endpoints() |
| return True |
|
|
| |
| if self.tg_host: |
| self._direct_available = self._connect_direct() |
| if self._direct_available: |
| logger.info(f"Connected to TigerGraph directly at {self.tg_host}") |
| return True |
|
|
| logger.warning("No TG GraphRAG connection available. Running in offline mode.") |
| return False |
|
|
| def _check_service(self) -> bool: |
| """Check if the TG GraphRAG REST service is healthy.""" |
| import urllib.request |
| import urllib.error |
|
|
| |
| for path in ["/health", "/api/health", "/", "/docs", "/openapi.json"]: |
| try: |
| url = f"{self.service_url}{path}" |
| req = urllib.request.Request(url, method="GET") |
| if self._api_token: |
| req.add_header("Authorization", f"Bearer {self._api_token}") |
| with urllib.request.urlopen(req, timeout=5) as resp: |
| if resp.status == 200: |
| logger.info(f"TG GraphRAG service healthy at {url}") |
| return True |
| except (urllib.error.URLError, OSError): |
| continue |
| return False |
|
|
| def _discover_endpoints(self): |
| """Discover available API endpoints from OpenAPI spec.""" |
| import urllib.request |
| try: |
| url = f"{self.service_url}/openapi.json" |
| req = urllib.request.Request(url, method="GET") |
| with urllib.request.urlopen(req, timeout=5) as resp: |
| self._openapi_spec = json.loads(resp.read()) |
| paths = list(self._openapi_spec.get("paths", {}).keys()) |
| logger.info(f"Discovered {len(paths)} API endpoints: {paths[:10]}") |
| except Exception as e: |
| logger.debug(f"Could not discover endpoints: {e}") |
|
|
| def _connect_direct(self) -> bool: |
| """Connect directly to TigerGraph via pyTigerGraph.""" |
| try: |
| import pyTigerGraph as tg |
| self._conn = tg.TigerGraphConnection( |
| host=self.tg_host, |
| graphname=self.tg_graph, |
| username=self.tg_username, |
| password=self.tg_password, |
| ) |
| if self.tg_token: |
| self._conn.apiToken = self.tg_token |
| else: |
| secret = self._conn.createSecret() |
| self._conn.getToken(secret) |
| return True |
| except Exception as e: |
| logger.error(f"Direct TigerGraph connection failed: {e}") |
| return False |
|
|
| @property |
| def is_connected(self) -> bool: |
| return self._service_available or self._direct_available |
|
|
| @property |
| def mode(self) -> str: |
| if self._service_available: |
| return "rest_api" |
| elif self._direct_available: |
| return "direct" |
| return "offline" |
|
|
| |
|
|
| def retrieve( |
| self, |
| query: str, |
| retriever: str = "hybrid", |
| top_k: int = 5, |
| num_hops: int = 2, |
| community_level: int = 1, |
| ) -> RetrievalResult: |
| """ |
| Retrieve context for a query using the specified retriever. |
| |
| Args: |
| query: The question to retrieve context for |
| retriever: One of "hybrid", "community", "sibling" |
| top_k: Number of top results to return |
| num_hops: Graph traversal depth (for hybrid/sibling) |
| community_level: Leiden hierarchy level (for community) |
| |
| Returns: |
| RetrievalResult with chunks, entities, and metadata |
| """ |
| start = time.perf_counter() |
|
|
| if self._service_available: |
| result = self._retrieve_via_api(query, retriever, top_k, num_hops, community_level) |
| elif self._direct_available: |
| result = self._retrieve_via_direct(query, retriever, top_k, num_hops, community_level) |
| else: |
| result = RetrievalResult( |
| content="[No TG GraphRAG connection β offline mode]", |
| retriever_used=retriever, |
| ) |
|
|
| result.latency_ms = (time.perf_counter() - start) * 1000 |
| return result |
|
|
| def _retrieve_via_api( |
| self, query: str, retriever: str, top_k: int, num_hops: int, community_level: int |
| ) -> RetrievalResult: |
| """Call the official TG GraphRAG REST API for retrieval.""" |
| import urllib.request |
| import urllib.error |
|
|
| payload = { |
| "query": query, |
| "top_k": top_k, |
| } |
| if retriever in ("hybrid", "sibling"): |
| payload["num_hops"] = num_hops |
| if retriever == "community": |
| payload["community_level"] = community_level |
|
|
| |
| endpoint_patterns = [ |
| f"/retrieve/{retriever}", |
| f"/api/retrieve/{retriever}", |
| f"/graphrag/retrieve/{retriever}", |
| f"/api/v1/retrieve/{retriever}", |
| f"/retrieve", |
| f"/api/retrieve", |
| f"/query", |
| f"/api/query", |
| ] |
|
|
| |
| payload_with_type = {**payload, "retriever": retriever, "retriever_type": retriever} |
|
|
| for path in endpoint_patterns: |
| try: |
| url = f"{self.service_url}{path}" |
| body = json.dumps(payload_with_type if "/retrieve/" not in path else payload) |
| req = urllib.request.Request( |
| url, data=body.encode("utf-8"), method="POST", |
| headers={"Content-Type": "application/json"} |
| ) |
| if self._api_token: |
| req.add_header("Authorization", f"Bearer {self._api_token}") |
|
|
| with urllib.request.urlopen(req, timeout=30) as resp: |
| data = json.loads(resp.read()) |
| return self._parse_api_response(data, retriever) |
| except urllib.error.HTTPError as e: |
| if e.code == 404: |
| continue |
| logger.error(f"API error on {path}: {e.code} {e.reason}") |
| continue |
| except (urllib.error.URLError, OSError, json.JSONDecodeError) as e: |
| logger.debug(f"Endpoint {path} failed: {e}") |
| continue |
|
|
| logger.warning("All REST API endpoint patterns failed. Falling back to direct mode.") |
| if self._direct_available: |
| return self._retrieve_via_direct(query, retriever, top_k, num_hops, community_level) |
| return RetrievalResult(content="[API retrieval failed]", retriever_used=retriever) |
|
|
| def _parse_api_response(self, data: Dict, retriever: str) -> RetrievalResult: |
| """Parse the response from the TG GraphRAG API into a RetrievalResult.""" |
| result = RetrievalResult(retriever_used=retriever) |
|
|
| |
| if isinstance(data, dict): |
| |
| results = data.get("results", data.get("chunks", data.get("documents", []))) |
| if isinstance(results, list): |
| for item in results: |
| if isinstance(item, dict): |
| result.chunks.append({ |
| "text": item.get("content", item.get("text", item.get("chunk_text", ""))), |
| "score": item.get("score", item.get("similarity", 0.0)), |
| "source": item.get("source", item.get("doc_id", "")), |
| "chunk_id": item.get("chunk_id", item.get("id", "")), |
| }) |
| elif isinstance(item, str): |
| result.chunks.append({"text": item, "score": 0.0}) |
|
|
| |
| entities = data.get("entities", data.get("nodes", [])) |
| if isinstance(entities, list): |
| result.entities = entities |
|
|
| |
| relations = data.get("relations", data.get("edges", data.get("relationships", []))) |
| if isinstance(relations, list): |
| result.relations = [str(r) for r in relations] |
|
|
| |
| summaries = data.get("community_summaries", data.get("summaries", [])) |
| if isinstance(summaries, list): |
| result.community_summaries = [str(s) for s in summaries] |
|
|
| |
| texts = [c.get("text", "") for c in result.chunks if c.get("text")] |
| if result.community_summaries: |
| texts = result.community_summaries + texts |
| result.content = "\n\n".join(texts) |
|
|
| |
| if "answer" in data: |
| result.metadata["service_answer"] = data["answer"] |
|
|
| result.metadata["raw_response_keys"] = list(data.keys()) |
|
|
| elif isinstance(data, list): |
| for item in data: |
| text = item.get("text", item.get("content", str(item))) if isinstance(item, dict) else str(item) |
| result.chunks.append({"text": text, "score": 0.0}) |
| result.content = "\n\n".join(c["text"] for c in result.chunks) |
|
|
| return result |
|
|
| def _retrieve_via_direct( |
| self, query: str, retriever: str, top_k: int, num_hops: int, community_level: int |
| ) -> RetrievalResult: |
| """ |
| Fallback: use pyTigerGraph direct GSQL queries. |
| Maps official retriever names to our custom GSQL queries. |
| """ |
| result = RetrievalResult(retriever_used=f"{retriever}_direct") |
|
|
| if not self._conn: |
| return result |
|
|
| try: |
| |
| from .orchestration_layer import EmbeddingManager |
| embedder = EmbeddingManager() |
| embedder.initialize() |
| query_emb = embedder.embed_single(query) |
|
|
| if retriever == "hybrid": |
| |
| chunks = self._run_query("vectorSearchChunks", |
| {"queryVec": query_emb, "topK": top_k}) |
| entity_results = self._run_query("vectorSearchEntities", |
| {"queryVec": query_emb, "topK": top_k}) |
| seed_ids = [e.get("entity_id", "") for e in |
| (entity_results[0].get("@@topEntities", []) if entity_results else [])] |
| if seed_ids: |
| traversal = self._run_query("graphRAGTraverse", |
| {"seedEntityIds": seed_ids, "hops": num_hops}) |
| if traversal: |
| for r in traversal: |
| if "@@chunkTexts" in r: |
| for text in r["@@chunkTexts"]: |
| result.chunks.append({"text": text, "score": 0.0}) |
| if "@@relationDescriptions" in r: |
| result.relations = list(r["@@relationDescriptions"]) |
|
|
| |
| if chunks: |
| for c in chunks[0].get("@@topChunks", []): |
| result.chunks.append({ |
| "text": c.get("text", c.get("chunk_id", "")), |
| "score": c.get("score", 0.0), |
| }) |
|
|
| result.content = "\n\n".join(c["text"] for c in result.chunks[:top_k] if c.get("text")) |
|
|
| elif retriever == "community": |
| |
| chunks = self._run_query("vectorSearchChunks", |
| {"queryVec": query_emb, "topK": top_k}) |
| if chunks: |
| for c in chunks[0].get("@@topChunks", []): |
| result.chunks.append({"text": c.get("text", ""), "score": c.get("score", 0.0)}) |
| result.content = "\n\n".join(c["text"] for c in result.chunks if c.get("text")) |
|
|
| elif retriever == "sibling": |
| |
| entity_results = self._run_query("vectorSearchEntities", |
| {"queryVec": query_emb, "topK": top_k}) |
| seed_ids = [e.get("entity_id", "") for e in |
| (entity_results[0].get("@@topEntities", []) if entity_results else [])] |
| if seed_ids: |
| traversal = self._run_query("graphRAGTraverse", |
| {"seedEntityIds": seed_ids, "hops": num_hops}) |
| if traversal: |
| for r in traversal: |
| if "@@chunkTexts" in r: |
| for text in r["@@chunkTexts"]: |
| result.chunks.append({"text": text, "score": 0.0}) |
| if "@@relationDescriptions" in r: |
| result.relations = list(r["@@relationDescriptions"]) |
| result.content = "\n\n".join(c["text"] for c in result.chunks[:top_k] if c.get("text")) |
|
|
| except Exception as e: |
| logger.error(f"Direct retrieval failed: {e}") |
| result.content = f"[Retrieval error: {e}]" |
|
|
| return result |
|
|
| def _run_query(self, query_name: str, params: Dict) -> List[Dict]: |
| """Run an installed GSQL query.""" |
| try: |
| return self._conn.runInstalledQuery(query_name, params=params) |
| except Exception as e: |
| logger.error(f"GSQL query {query_name} failed: {e}") |
| return [] |
|
|
| |
|
|
| def query( |
| self, |
| question: str, |
| retriever: str = "hybrid", |
| top_k: int = 5, |
| num_hops: int = 2, |
| community_level: int = 1, |
| llm_layer=None, |
| ) -> GraphRAGAnswer: |
| """ |
| Full GraphRAG Q&A: retrieve context β generate answer. |
| |
| If the TG GraphRAG service provides its own answer, use that. |
| Otherwise, retrieve context and pass to our LLM layer for generation. |
| """ |
| start = time.perf_counter() |
| retrieval = self.retrieve(query=question, retriever=retriever, |
| top_k=top_k, num_hops=num_hops, |
| community_level=community_level) |
| answer_obj = GraphRAGAnswer(retrieval=retrieval) |
|
|
| |
| service_answer = retrieval.metadata.get("service_answer", "") |
| if service_answer: |
| answer_obj.answer = service_answer |
| elif llm_layer and retrieval.content: |
| |
| resp = llm_layer.generate_answer(question, retrieval.content, |
| system_prompt=( |
| "You are a knowledgeable assistant with access to a knowledge graph. " |
| "Use the structured context including entities, relationships, and passages " |
| "to answer accurately. Follow relationship chains for multi-hop reasoning. " |
| "Be concise and precise." |
| )) |
| answer_obj.answer = resp.content |
| answer_obj.input_tokens = resp.input_tokens |
| answer_obj.output_tokens = resp.output_tokens |
| answer_obj.total_tokens = resp.total_tokens |
| answer_obj.cost_usd = resp.cost_usd |
| else: |
| answer_obj.answer = "[No context retrieved and no LLM available]" |
|
|
| answer_obj.latency_ms = (time.perf_counter() - start) * 1000 |
| return answer_obj |
|
|
| |
|
|
| def ingest_document( |
| self, |
| doc_id: str, |
| title: str, |
| content: str, |
| source: str = "", |
| ) -> Dict[str, Any]: |
| """ |
| Ingest a document via the TG GraphRAG service API. |
| Falls back to direct pyTigerGraph if service is unavailable. |
| """ |
| if self._service_available: |
| return self._ingest_via_api(doc_id, title, content, source) |
| elif self._direct_available: |
| return self._ingest_via_direct(doc_id, title, content, source) |
| return {"status": "error", "message": "No connection available"} |
|
|
| def _ingest_via_api(self, doc_id, title, content, source) -> Dict: |
| import urllib.request |
| payload = json.dumps({ |
| "doc_id": doc_id, "title": title, |
| "content": content, "source": source, |
| }) |
| for path in ["/ingest", "/api/ingest", "/documents", "/api/documents"]: |
| try: |
| url = f"{self.service_url}{path}" |
| req = urllib.request.Request( |
| url, data=payload.encode(), method="POST", |
| headers={"Content-Type": "application/json"}) |
| with urllib.request.urlopen(req, timeout=60) as resp: |
| return json.loads(resp.read()) |
| except Exception: |
| continue |
| return {"status": "error", "message": "All ingest endpoints failed"} |
|
|
| def _ingest_via_direct(self, doc_id, title, content, source) -> Dict: |
| try: |
| self._conn.upsertVertex("Document", doc_id, { |
| "title": title, "content": content, "source": source}) |
| return {"status": "ok", "doc_id": doc_id} |
| except Exception as e: |
| return {"status": "error", "message": str(e)} |
|
|
| |
|
|
| def status(self) -> Dict[str, Any]: |
| """Return connection status and available features.""" |
| return { |
| "mode": self.mode, |
| "service_url": self.service_url if self._service_available else None, |
| "tg_host": self.tg_host if self._direct_available else None, |
| "tg_graph": self.tg_graph, |
| "service_available": self._service_available, |
| "direct_available": self._direct_available, |
| "available_retrievers": ["hybrid", "community", "sibling"], |
| "openapi_endpoints": list(self._openapi_spec.get("paths", {}).keys())[:20], |
| } |
|
|