File size: 6,778 Bytes
3b6130d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 | """
rag.py — Plexi RAG Engine
=========================
Handles everything related to the LlamaIndex vector index:
- Downloading the pre-built index from GitHub
- Loading HuggingFace sentence-transformer embeddings
- Embedding queries and retrieving top-k chunks scoped by semester + subject
- Extracting text from PDFs for full-context fallback
- Formatting retrieved chunks for the LLM system prompt
"""
import io
import os
import tempfile
from pathlib import Path
import requests
# ---------------------------------------------------------------------------
# Optional LlamaIndex — graceful degradation if not installed
# ---------------------------------------------------------------------------
try:
from llama_index.core import Settings, StorageContext, load_index_from_storage
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
LLAMA_INDEX_AVAILABLE = True
except ImportError:
LLAMA_INDEX_AVAILABLE = False
try:
import PyPDF2
PYPDF2_AVAILABLE = True
except ImportError:
PYPDF2_AVAILABLE = False
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
MATERIALS_REPO = os.getenv("MATERIALS_REPO", "KunalGupta25/plexi-materials")
MANIFEST_BRANCH = os.getenv("MANIFEST_BRANCH", "main")
EMBED_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2"
INDEX_FILES = [
"default__vector_store.json",
"docstore.json",
"graph_store.json",
"image__vector_store.json",
"index_store.json",
]
DEFAULT_TOP_K = 5
# ---------------------------------------------------------------------------
# Index loading (called once at FastAPI startup)
# ---------------------------------------------------------------------------
def load_index():
"""
Download the pre-built LlamaIndex from the materials repo and return a
VectorStoreIndex ready for querying.
Returns (index, error_msg). index is None if loading failed.
"""
if not LLAMA_INDEX_AVAILABLE:
return None, "llama-index-core is not installed."
index_base_url = (
f"https://raw.githubusercontent.com/{MATERIALS_REPO}/{MANIFEST_BRANCH}/index"
)
index_dir = tempfile.mkdtemp(prefix="plexi_index_")
for filename in INDEX_FILES:
url = f"{index_base_url}/{filename}"
try:
resp = requests.get(url, timeout=30)
resp.raise_for_status()
with open(os.path.join(index_dir, filename), "wb") as fh:
fh.write(resp.content)
except Exception as err:
return None, f"Failed to download index file '{filename}': {err}"
try:
embed_model = HuggingFaceEmbedding(model_name=EMBED_MODEL_ID)
Settings.embed_model = embed_model
Settings.llm = None
storage_ctx = StorageContext.from_defaults(persist_dir=index_dir)
index = load_index_from_storage(storage_ctx)
return index, None
except Exception as err:
return None, f"Failed to load index from storage: {err}"
def load_embed_model():
"""Load and return the HuggingFace embedding model (for health checks)."""
if not LLAMA_INDEX_AVAILABLE:
return None
return HuggingFaceEmbedding(model_name=EMBED_MODEL_ID)
# ---------------------------------------------------------------------------
# Retrieval
# ---------------------------------------------------------------------------
def _matches_scope(node, semester: str, subject: str) -> bool:
"""Return True when a retrieved node belongs to the active semester + subject."""
metadata = getattr(node.node, "metadata", {}) or {}
return (
metadata.get("semester") == semester
and metadata.get("subject") == subject
)
def retrieve_chunks(
index,
query: str,
semester: str,
subject: str,
top_k: int = DEFAULT_TOP_K,
) -> list[dict]:
"""
Embed the query, retrieve top-k chunks from the index scoped to the
given semester + subject.
Returns a list of dicts:
{ text, score, filename, subject }
"""
if index is None:
return []
try:
# Fetch more than needed so we have room to filter by scope
retriever = index.as_retriever(similarity_top_k=max(top_k * 5, 10))
nodes = retriever.retrieve(query)
scoped = [n for n in nodes if _matches_scope(n, semester, subject)]
return [
{
"text": node.node.get_content(),
"score": round(float(node.score), 4) if node.score is not None else None,
"filename": (getattr(node.node, "metadata", {}) or {}).get("filename"),
"subject": (getattr(node.node, "metadata", {}) or {}).get("subject"),
}
for node in scoped[:top_k]
]
except Exception as err:
print(f"Retrieval error: {err}")
return []
# ---------------------------------------------------------------------------
# Context formatting (for system prompt injection)
# ---------------------------------------------------------------------------
def format_context(chunks: list[dict]) -> str:
"""Format retrieved chunks as a numbered block for the LLM system prompt."""
if not chunks:
return "(No relevant context retrieved for this query.)"
parts = []
for i, chunk in enumerate(chunks, start=1):
score_info = f" [relevance: {chunk['score']}]" if chunk.get("score") else ""
source = chunk.get("filename") or chunk.get("subject") or "Unknown source"
parts.append(
f"--- Chunk {i} | {source}{score_info} ---\n{chunk['text']}\n"
)
return "\n".join(parts)
# ---------------------------------------------------------------------------
# PDF text extraction (used for full-context fallback loading)
# ---------------------------------------------------------------------------
def read_pdf_text(pdf_bytes: bytes) -> str:
"""Extract plain text from PDF bytes. Returns empty string on failure."""
if not PYPDF2_AVAILABLE:
return ""
text_parts = []
try:
reader = PyPDF2.PdfReader(io.BytesIO(pdf_bytes))
for page in reader.pages:
try:
page_text = page.extract_text()
if page_text:
# Sanitise surrogate pairs that can appear in some PDFs
filtered = page_text.encode("utf-16", "surrogatepass").decode(
"utf-16", "ignore"
)
text_parts.append(filtered)
except Exception:
pass
except Exception:
return pdf_bytes.decode("utf-8", errors="ignore") if pdf_bytes else ""
return "\n".join(text_parts)
|