File size: 5,587 Bytes
a60c61c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
577adc4
a60c61c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
577adc4
 
 
 
 
 
 
 
 
 
 
a60c61c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
577adc4
 
 
 
 
 
 
 
 
 
 
 
 
a60c61c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""
Document Ingestion Pipeline
============================
Ingests documents from HotpotQA or custom sources into TigerGraph.
"""
import hashlib
import json
import logging
from typing import Dict, List, Tuple
from .layers.graph_layer import GraphLayer, chunk_text, generate_entity_id, generate_chunk_id
from .layers.llm_layer import LLMLayer
from .layers.orchestration_layer import EmbeddingManager

logger = logging.getLogger(__name__)


class IngestionPipeline:
    """Full document ingestion: Docs → Chunks → Embeddings → Entities → Graph."""

    def __init__(self, graph, llm, embedder, config=None):
        self.graph = graph
        self.llm = llm
        self.embedder = embedder
        self.config = config or {}
        self.stats = {"documents": 0, "chunks": 0, "entities": 0,
                      "relations": 0, "mentions": 0, "errors": 0}

    def ingest_document(self, doc_id, title, content, source="", extract_entities=True):
        """Ingest a single document into the graph."""
        content = content[:20000]  # cap at ~4k tokens to prevent MemoryError
        self.graph.upsert_document(doc_id, title, content, source)
        self.stats["documents"] += 1

        chunks = chunk_text(content, self.config.get("chunk_size", 1000),
                            self.config.get("chunk_overlap", 100))
        if not chunks:
            return self.stats

        embs = self.embedder.embed(chunks)
        for i, (chunk, emb) in enumerate(zip(chunks, embs)):
            cid = generate_chunk_id(doc_id, i)
            self.graph.upsert_chunk(cid, chunk, emb, i, len(chunk.split()), doc_id)
            self.stats["chunks"] += 1
            if extract_entities:
                self._extract_entities(cid, chunk)
        return self.stats

    def _extract_entities(self, chunk_id, text):
        """Extract entities from chunk and upsert to graph."""
        try:
            resp = self.llm.extract_entities(text)
            content = resp.content
            start = content.find("{")
            end = content.rfind("}") + 1
            if start == -1 or end == 0:
                raise ValueError("No JSON found")
            raw = content[start:end]
            try:
                data = json.loads(raw)
            except Exception:
                from json_repair import repair_json
                data = json.loads(repair_json(raw))
        except Exception as e:
            logger.error(f"Entity extraction failed: {e}")
            self.stats["errors"] += 1
            return

        name_to_id = {}
        for ent in data.get("entities", []):
            name = ent.get("name", "").strip()
            etype = ent.get("type", "CONCEPT").strip()
            if not name:
                continue
            eid = generate_entity_id(name, etype)
            name_to_id[name] = eid
            emb = self.embedder.embed_single(f"{name} ({etype}): {ent.get('description', '')}")
            self.graph.upsert_entity(eid, name, etype, ent.get("description", ""), emb)
            self.graph.upsert_mention(chunk_id, eid)
            self.stats["entities"] += 1
            self.stats["mentions"] += 1

        for rel in data.get("relations", []):
            sid = name_to_id.get(rel.get("source", ""))
            tid = name_to_id.get(rel.get("target", ""))
            if sid and tid:
                self.graph.upsert_relation(sid, tid, rel.get("type", "RELATED_TO"),
                                           rel.get("description", ""))
                self.stats["relations"] += 1

    def ingest_hotpotqa(self, max_docs=100, split="validation", extract_entities=True):
        """Ingest HotpotQA documents into the graph."""
        from datasets import load_dataset
        logger.info(f"Loading HotpotQA ({split}, max={max_docs})...")
        ds = load_dataset("hotpotqa/hotpot_qa", "distractor", split=split)
        ingested, seen = 0, set()
        for row in ds:
            if ingested >= max_docs:
                break
            for title, sents in zip(row["context"]["title"], row["context"]["sentences"]):
                if title in seen or ingested >= max_docs:
                    continue
                seen.add(title)
                content = " ".join(sents)
                if len(content) < 50:
                    continue
                did = hashlib.md5(title.encode()).hexdigest()[:10]
                self.ingest_document(did, title, content, "hotpotqa", extract_entities)
                ingested += 1
                if ingested % 10 == 0:
                    logger.info(f"Ingested {ingested}/{max_docs} documents...")
        logger.info(f"Ingestion complete. Stats: {self.stats}")
        return self.stats

    def ingest_custom_documents(self, documents: List[Dict], extract_entities=True):
        """Ingest custom documents. Each dict: {id, title, content, source}."""
        import gc
        total = len(documents)
        for i, doc in enumerate(documents, 1):
            try:
                self.ingest_document(
                    doc_id=doc.get("id", hashlib.md5(doc["title"].encode()).hexdigest()[:10]),
                    title=doc.get("title", ""), content=doc.get("content", ""),
                    source=doc.get("source", "custom"), extract_entities=extract_entities)
                logger.info(f"Ingested {i}/{total}: {doc.get('title', '')[:60]}")
            except MemoryError:
                logger.warning(f"Skipped {i}/{total} (MemoryError): {doc.get('title', '')[:60]}")
                self.stats["errors"] += 1
            gc.collect()
        return self.stats