graphrag-inference-hackathon / graphrag /layers /tg_graphrag_client.py
muthuk1's picture
Fix #1: Add TigerGraph GraphRAG integration layer wrapping official repo REST APIs
3101051 verified
"""
TigerGraph GraphRAG Client β€” Integration with the Official tigergraph/graphrag Repo
====================================================================================
This module integrates with the official TigerGraph GraphRAG service
(https://github.com/tigergraph/graphrag) deployed via Docker.
The official repo exposes REST APIs for graph-powered Q&A with three retrievers:
- Hybrid Search: vector similarity + graph traversal combined
- Community: hierarchical community summaries (Leiden algorithm)
- Sibling: sibling/neighbor node traversal from seed entities
This client calls those APIs. When the official service is not available,
it falls back to our custom pyTigerGraph-based GraphLayer implementation.
Usage:
client = TGGraphRAGClient(service_url="http://localhost:8000", ...)
if client.connect():
result = client.retrieve(query, retriever="hybrid", top_k=5, num_hops=2)
answer = client.query(question, retriever="hybrid")
"""
import json
import logging
import os
import time
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
@dataclass
class RetrievalResult:
"""Result from a TG GraphRAG retrieval call."""
content: str = ""
chunks: List[Dict[str, Any]] = field(default_factory=list)
entities: List[Dict[str, Any]] = field(default_factory=list)
relations: List[str] = field(default_factory=list)
community_summaries: List[str] = field(default_factory=list)
retriever_used: str = ""
score: float = 0.0
latency_ms: float = 0.0
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass
class GraphRAGAnswer:
"""Full answer from the TG GraphRAG service."""
answer: str = ""
retrieval: RetrievalResult = field(default_factory=RetrievalResult)
total_tokens: int = 0
input_tokens: int = 0
output_tokens: int = 0
latency_ms: float = 0.0
cost_usd: float = 0.0
class TGGraphRAGClient:
"""
Client for the official TigerGraph GraphRAG service.
Supports two modes:
1. REST API mode: calls the deployed tigergraph/graphrag Docker service
2. Direct mode: uses pyTigerGraph SDK with our custom GSQL queries (fallback)
The hackathon allows both Path A (use as-is) and Path B (customize).
This client implements Path A (REST API) with Path B fallback (direct GSQL).
"""
def __init__(
self,
service_url: str = "",
tg_host: str = "",
tg_graph: str = "GraphRAG",
tg_username: str = "tigergraph",
tg_password: str = "",
tg_token: str = "",
):
self.service_url = (
service_url
or os.getenv("GRAPHRAG_SERVICE_URL", "")
or os.getenv("TG_GRAPHRAG_URL", "")
).rstrip("/")
self.tg_host = tg_host or os.getenv("TG_HOST", "")
self.tg_graph = tg_graph or os.getenv("TG_GRAPH", "GraphRAG")
self.tg_username = tg_username or os.getenv("TG_USERNAME", "tigergraph")
self.tg_password = tg_password or os.getenv("TG_PASSWORD", "")
self.tg_token = tg_token or os.getenv("TG_TOKEN", "")
self._service_available = False
self._direct_available = False
self._conn = None
self._api_token = ""
self._openapi_spec: Dict = {}
# ── Connection ────────────────────────────────────────
def connect(self) -> bool:
"""
Connect to the TG GraphRAG service.
Tries REST API first, then falls back to direct pyTigerGraph.
"""
# Try REST API service first
if self.service_url:
self._service_available = self._check_service()
if self._service_available:
logger.info(f"Connected to TG GraphRAG service at {self.service_url}")
self._discover_endpoints()
return True
# Fall back to direct pyTigerGraph connection
if self.tg_host:
self._direct_available = self._connect_direct()
if self._direct_available:
logger.info(f"Connected to TigerGraph directly at {self.tg_host}")
return True
logger.warning("No TG GraphRAG connection available. Running in offline mode.")
return False
def _check_service(self) -> bool:
"""Check if the TG GraphRAG REST service is healthy."""
import urllib.request
import urllib.error
# Try common health endpoints
for path in ["/health", "/api/health", "/", "/docs", "/openapi.json"]:
try:
url = f"{self.service_url}{path}"
req = urllib.request.Request(url, method="GET")
if self._api_token:
req.add_header("Authorization", f"Bearer {self._api_token}")
with urllib.request.urlopen(req, timeout=5) as resp:
if resp.status == 200:
logger.info(f"TG GraphRAG service healthy at {url}")
return True
except (urllib.error.URLError, OSError):
continue
return False
def _discover_endpoints(self):
"""Discover available API endpoints from OpenAPI spec."""
import urllib.request
try:
url = f"{self.service_url}/openapi.json"
req = urllib.request.Request(url, method="GET")
with urllib.request.urlopen(req, timeout=5) as resp:
self._openapi_spec = json.loads(resp.read())
paths = list(self._openapi_spec.get("paths", {}).keys())
logger.info(f"Discovered {len(paths)} API endpoints: {paths[:10]}")
except Exception as e:
logger.debug(f"Could not discover endpoints: {e}")
def _connect_direct(self) -> bool:
"""Connect directly to TigerGraph via pyTigerGraph."""
try:
import pyTigerGraph as tg
self._conn = tg.TigerGraphConnection(
host=self.tg_host,
graphname=self.tg_graph,
username=self.tg_username,
password=self.tg_password,
)
if self.tg_token:
self._conn.apiToken = self.tg_token
else:
secret = self._conn.createSecret()
self._conn.getToken(secret)
return True
except Exception as e:
logger.error(f"Direct TigerGraph connection failed: {e}")
return False
@property
def is_connected(self) -> bool:
return self._service_available or self._direct_available
@property
def mode(self) -> str:
if self._service_available:
return "rest_api"
elif self._direct_available:
return "direct"
return "offline"
# ── Retrieval (Core API) ──────────────────────────────
def retrieve(
self,
query: str,
retriever: str = "hybrid",
top_k: int = 5,
num_hops: int = 2,
community_level: int = 1,
) -> RetrievalResult:
"""
Retrieve context for a query using the specified retriever.
Args:
query: The question to retrieve context for
retriever: One of "hybrid", "community", "sibling"
top_k: Number of top results to return
num_hops: Graph traversal depth (for hybrid/sibling)
community_level: Leiden hierarchy level (for community)
Returns:
RetrievalResult with chunks, entities, and metadata
"""
start = time.perf_counter()
if self._service_available:
result = self._retrieve_via_api(query, retriever, top_k, num_hops, community_level)
elif self._direct_available:
result = self._retrieve_via_direct(query, retriever, top_k, num_hops, community_level)
else:
result = RetrievalResult(
content="[No TG GraphRAG connection β€” offline mode]",
retriever_used=retriever,
)
result.latency_ms = (time.perf_counter() - start) * 1000
return result
def _retrieve_via_api(
self, query: str, retriever: str, top_k: int, num_hops: int, community_level: int
) -> RetrievalResult:
"""Call the official TG GraphRAG REST API for retrieval."""
import urllib.request
import urllib.error
payload = {
"query": query,
"top_k": top_k,
}
if retriever in ("hybrid", "sibling"):
payload["num_hops"] = num_hops
if retriever == "community":
payload["community_level"] = community_level
# Try multiple endpoint patterns (official repo may use different paths)
endpoint_patterns = [
f"/retrieve/{retriever}",
f"/api/retrieve/{retriever}",
f"/graphrag/retrieve/{retriever}",
f"/api/v1/retrieve/{retriever}",
f"/retrieve", # with retriever in body
f"/api/retrieve", # with retriever in body
f"/query", # generic query endpoint
f"/api/query",
]
# For generic endpoints, include retriever type in payload
payload_with_type = {**payload, "retriever": retriever, "retriever_type": retriever}
for path in endpoint_patterns:
try:
url = f"{self.service_url}{path}"
body = json.dumps(payload_with_type if "/retrieve/" not in path else payload)
req = urllib.request.Request(
url, data=body.encode("utf-8"), method="POST",
headers={"Content-Type": "application/json"}
)
if self._api_token:
req.add_header("Authorization", f"Bearer {self._api_token}")
with urllib.request.urlopen(req, timeout=30) as resp:
data = json.loads(resp.read())
return self._parse_api_response(data, retriever)
except urllib.error.HTTPError as e:
if e.code == 404:
continue # try next endpoint pattern
logger.error(f"API error on {path}: {e.code} {e.reason}")
continue
except (urllib.error.URLError, OSError, json.JSONDecodeError) as e:
logger.debug(f"Endpoint {path} failed: {e}")
continue
logger.warning("All REST API endpoint patterns failed. Falling back to direct mode.")
if self._direct_available:
return self._retrieve_via_direct(query, retriever, top_k, num_hops, community_level)
return RetrievalResult(content="[API retrieval failed]", retriever_used=retriever)
def _parse_api_response(self, data: Dict, retriever: str) -> RetrievalResult:
"""Parse the response from the TG GraphRAG API into a RetrievalResult."""
result = RetrievalResult(retriever_used=retriever)
# Handle various response formats the API might return
if isinstance(data, dict):
# Standard format: {"results": [...], "answer": "..."}
results = data.get("results", data.get("chunks", data.get("documents", [])))
if isinstance(results, list):
for item in results:
if isinstance(item, dict):
result.chunks.append({
"text": item.get("content", item.get("text", item.get("chunk_text", ""))),
"score": item.get("score", item.get("similarity", 0.0)),
"source": item.get("source", item.get("doc_id", "")),
"chunk_id": item.get("chunk_id", item.get("id", "")),
})
elif isinstance(item, str):
result.chunks.append({"text": item, "score": 0.0})
# Extract entities if present
entities = data.get("entities", data.get("nodes", []))
if isinstance(entities, list):
result.entities = entities
# Extract relations if present
relations = data.get("relations", data.get("edges", data.get("relationships", [])))
if isinstance(relations, list):
result.relations = [str(r) for r in relations]
# Extract community summaries if present
summaries = data.get("community_summaries", data.get("summaries", []))
if isinstance(summaries, list):
result.community_summaries = [str(s) for s in summaries]
# Build combined content
texts = [c.get("text", "") for c in result.chunks if c.get("text")]
if result.community_summaries:
texts = result.community_summaries + texts
result.content = "\n\n".join(texts)
# Answer if provided
if "answer" in data:
result.metadata["service_answer"] = data["answer"]
result.metadata["raw_response_keys"] = list(data.keys())
elif isinstance(data, list):
for item in data:
text = item.get("text", item.get("content", str(item))) if isinstance(item, dict) else str(item)
result.chunks.append({"text": text, "score": 0.0})
result.content = "\n\n".join(c["text"] for c in result.chunks)
return result
def _retrieve_via_direct(
self, query: str, retriever: str, top_k: int, num_hops: int, community_level: int
) -> RetrievalResult:
"""
Fallback: use pyTigerGraph direct GSQL queries.
Maps official retriever names to our custom GSQL queries.
"""
result = RetrievalResult(retriever_used=f"{retriever}_direct")
if not self._conn:
return result
try:
# Get query embedding for vector search
from .orchestration_layer import EmbeddingManager
embedder = EmbeddingManager()
embedder.initialize()
query_emb = embedder.embed_single(query)
if retriever == "hybrid":
# Hybrid = vector search chunks + entity traversal
chunks = self._run_query("vectorSearchChunks",
{"queryVec": query_emb, "topK": top_k})
entity_results = self._run_query("vectorSearchEntities",
{"queryVec": query_emb, "topK": top_k})
seed_ids = [e.get("entity_id", "") for e in
(entity_results[0].get("@@topEntities", []) if entity_results else [])]
if seed_ids:
traversal = self._run_query("graphRAGTraverse",
{"seedEntityIds": seed_ids, "hops": num_hops})
if traversal:
for r in traversal:
if "@@chunkTexts" in r:
for text in r["@@chunkTexts"]:
result.chunks.append({"text": text, "score": 0.0})
if "@@relationDescriptions" in r:
result.relations = list(r["@@relationDescriptions"])
# Also add vector search results
if chunks:
for c in chunks[0].get("@@topChunks", []):
result.chunks.append({
"text": c.get("text", c.get("chunk_id", "")),
"score": c.get("score", 0.0),
})
result.content = "\n\n".join(c["text"] for c in result.chunks[:top_k] if c.get("text"))
elif retriever == "community":
# Community retriever β€” use community summaries
chunks = self._run_query("vectorSearchChunks",
{"queryVec": query_emb, "topK": top_k})
if chunks:
for c in chunks[0].get("@@topChunks", []):
result.chunks.append({"text": c.get("text", ""), "score": c.get("score", 0.0)})
result.content = "\n\n".join(c["text"] for c in result.chunks if c.get("text"))
elif retriever == "sibling":
# Sibling retriever β€” entity neighbors
entity_results = self._run_query("vectorSearchEntities",
{"queryVec": query_emb, "topK": top_k})
seed_ids = [e.get("entity_id", "") for e in
(entity_results[0].get("@@topEntities", []) if entity_results else [])]
if seed_ids:
traversal = self._run_query("graphRAGTraverse",
{"seedEntityIds": seed_ids, "hops": num_hops})
if traversal:
for r in traversal:
if "@@chunkTexts" in r:
for text in r["@@chunkTexts"]:
result.chunks.append({"text": text, "score": 0.0})
if "@@relationDescriptions" in r:
result.relations = list(r["@@relationDescriptions"])
result.content = "\n\n".join(c["text"] for c in result.chunks[:top_k] if c.get("text"))
except Exception as e:
logger.error(f"Direct retrieval failed: {e}")
result.content = f"[Retrieval error: {e}]"
return result
def _run_query(self, query_name: str, params: Dict) -> List[Dict]:
"""Run an installed GSQL query."""
try:
return self._conn.runInstalledQuery(query_name, params=params)
except Exception as e:
logger.error(f"GSQL query {query_name} failed: {e}")
return []
# ── Full Q&A (Retrieval + Generation) ─────────────────
def query(
self,
question: str,
retriever: str = "hybrid",
top_k: int = 5,
num_hops: int = 2,
community_level: int = 1,
llm_layer=None,
) -> GraphRAGAnswer:
"""
Full GraphRAG Q&A: retrieve context β†’ generate answer.
If the TG GraphRAG service provides its own answer, use that.
Otherwise, retrieve context and pass to our LLM layer for generation.
"""
start = time.perf_counter()
retrieval = self.retrieve(query=question, retriever=retriever,
top_k=top_k, num_hops=num_hops,
community_level=community_level)
answer_obj = GraphRAGAnswer(retrieval=retrieval)
# If the service already returned an answer, use it
service_answer = retrieval.metadata.get("service_answer", "")
if service_answer:
answer_obj.answer = service_answer
elif llm_layer and retrieval.content:
# Generate answer using our LLM layer with retrieved context
resp = llm_layer.generate_answer(question, retrieval.content,
system_prompt=(
"You are a knowledgeable assistant with access to a knowledge graph. "
"Use the structured context including entities, relationships, and passages "
"to answer accurately. Follow relationship chains for multi-hop reasoning. "
"Be concise and precise."
))
answer_obj.answer = resp.content
answer_obj.input_tokens = resp.input_tokens
answer_obj.output_tokens = resp.output_tokens
answer_obj.total_tokens = resp.total_tokens
answer_obj.cost_usd = resp.cost_usd
else:
answer_obj.answer = "[No context retrieved and no LLM available]"
answer_obj.latency_ms = (time.perf_counter() - start) * 1000
return answer_obj
# ── Document Ingestion via Service ────────────────────
def ingest_document(
self,
doc_id: str,
title: str,
content: str,
source: str = "",
) -> Dict[str, Any]:
"""
Ingest a document via the TG GraphRAG service API.
Falls back to direct pyTigerGraph if service is unavailable.
"""
if self._service_available:
return self._ingest_via_api(doc_id, title, content, source)
elif self._direct_available:
return self._ingest_via_direct(doc_id, title, content, source)
return {"status": "error", "message": "No connection available"}
def _ingest_via_api(self, doc_id, title, content, source) -> Dict:
import urllib.request
payload = json.dumps({
"doc_id": doc_id, "title": title,
"content": content, "source": source,
})
for path in ["/ingest", "/api/ingest", "/documents", "/api/documents"]:
try:
url = f"{self.service_url}{path}"
req = urllib.request.Request(
url, data=payload.encode(), method="POST",
headers={"Content-Type": "application/json"})
with urllib.request.urlopen(req, timeout=60) as resp:
return json.loads(resp.read())
except Exception:
continue
return {"status": "error", "message": "All ingest endpoints failed"}
def _ingest_via_direct(self, doc_id, title, content, source) -> Dict:
try:
self._conn.upsertVertex("Document", doc_id, {
"title": title, "content": content, "source": source})
return {"status": "ok", "doc_id": doc_id}
except Exception as e:
return {"status": "error", "message": str(e)}
# ── Status / Debug ────────────────────────────────────
def status(self) -> Dict[str, Any]:
"""Return connection status and available features."""
return {
"mode": self.mode,
"service_url": self.service_url if self._service_available else None,
"tg_host": self.tg_host if self._direct_available else None,
"tg_graph": self.tg_graph,
"service_available": self._service_available,
"direct_available": self._direct_available,
"available_retrievers": ["hybrid", "community", "sibling"],
"openapi_endpoints": list(self._openapi_spec.get("paths", {}).keys())[:20],
}