Add ingestion pipeline
Browse files- graphrag/ingestion.py +110 -0
graphrag/ingestion.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Document Ingestion Pipeline
|
| 3 |
+
============================
|
| 4 |
+
Ingests documents from HotpotQA or custom sources into TigerGraph.
|
| 5 |
+
"""
|
| 6 |
+
import hashlib
|
| 7 |
+
import json
|
| 8 |
+
import logging
|
| 9 |
+
from typing import Dict, List, Tuple
|
| 10 |
+
from .layers.graph_layer import GraphLayer, chunk_text, generate_entity_id, generate_chunk_id
|
| 11 |
+
from .layers.llm_layer import LLMLayer
|
| 12 |
+
from .layers.orchestration_layer import EmbeddingManager
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class IngestionPipeline:
|
| 18 |
+
"""Full document ingestion: Docs → Chunks → Embeddings → Entities → Graph."""
|
| 19 |
+
|
| 20 |
+
def __init__(self, graph, llm, embedder, config=None):
|
| 21 |
+
self.graph = graph
|
| 22 |
+
self.llm = llm
|
| 23 |
+
self.embedder = embedder
|
| 24 |
+
self.config = config or {}
|
| 25 |
+
self.stats = {"documents": 0, "chunks": 0, "entities": 0,
|
| 26 |
+
"relations": 0, "mentions": 0, "errors": 0}
|
| 27 |
+
|
| 28 |
+
def ingest_document(self, doc_id, title, content, source="", extract_entities=True):
|
| 29 |
+
"""Ingest a single document into the graph."""
|
| 30 |
+
self.graph.upsert_document(doc_id, title, content, source)
|
| 31 |
+
self.stats["documents"] += 1
|
| 32 |
+
|
| 33 |
+
chunks = chunk_text(content, self.config.get("chunk_size", 1000),
|
| 34 |
+
self.config.get("chunk_overlap", 100))
|
| 35 |
+
if not chunks:
|
| 36 |
+
return self.stats
|
| 37 |
+
|
| 38 |
+
embs = self.embedder.embed(chunks)
|
| 39 |
+
for i, (chunk, emb) in enumerate(zip(chunks, embs)):
|
| 40 |
+
cid = generate_chunk_id(doc_id, i)
|
| 41 |
+
self.graph.upsert_chunk(cid, chunk, emb, i, len(chunk.split()), doc_id)
|
| 42 |
+
self.stats["chunks"] += 1
|
| 43 |
+
if extract_entities:
|
| 44 |
+
self._extract_entities(cid, chunk)
|
| 45 |
+
return self.stats
|
| 46 |
+
|
| 47 |
+
def _extract_entities(self, chunk_id, text):
|
| 48 |
+
"""Extract entities from chunk and upsert to graph."""
|
| 49 |
+
try:
|
| 50 |
+
resp = self.llm.extract_entities(text)
|
| 51 |
+
data = json.loads(resp.content)
|
| 52 |
+
except Exception as e:
|
| 53 |
+
logger.error(f"Entity extraction failed: {e}")
|
| 54 |
+
self.stats["errors"] += 1
|
| 55 |
+
return
|
| 56 |
+
|
| 57 |
+
name_to_id = {}
|
| 58 |
+
for ent in data.get("entities", []):
|
| 59 |
+
name = ent.get("name", "").strip()
|
| 60 |
+
etype = ent.get("type", "CONCEPT").strip()
|
| 61 |
+
if not name:
|
| 62 |
+
continue
|
| 63 |
+
eid = generate_entity_id(name, etype)
|
| 64 |
+
name_to_id[name] = eid
|
| 65 |
+
emb = self.embedder.embed_single(f"{name} ({etype}): {ent.get('description', '')}")
|
| 66 |
+
self.graph.upsert_entity(eid, name, etype, ent.get("description", ""), emb)
|
| 67 |
+
self.graph.upsert_mention(chunk_id, eid)
|
| 68 |
+
self.stats["entities"] += 1
|
| 69 |
+
self.stats["mentions"] += 1
|
| 70 |
+
|
| 71 |
+
for rel in data.get("relations", []):
|
| 72 |
+
sid = name_to_id.get(rel.get("source", ""))
|
| 73 |
+
tid = name_to_id.get(rel.get("target", ""))
|
| 74 |
+
if sid and tid:
|
| 75 |
+
self.graph.upsert_relation(sid, tid, rel.get("type", "RELATED_TO"),
|
| 76 |
+
rel.get("description", ""))
|
| 77 |
+
self.stats["relations"] += 1
|
| 78 |
+
|
| 79 |
+
def ingest_hotpotqa(self, max_docs=100, split="validation", extract_entities=True):
|
| 80 |
+
"""Ingest HotpotQA documents into the graph."""
|
| 81 |
+
from datasets import load_dataset
|
| 82 |
+
logger.info(f"Loading HotpotQA ({split}, max={max_docs})...")
|
| 83 |
+
ds = load_dataset("hotpotqa/hotpot_qa", "distractor", split=split)
|
| 84 |
+
ingested, seen = 0, set()
|
| 85 |
+
for row in ds:
|
| 86 |
+
if ingested >= max_docs:
|
| 87 |
+
break
|
| 88 |
+
for title, sents in zip(row["context"]["title"], row["context"]["sentences"]):
|
| 89 |
+
if title in seen or ingested >= max_docs:
|
| 90 |
+
continue
|
| 91 |
+
seen.add(title)
|
| 92 |
+
content = " ".join(sents)
|
| 93 |
+
if len(content) < 50:
|
| 94 |
+
continue
|
| 95 |
+
did = hashlib.md5(title.encode()).hexdigest()[:10]
|
| 96 |
+
self.ingest_document(did, title, content, "hotpotqa", extract_entities)
|
| 97 |
+
ingested += 1
|
| 98 |
+
if ingested % 10 == 0:
|
| 99 |
+
logger.info(f"Ingested {ingested}/{max_docs} documents...")
|
| 100 |
+
logger.info(f"Ingestion complete. Stats: {self.stats}")
|
| 101 |
+
return self.stats
|
| 102 |
+
|
| 103 |
+
def ingest_custom_documents(self, documents: List[Dict], extract_entities=True):
|
| 104 |
+
"""Ingest custom documents. Each dict: {id, title, content, source}."""
|
| 105 |
+
for doc in documents:
|
| 106 |
+
self.ingest_document(
|
| 107 |
+
doc_id=doc.get("id", hashlib.md5(doc["title"].encode()).hexdigest()[:10]),
|
| 108 |
+
title=doc.get("title", ""), content=doc.get("content", ""),
|
| 109 |
+
source=doc.get("source", "custom"), extract_entities=extract_entities)
|
| 110 |
+
return self.stats
|