muthuk1 commited on
Commit
a60c61c
·
verified ·
1 Parent(s): da5b779

Add ingestion pipeline

Browse files
Files changed (1) hide show
  1. graphrag/ingestion.py +110 -0
graphrag/ingestion.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Document Ingestion Pipeline
3
+ ============================
4
+ Ingests documents from HotpotQA or custom sources into TigerGraph.
5
+ """
6
+ import hashlib
7
+ import json
8
+ import logging
9
+ from typing import Dict, List, Tuple
10
+ from .layers.graph_layer import GraphLayer, chunk_text, generate_entity_id, generate_chunk_id
11
+ from .layers.llm_layer import LLMLayer
12
+ from .layers.orchestration_layer import EmbeddingManager
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class IngestionPipeline:
18
+ """Full document ingestion: Docs → Chunks → Embeddings → Entities → Graph."""
19
+
20
+ def __init__(self, graph, llm, embedder, config=None):
21
+ self.graph = graph
22
+ self.llm = llm
23
+ self.embedder = embedder
24
+ self.config = config or {}
25
+ self.stats = {"documents": 0, "chunks": 0, "entities": 0,
26
+ "relations": 0, "mentions": 0, "errors": 0}
27
+
28
+ def ingest_document(self, doc_id, title, content, source="", extract_entities=True):
29
+ """Ingest a single document into the graph."""
30
+ self.graph.upsert_document(doc_id, title, content, source)
31
+ self.stats["documents"] += 1
32
+
33
+ chunks = chunk_text(content, self.config.get("chunk_size", 1000),
34
+ self.config.get("chunk_overlap", 100))
35
+ if not chunks:
36
+ return self.stats
37
+
38
+ embs = self.embedder.embed(chunks)
39
+ for i, (chunk, emb) in enumerate(zip(chunks, embs)):
40
+ cid = generate_chunk_id(doc_id, i)
41
+ self.graph.upsert_chunk(cid, chunk, emb, i, len(chunk.split()), doc_id)
42
+ self.stats["chunks"] += 1
43
+ if extract_entities:
44
+ self._extract_entities(cid, chunk)
45
+ return self.stats
46
+
47
+ def _extract_entities(self, chunk_id, text):
48
+ """Extract entities from chunk and upsert to graph."""
49
+ try:
50
+ resp = self.llm.extract_entities(text)
51
+ data = json.loads(resp.content)
52
+ except Exception as e:
53
+ logger.error(f"Entity extraction failed: {e}")
54
+ self.stats["errors"] += 1
55
+ return
56
+
57
+ name_to_id = {}
58
+ for ent in data.get("entities", []):
59
+ name = ent.get("name", "").strip()
60
+ etype = ent.get("type", "CONCEPT").strip()
61
+ if not name:
62
+ continue
63
+ eid = generate_entity_id(name, etype)
64
+ name_to_id[name] = eid
65
+ emb = self.embedder.embed_single(f"{name} ({etype}): {ent.get('description', '')}")
66
+ self.graph.upsert_entity(eid, name, etype, ent.get("description", ""), emb)
67
+ self.graph.upsert_mention(chunk_id, eid)
68
+ self.stats["entities"] += 1
69
+ self.stats["mentions"] += 1
70
+
71
+ for rel in data.get("relations", []):
72
+ sid = name_to_id.get(rel.get("source", ""))
73
+ tid = name_to_id.get(rel.get("target", ""))
74
+ if sid and tid:
75
+ self.graph.upsert_relation(sid, tid, rel.get("type", "RELATED_TO"),
76
+ rel.get("description", ""))
77
+ self.stats["relations"] += 1
78
+
79
+ def ingest_hotpotqa(self, max_docs=100, split="validation", extract_entities=True):
80
+ """Ingest HotpotQA documents into the graph."""
81
+ from datasets import load_dataset
82
+ logger.info(f"Loading HotpotQA ({split}, max={max_docs})...")
83
+ ds = load_dataset("hotpotqa/hotpot_qa", "distractor", split=split)
84
+ ingested, seen = 0, set()
85
+ for row in ds:
86
+ if ingested >= max_docs:
87
+ break
88
+ for title, sents in zip(row["context"]["title"], row["context"]["sentences"]):
89
+ if title in seen or ingested >= max_docs:
90
+ continue
91
+ seen.add(title)
92
+ content = " ".join(sents)
93
+ if len(content) < 50:
94
+ continue
95
+ did = hashlib.md5(title.encode()).hexdigest()[:10]
96
+ self.ingest_document(did, title, content, "hotpotqa", extract_entities)
97
+ ingested += 1
98
+ if ingested % 10 == 0:
99
+ logger.info(f"Ingested {ingested}/{max_docs} documents...")
100
+ logger.info(f"Ingestion complete. Stats: {self.stats}")
101
+ return self.stats
102
+
103
+ def ingest_custom_documents(self, documents: List[Dict], extract_entities=True):
104
+ """Ingest custom documents. Each dict: {id, title, content, source}."""
105
+ for doc in documents:
106
+ self.ingest_document(
107
+ doc_id=doc.get("id", hashlib.md5(doc["title"].encode()).hexdigest()[:10]),
108
+ title=doc.get("title", ""), content=doc.get("content", ""),
109
+ source=doc.get("source", "custom"), extract_entities=extract_entities)
110
+ return self.stats