File size: 6,778 Bytes
3b6130d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
"""
rag.py — Plexi RAG Engine
=========================
Handles everything related to the LlamaIndex vector index:
  - Downloading the pre-built index from GitHub
  - Loading HuggingFace sentence-transformer embeddings
  - Embedding queries and retrieving top-k chunks scoped by semester + subject
  - Extracting text from PDFs for full-context fallback
  - Formatting retrieved chunks for the LLM system prompt
"""

import io
import os
import tempfile
from pathlib import Path

import requests

# ---------------------------------------------------------------------------
# Optional LlamaIndex — graceful degradation if not installed
# ---------------------------------------------------------------------------
try:
    from llama_index.core import Settings, StorageContext, load_index_from_storage
    from llama_index.embeddings.huggingface import HuggingFaceEmbedding

    LLAMA_INDEX_AVAILABLE = True
except ImportError:
    LLAMA_INDEX_AVAILABLE = False

try:
    import PyPDF2

    PYPDF2_AVAILABLE = True
except ImportError:
    PYPDF2_AVAILABLE = False

# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
MATERIALS_REPO = os.getenv("MATERIALS_REPO", "KunalGupta25/plexi-materials")
MANIFEST_BRANCH = os.getenv("MANIFEST_BRANCH", "main")
EMBED_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2"

INDEX_FILES = [
    "default__vector_store.json",
    "docstore.json",
    "graph_store.json",
    "image__vector_store.json",
    "index_store.json",
]

DEFAULT_TOP_K = 5


# ---------------------------------------------------------------------------
# Index loading (called once at FastAPI startup)
# ---------------------------------------------------------------------------

def load_index():
    """
    Download the pre-built LlamaIndex from the materials repo and return a
    VectorStoreIndex ready for querying.

    Returns (index, error_msg). index is None if loading failed.
    """
    if not LLAMA_INDEX_AVAILABLE:
        return None, "llama-index-core is not installed."

    index_base_url = (
        f"https://raw.githubusercontent.com/{MATERIALS_REPO}/{MANIFEST_BRANCH}/index"
    )
    index_dir = tempfile.mkdtemp(prefix="plexi_index_")

    for filename in INDEX_FILES:
        url = f"{index_base_url}/{filename}"
        try:
            resp = requests.get(url, timeout=30)
            resp.raise_for_status()
            with open(os.path.join(index_dir, filename), "wb") as fh:
                fh.write(resp.content)
        except Exception as err:
            return None, f"Failed to download index file '{filename}': {err}"

    try:
        embed_model = HuggingFaceEmbedding(model_name=EMBED_MODEL_ID)
        Settings.embed_model = embed_model
        Settings.llm = None

        storage_ctx = StorageContext.from_defaults(persist_dir=index_dir)
        index = load_index_from_storage(storage_ctx)
        return index, None
    except Exception as err:
        return None, f"Failed to load index from storage: {err}"


def load_embed_model():
    """Load and return the HuggingFace embedding model (for health checks)."""
    if not LLAMA_INDEX_AVAILABLE:
        return None
    return HuggingFaceEmbedding(model_name=EMBED_MODEL_ID)


# ---------------------------------------------------------------------------
# Retrieval
# ---------------------------------------------------------------------------

def _matches_scope(node, semester: str, subject: str) -> bool:
    """Return True when a retrieved node belongs to the active semester + subject."""
    metadata = getattr(node.node, "metadata", {}) or {}
    return (
        metadata.get("semester") == semester
        and metadata.get("subject") == subject
    )


def retrieve_chunks(
    index,
    query: str,
    semester: str,
    subject: str,
    top_k: int = DEFAULT_TOP_K,
) -> list[dict]:
    """
    Embed the query, retrieve top-k chunks from the index scoped to the
    given semester + subject.

    Returns a list of dicts:
        { text, score, filename, subject }
    """
    if index is None:
        return []

    try:
        # Fetch more than needed so we have room to filter by scope
        retriever = index.as_retriever(similarity_top_k=max(top_k * 5, 10))
        nodes = retriever.retrieve(query)

        scoped = [n for n in nodes if _matches_scope(n, semester, subject)]

        return [
            {
                "text": node.node.get_content(),
                "score": round(float(node.score), 4) if node.score is not None else None,
                "filename": (getattr(node.node, "metadata", {}) or {}).get("filename"),
                "subject": (getattr(node.node, "metadata", {}) or {}).get("subject"),
            }
            for node in scoped[:top_k]
        ]
    except Exception as err:
        print(f"Retrieval error: {err}")
        return []


# ---------------------------------------------------------------------------
# Context formatting (for system prompt injection)
# ---------------------------------------------------------------------------

def format_context(chunks: list[dict]) -> str:
    """Format retrieved chunks as a numbered block for the LLM system prompt."""
    if not chunks:
        return "(No relevant context retrieved for this query.)"
    parts = []
    for i, chunk in enumerate(chunks, start=1):
        score_info = f"  [relevance: {chunk['score']}]" if chunk.get("score") else ""
        source = chunk.get("filename") or chunk.get("subject") or "Unknown source"
        parts.append(
            f"--- Chunk {i} | {source}{score_info} ---\n{chunk['text']}\n"
        )
    return "\n".join(parts)


# ---------------------------------------------------------------------------
# PDF text extraction (used for full-context fallback loading)
# ---------------------------------------------------------------------------

def read_pdf_text(pdf_bytes: bytes) -> str:
    """Extract plain text from PDF bytes. Returns empty string on failure."""
    if not PYPDF2_AVAILABLE:
        return ""
    text_parts = []
    try:
        reader = PyPDF2.PdfReader(io.BytesIO(pdf_bytes))
        for page in reader.pages:
            try:
                page_text = page.extract_text()
                if page_text:
                    # Sanitise surrogate pairs that can appear in some PDFs
                    filtered = page_text.encode("utf-16", "surrogatepass").decode(
                        "utf-16", "ignore"
                    )
                    text_parts.append(filtered)
            except Exception:
                pass
    except Exception:
        return pdf_bytes.decode("utf-8", errors="ignore") if pdf_bytes else ""
    return "\n".join(text_parts)