feat(rag): wire retrieval-augmented chat into chat_fn
Browse filesRAG (retrieval-augmented generation) is now the default chat mode on the
Space when the index loads. Pulls 83K BGE-embedded passages from the
Ghostgim/GhostLM-v0.9-experimental Models repo at startup; embeds each
user query with BAAI/bge-small-en-v1.5; takes top-4 by cosine similarity;
prepends them as 'Reference passages' in front of the question. The model
is not RAFT-trained yet so it sees retrieved context as part of the user
message, but even un-trained-for retrieval cuts the hallucination floor
the bare 81M model exhibits.
RAG load is wrapped in try/except: if the index file isn't in the Models
repo (e.g. upload still in progress), the Space falls back to bare chat
with an honest 'RAG: OFF' note in the description so the user knows
what they're getting. No silent failure modes.
Adds requirements: numpy >= 1.24, transformers >= 4.38, sentencepiece >=
0.1.99 (BGE deps). All small; no LFS budget impact on the Space.
Index files (rag/index.npy fp16 64MB + rag/chunks.jsonl 57MB + rag/meta.json)
land in the Models repo separately when bandwidth cooperates; once they
do, restart_space() picks them up without a Space code change.
- app.py +130 -1
- requirements.txt +16 -4
|
@@ -32,13 +32,15 @@ the default 200-token cap on the 81M model.
|
|
| 32 |
from __future__ import annotations
|
| 33 |
|
| 34 |
import gc
|
|
|
|
| 35 |
import os
|
| 36 |
import sys
|
| 37 |
from dataclasses import fields
|
| 38 |
from pathlib import Path
|
| 39 |
-
from typing import List
|
| 40 |
|
| 41 |
import gradio as gr
|
|
|
|
| 42 |
import torch
|
| 43 |
import torch.nn.functional as F
|
| 44 |
|
|
@@ -245,6 +247,117 @@ TOKENIZER = GhostTokenizer()
|
|
| 245 |
END_ID = TOKENIZER._special_tokens[TOKENIZER.END]
|
| 246 |
|
| 247 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
# ---------------------------------------------------------------------------
|
| 249 |
# Chat handler
|
| 250 |
# ---------------------------------------------------------------------------
|
|
@@ -263,6 +376,20 @@ def chat_fn(message: str, history: list, temperature: float, top_k: int,
|
|
| 263 |
# process (e.g. a buggy plugin flipping training mode).
|
| 264 |
MODEL.eval()
|
| 265 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
turns: list = []
|
| 267 |
for h in history:
|
| 268 |
if isinstance(h, dict) and h.get("role") in ("user", "assistant"):
|
|
@@ -356,6 +483,8 @@ rented GPU compute, where literature reports factual recall on cybersec
|
|
| 356 |
MCQ starting to emerge. Spec at
|
| 357 |
[`docs/ghost_base_spec.md`](https://github.com/joemunene-by/GhostLM/blob/main/docs/ghost_base_spec.md).
|
| 358 |
|
|
|
|
|
|
|
| 359 |
**Loaded checkpoint:** `{LOADED_FROM}`
|
| 360 |
"""
|
| 361 |
|
|
|
|
| 32 |
from __future__ import annotations
|
| 33 |
|
| 34 |
import gc
|
| 35 |
+
import json
|
| 36 |
import os
|
| 37 |
import sys
|
| 38 |
from dataclasses import fields
|
| 39 |
from pathlib import Path
|
| 40 |
+
from typing import List, Optional
|
| 41 |
|
| 42 |
import gradio as gr
|
| 43 |
+
import numpy as np
|
| 44 |
import torch
|
| 45 |
import torch.nn.functional as F
|
| 46 |
|
|
|
|
| 247 |
END_ID = TOKENIZER._special_tokens[TOKENIZER.END]
|
| 248 |
|
| 249 |
|
| 250 |
+
# ---------------------------------------------------------------------------
|
| 251 |
+
# RAG (retrieval-augmented generation)
|
| 252 |
+
#
|
| 253 |
+
# The bare 81M chat model hallucinates badly because at this scale it has
|
| 254 |
+
# learned the *register* of cybersec writing but not the *facts* in any
|
| 255 |
+
# retrievable form. RAG closes that gap without retraining: we embed the
|
| 256 |
+
# user's query with a small BGE bi-encoder, look up the top-K most-similar
|
| 257 |
+
# chunks from the pretrain corpus, and inject them as "Reference passages"
|
| 258 |
+
# in front of the question. The model then conditions on real facts
|
| 259 |
+
# instead of producing register-shaped fiction.
|
| 260 |
+
#
|
| 261 |
+
# The index (83K chunks of NVD / MITRE / CWE / OWASP / CTFtime / arXiv at
|
| 262 |
+
# 384-dim BGE embeddings) lives in the Models repo at rag/, alongside the
|
| 263 |
+
# weights. The Space pulls it on first launch via hf_hub_download.
|
| 264 |
+
# ---------------------------------------------------------------------------
|
| 265 |
+
|
| 266 |
+
RAG_INDEX: Optional[np.ndarray] = None
|
| 267 |
+
RAG_CHUNKS: Optional[List[dict]] = None
|
| 268 |
+
RAG_EMBEDDER_TOK = None
|
| 269 |
+
RAG_EMBEDDER = None
|
| 270 |
+
RAG_LOAD_ERROR: Optional[str] = None
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def _load_rag() -> None:
|
| 274 |
+
"""Load RAG index + embedder. On any failure leaves everything None
|
| 275 |
+
and stores the error message so the UI can surface it. The chat
|
| 276 |
+
handler treats RAG as optional: if it didn't load, generation still
|
| 277 |
+
works, just bare without retrieval."""
|
| 278 |
+
global RAG_INDEX, RAG_CHUNKS, RAG_EMBEDDER_TOK, RAG_EMBEDDER, RAG_LOAD_ERROR
|
| 279 |
+
try:
|
| 280 |
+
from huggingface_hub import hf_hub_download
|
| 281 |
+
print(f"Pulling RAG index from {HUB_REPO}...")
|
| 282 |
+
index_path = hf_hub_download(repo_id=HUB_REPO, filename="rag/index.npy", repo_type="model")
|
| 283 |
+
chunks_path = hf_hub_download(repo_id=HUB_REPO, filename="rag/chunks.jsonl", repo_type="model")
|
| 284 |
+
|
| 285 |
+
idx = np.load(index_path)
|
| 286 |
+
# Index ships as fp16 (halves the upload from 128 MB to 64 MB);
|
| 287 |
+
# upcast for the matmul against the fp32 query vector.
|
| 288 |
+
if idx.dtype != np.float32:
|
| 289 |
+
idx = idx.astype(np.float32)
|
| 290 |
+
chunks: List[dict] = []
|
| 291 |
+
with open(chunks_path) as f:
|
| 292 |
+
for line in f:
|
| 293 |
+
chunks.append(json.loads(line))
|
| 294 |
+
|
| 295 |
+
from transformers import AutoModel, AutoTokenizer
|
| 296 |
+
e_tok = AutoTokenizer.from_pretrained("BAAI/bge-small-en-v1.5")
|
| 297 |
+
e_model = AutoModel.from_pretrained("BAAI/bge-small-en-v1.5").eval()
|
| 298 |
+
if os.environ.get("SPACE_ID"):
|
| 299 |
+
# Same fp16 cast we apply to GhostLM on the Space's tight CPU
|
| 300 |
+
# memory budget. BGE-small is 130M params at fp32 (~260 MB);
|
| 301 |
+
# halving keeps the working set inside the cpu-basic worker
|
| 302 |
+
# ceiling alongside the v0.9 weights and the index matrix.
|
| 303 |
+
e_model = e_model.half()
|
| 304 |
+
|
| 305 |
+
RAG_INDEX = idx
|
| 306 |
+
RAG_CHUNKS = chunks
|
| 307 |
+
RAG_EMBEDDER_TOK = e_tok
|
| 308 |
+
RAG_EMBEDDER = e_model
|
| 309 |
+
print(f"RAG loaded: {len(chunks)} chunks, dim {idx.shape[1]}")
|
| 310 |
+
except Exception as e:
|
| 311 |
+
RAG_LOAD_ERROR = f"{type(e).__name__}: {e}"
|
| 312 |
+
print(f"RAG disabled, falling back to bare chat: {RAG_LOAD_ERROR}")
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
_load_rag()
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def retrieve(query: str, k: int = 4) -> List[dict]:
|
| 319 |
+
"""Embed the query and return the top-K chunks by cosine similarity.
|
| 320 |
+
Returns an empty list if RAG isn't loaded; caller handles that."""
|
| 321 |
+
if RAG_INDEX is None or RAG_EMBEDDER is None or RAG_EMBEDDER_TOK is None:
|
| 322 |
+
return []
|
| 323 |
+
# BGE recommends prefixing queries with this instruction string.
|
| 324 |
+
text = "Represent this sentence for searching relevant passages: " + query
|
| 325 |
+
enc = RAG_EMBEDDER_TOK(text, padding=True, truncation=True,
|
| 326 |
+
max_length=512, return_tensors="pt")
|
| 327 |
+
with torch.no_grad():
|
| 328 |
+
out = RAG_EMBEDDER(**enc)
|
| 329 |
+
emb = out.last_hidden_state[:, 0]
|
| 330 |
+
emb = F.normalize(emb, p=2, dim=-1)
|
| 331 |
+
q_vec = emb.cpu().to(torch.float32).numpy().reshape(-1)
|
| 332 |
+
scores = RAG_INDEX @ q_vec
|
| 333 |
+
top = np.argsort(-scores)[:k]
|
| 334 |
+
return [RAG_CHUNKS[i] for i in top]
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def format_rag_prompt(query: str, passages: List[dict]) -> str:
|
| 338 |
+
"""Wrap the query with retrieved reference passages. The model is
|
| 339 |
+
not RAFT-trained yet so it just sees this as part of the user
|
| 340 |
+
message; even without a RAFT pass, retrieval-augmented chat
|
| 341 |
+
dramatically reduces the bare 81M model's hallucination rate on
|
| 342 |
+
factual cybersec questions."""
|
| 343 |
+
if not passages:
|
| 344 |
+
return query
|
| 345 |
+
refs = []
|
| 346 |
+
for i, p in enumerate(passages):
|
| 347 |
+
text = p.get("text", "")
|
| 348 |
+
if len(text) > 400:
|
| 349 |
+
text = text[:400].rsplit(" ", 1)[0] + "..."
|
| 350 |
+
refs.append(f"[{i + 1}] ({p.get('source', '?')} {p.get('ref', '')}) {text}")
|
| 351 |
+
refs_block = "\n\n".join(refs)
|
| 352 |
+
return (
|
| 353 |
+
"Reference passages from the cybersecurity corpus:\n\n"
|
| 354 |
+
f"{refs_block}\n\n"
|
| 355 |
+
"Use the reference passages above to answer the question. If the "
|
| 356 |
+
"passages don't contain the answer, say so rather than guessing.\n\n"
|
| 357 |
+
f"Question: {query}"
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
|
| 361 |
# ---------------------------------------------------------------------------
|
| 362 |
# Chat handler
|
| 363 |
# ---------------------------------------------------------------------------
|
|
|
|
| 376 |
# process (e.g. a buggy plugin flipping training mode).
|
| 377 |
MODEL.eval()
|
| 378 |
|
| 379 |
+
# If RAG loaded successfully at startup, retrieve top-K relevant
|
| 380 |
+
# passages from the indexed corpus and prepend them to the user's
|
| 381 |
+
# message before building the chat prompt. The retrieval cost is
|
| 382 |
+
# ~1-2 s on cpu-basic; activated by default whenever RAG is
|
| 383 |
+
# available because it's the difference between "register-shaped
|
| 384 |
+
# fiction" and "answers grounded in real CVE/MITRE/CWE text".
|
| 385 |
+
if RAG_INDEX is not None:
|
| 386 |
+
try:
|
| 387 |
+
passages = retrieve(message, k=4)
|
| 388 |
+
if passages:
|
| 389 |
+
message = format_rag_prompt(message, passages)
|
| 390 |
+
except Exception as e: # noqa: BLE001 - never break chat for retrieval issues
|
| 391 |
+
print(f"RAG retrieve failed for this turn: {type(e).__name__}: {e}")
|
| 392 |
+
|
| 393 |
turns: list = []
|
| 394 |
for h in history:
|
| 395 |
if isinstance(h, dict) and h.get("role") in ("user", "assistant"):
|
|
|
|
| 483 |
MCQ starting to emerge. Spec at
|
| 484 |
[`docs/ghost_base_spec.md`](https://github.com/joemunene-by/GhostLM/blob/main/docs/ghost_base_spec.md).
|
| 485 |
|
| 486 |
+
**Retrieval-augmented mode:** {("**ON**. Each query is augmented with top-4 passages retrieved from a 83K-chunk index of the cybersec corpus (NVD / MITRE / CWE / OWASP / CTFtime / arXiv). The model conditions on real reference text instead of producing register-shaped fiction. Retrieval adds ~1-2 s per reply." if RAG_INDEX is not None else f"**OFF**. RAG could not load at startup (`{RAG_LOAD_ERROR}`). Generation is bare; expect hallucination on factual questions.")}
|
| 487 |
+
|
| 488 |
**Loaded checkpoint:** `{LOADED_FROM}`
|
| 489 |
"""
|
| 490 |
|
|
@@ -14,12 +14,24 @@ torch>=2.0.0
|
|
| 14 |
# tiktoken is the GPT-2 BPE backend the GhostTokenizer wraps.
|
| 15 |
tiktoken>=0.5.0
|
| 16 |
|
| 17 |
-
# huggingface_hub for hf_hub_download. The v0.9 chat weights
|
| 18 |
-
# in the Models repo Ghostgim/GhostLM-v0.9-experimental
|
| 19 |
-
# Space's own LFS) so the Space stays within the 1 GB
|
| 20 |
-
# Pulled and cached on first launch.
|
| 21 |
huggingface_hub>=0.20
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
# Python 3.13 removed the stdlib audioop module that gradio's transitive
|
| 24 |
# pydub dep imports at module-load time. Without this the entire gradio
|
| 25 |
# import chain fails with ModuleNotFoundError: No module named
|
|
|
|
| 14 |
# tiktoken is the GPT-2 BPE backend the GhostTokenizer wraps.
|
| 15 |
tiktoken>=0.5.0
|
| 16 |
|
| 17 |
+
# huggingface_hub for hf_hub_download. The v0.9 chat weights and the
|
| 18 |
+
# RAG index both live in the Models repo Ghostgim/GhostLM-v0.9-experimental
|
| 19 |
+
# (not in the Space's own LFS) so the Space stays within the 1 GB
|
| 20 |
+
# free-tier cap. Pulled and cached on first launch.
|
| 21 |
huggingface_hub>=0.20
|
| 22 |
|
| 23 |
+
# transformers + sentencepiece for the BGE-small-en-v1.5 retrieval
|
| 24 |
+
# embedder used by the RAG layer. transformers is already pulled in
|
| 25 |
+
# transitively but listing it explicitly pins the version range we test
|
| 26 |
+
# against. sentencepiece is BGE's tokenizer backend.
|
| 27 |
+
transformers>=4.38
|
| 28 |
+
sentencepiece>=0.1.99
|
| 29 |
+
|
| 30 |
+
# numpy for the in-memory retrieval index (83K x 384 float32 matrix).
|
| 31 |
+
# Already a transitive dep of torch but listed explicitly because the
|
| 32 |
+
# RAG code path imports it directly.
|
| 33 |
+
numpy>=1.24
|
| 34 |
+
|
| 35 |
# Python 3.13 removed the stdlib audioop module that gradio's transitive
|
| 36 |
# pydub dep imports at module-load time. Without this the entire gradio
|
| 37 |
# import chain fails with ModuleNotFoundError: No module named
|