File size: 4,744 Bytes
a60c61c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""
Document Ingestion Pipeline
============================
Ingests documents from HotpotQA or custom sources into TigerGraph.
"""
import hashlib
import json
import logging
from typing import Dict, List, Tuple
from .layers.graph_layer import GraphLayer, chunk_text, generate_entity_id, generate_chunk_id
from .layers.llm_layer import LLMLayer
from .layers.orchestration_layer import EmbeddingManager

logger = logging.getLogger(__name__)


class IngestionPipeline:
    """Full document ingestion: Docs → Chunks → Embeddings → Entities → Graph."""

    def __init__(self, graph, llm, embedder, config=None):
        self.graph = graph
        self.llm = llm
        self.embedder = embedder
        self.config = config or {}
        self.stats = {"documents": 0, "chunks": 0, "entities": 0,
                      "relations": 0, "mentions": 0, "errors": 0}

    def ingest_document(self, doc_id, title, content, source="", extract_entities=True):
        """Ingest a single document into the graph."""
        self.graph.upsert_document(doc_id, title, content, source)
        self.stats["documents"] += 1

        chunks = chunk_text(content, self.config.get("chunk_size", 1000),
                            self.config.get("chunk_overlap", 100))
        if not chunks:
            return self.stats

        embs = self.embedder.embed(chunks)
        for i, (chunk, emb) in enumerate(zip(chunks, embs)):
            cid = generate_chunk_id(doc_id, i)
            self.graph.upsert_chunk(cid, chunk, emb, i, len(chunk.split()), doc_id)
            self.stats["chunks"] += 1
            if extract_entities:
                self._extract_entities(cid, chunk)
        return self.stats

    def _extract_entities(self, chunk_id, text):
        """Extract entities from chunk and upsert to graph."""
        try:
            resp = self.llm.extract_entities(text)
            data = json.loads(resp.content)
        except Exception as e:
            logger.error(f"Entity extraction failed: {e}")
            self.stats["errors"] += 1
            return

        name_to_id = {}
        for ent in data.get("entities", []):
            name = ent.get("name", "").strip()
            etype = ent.get("type", "CONCEPT").strip()
            if not name:
                continue
            eid = generate_entity_id(name, etype)
            name_to_id[name] = eid
            emb = self.embedder.embed_single(f"{name} ({etype}): {ent.get('description', '')}")
            self.graph.upsert_entity(eid, name, etype, ent.get("description", ""), emb)
            self.graph.upsert_mention(chunk_id, eid)
            self.stats["entities"] += 1
            self.stats["mentions"] += 1

        for rel in data.get("relations", []):
            sid = name_to_id.get(rel.get("source", ""))
            tid = name_to_id.get(rel.get("target", ""))
            if sid and tid:
                self.graph.upsert_relation(sid, tid, rel.get("type", "RELATED_TO"),
                                           rel.get("description", ""))
                self.stats["relations"] += 1

    def ingest_hotpotqa(self, max_docs=100, split="validation", extract_entities=True):
        """Ingest HotpotQA documents into the graph."""
        from datasets import load_dataset
        logger.info(f"Loading HotpotQA ({split}, max={max_docs})...")
        ds = load_dataset("hotpotqa/hotpot_qa", "distractor", split=split)
        ingested, seen = 0, set()
        for row in ds:
            if ingested >= max_docs:
                break
            for title, sents in zip(row["context"]["title"], row["context"]["sentences"]):
                if title in seen or ingested >= max_docs:
                    continue
                seen.add(title)
                content = " ".join(sents)
                if len(content) < 50:
                    continue
                did = hashlib.md5(title.encode()).hexdigest()[:10]
                self.ingest_document(did, title, content, "hotpotqa", extract_entities)
                ingested += 1
                if ingested % 10 == 0:
                    logger.info(f"Ingested {ingested}/{max_docs} documents...")
        logger.info(f"Ingestion complete. Stats: {self.stats}")
        return self.stats

    def ingest_custom_documents(self, documents: List[Dict], extract_entities=True):
        """Ingest custom documents. Each dict: {id, title, content, source}."""
        for doc in documents:
            self.ingest_document(
                doc_id=doc.get("id", hashlib.md5(doc["title"].encode()).hexdigest()[:10]),
                title=doc.get("title", ""), content=doc.get("content", ""),
                source=doc.get("source", "custom"), extract_entities=extract_entities)
        return self.stats