File size: 2,646 Bytes
cf5c011
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3e290f
 
cf5c011
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
"""Walk a knowledge-base directory, chunk each file, embed, persist FAISS index.

CLI entry point: `python -m src.rag.ingest [<input_dir> [<output_dir>]]`.
Defaults: input=`data/knowledge_base/`, output=`data/processed/faiss_index/`.

Supported file types: `.md`, `.txt`, `.pdf`. Other extensions are ignored
with a logged WARNING.
"""
from __future__ import annotations

import sys
from pathlib import Path

from src.core.logger import get_logger
from src.rag.chunker import chunk_text
from src.rag.embed import EMBEDDING_DIM, Embedder
from src.rag.store import FAISSStore

logger = get_logger(__name__)


_DEFAULT_INPUT = Path("data/knowledge_base")
_DEFAULT_OUTPUT = Path("data/processed/faiss_index")
_SUPPORTED = {".md", ".txt", ".pdf"}


def _read_pdf(path: Path) -> str:
    from pypdf import PdfReader
    reader = PdfReader(str(path))
    return "\n\n".join(page.extract_text() or "" for page in reader.pages)


def _read_file(path: Path) -> str:
    suffix = path.suffix.lower()
    if suffix == ".pdf":
        return _read_pdf(path)
    return path.read_text(encoding="utf-8", errors="replace")


def ingest_directory(input_dir: Path, output_dir: Path) -> int:
    """Ingest every supported file in `input_dir` into a FAISS index at `output_dir`.

    Returns the total number of chunks indexed.
    """
    input_dir = Path(input_dir)
    output_dir = Path(output_dir)

    files = sorted(p for p in input_dir.rglob("*") if p.suffix.lower() in _SUPPORTED)
    logger.info("Ingesting %d file(s) from %s", len(files), input_dir)

    all_chunks: list[dict] = []
    for path in files:
        try:
            text = _read_file(path)
        except Exception as e:
            logger.warning("Skipping %s (read failed: %s)", path, e)
            continue
        for i, ch in enumerate(chunk_text(text)):
            all_chunks.append({
                "text": ch,
                "source": str(path.relative_to(input_dir)),
                "chunk_index": i,
            })

    store = FAISSStore(dim=EMBEDDING_DIM)
    if all_chunks:
        embedder = Embedder()
        vectors = embedder.encode([c["text"] for c in all_chunks])
        store.add(vectors, all_chunks)

    store.save(output_dir)
    logger.info("Indexed %d chunk(s) → %s", len(all_chunks), output_dir)
    return len(all_chunks)


def main() -> None:
    args = sys.argv[1:]
    inp = Path(args[0]) if len(args) >= 1 else _DEFAULT_INPUT
    out = Path(args[1]) if len(args) >= 2 else _DEFAULT_OUTPUT
    ingest_directory(inp, out)
    # Per-call summary already logged at INFO inside ingest_directory; no print() in src/.


if __name__ == "__main__":
    main()