Ghostgim commited on
Commit
ae0fa76
·
verified ·
1 Parent(s): 551cb99

feat(rag): wire retrieval-augmented chat into chat_fn

Browse files

RAG (retrieval-augmented generation) is now the default chat mode on the
Space when the index loads. Pulls 83K BGE-embedded passages from the
Ghostgim/GhostLM-v0.9-experimental Models repo at startup; embeds each
user query with BAAI/bge-small-en-v1.5; takes top-4 by cosine similarity;
prepends them as 'Reference passages' in front of the question. The model
is not RAFT-trained yet so it sees retrieved context as part of the user
message, but even un-trained-for retrieval cuts the hallucination floor
the bare 81M model exhibits.

RAG load is wrapped in try/except: if the index file isn't in the Models
repo (e.g. upload still in progress), the Space falls back to bare chat
with an honest 'RAG: OFF' note in the description so the user knows
what they're getting. No silent failure modes.

Adds requirements: numpy >= 1.24, transformers >= 4.38, sentencepiece >=
0.1.99 (BGE deps). All small; no LFS budget impact on the Space.

Index files (rag/index.npy fp16 64MB + rag/chunks.jsonl 57MB + rag/meta.json)
land in the Models repo separately when bandwidth cooperates; once they
do, restart_space() picks them up without a Space code change.

Files changed (2) hide show
  1. app.py +130 -1
  2. requirements.txt +16 -4
app.py CHANGED
@@ -32,13 +32,15 @@ the default 200-token cap on the 81M model.
32
  from __future__ import annotations
33
 
34
  import gc
 
35
  import os
36
  import sys
37
  from dataclasses import fields
38
  from pathlib import Path
39
- from typing import List
40
 
41
  import gradio as gr
 
42
  import torch
43
  import torch.nn.functional as F
44
 
@@ -245,6 +247,117 @@ TOKENIZER = GhostTokenizer()
245
  END_ID = TOKENIZER._special_tokens[TOKENIZER.END]
246
 
247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
  # ---------------------------------------------------------------------------
249
  # Chat handler
250
  # ---------------------------------------------------------------------------
@@ -263,6 +376,20 @@ def chat_fn(message: str, history: list, temperature: float, top_k: int,
263
  # process (e.g. a buggy plugin flipping training mode).
264
  MODEL.eval()
265
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  turns: list = []
267
  for h in history:
268
  if isinstance(h, dict) and h.get("role") in ("user", "assistant"):
@@ -356,6 +483,8 @@ rented GPU compute, where literature reports factual recall on cybersec
356
  MCQ starting to emerge. Spec at
357
  [`docs/ghost_base_spec.md`](https://github.com/joemunene-by/GhostLM/blob/main/docs/ghost_base_spec.md).
358
 
 
 
359
  **Loaded checkpoint:** `{LOADED_FROM}`
360
  """
361
 
 
32
  from __future__ import annotations
33
 
34
  import gc
35
+ import json
36
  import os
37
  import sys
38
  from dataclasses import fields
39
  from pathlib import Path
40
+ from typing import List, Optional
41
 
42
  import gradio as gr
43
+ import numpy as np
44
  import torch
45
  import torch.nn.functional as F
46
 
 
247
  END_ID = TOKENIZER._special_tokens[TOKENIZER.END]
248
 
249
 
250
+ # ---------------------------------------------------------------------------
251
+ # RAG (retrieval-augmented generation)
252
+ #
253
+ # The bare 81M chat model hallucinates badly because at this scale it has
254
+ # learned the *register* of cybersec writing but not the *facts* in any
255
+ # retrievable form. RAG closes that gap without retraining: we embed the
256
+ # user's query with a small BGE bi-encoder, look up the top-K most-similar
257
+ # chunks from the pretrain corpus, and inject them as "Reference passages"
258
+ # in front of the question. The model then conditions on real facts
259
+ # instead of producing register-shaped fiction.
260
+ #
261
+ # The index (83K chunks of NVD / MITRE / CWE / OWASP / CTFtime / arXiv at
262
+ # 384-dim BGE embeddings) lives in the Models repo at rag/, alongside the
263
+ # weights. The Space pulls it on first launch via hf_hub_download.
264
+ # ---------------------------------------------------------------------------
265
+
266
+ RAG_INDEX: Optional[np.ndarray] = None
267
+ RAG_CHUNKS: Optional[List[dict]] = None
268
+ RAG_EMBEDDER_TOK = None
269
+ RAG_EMBEDDER = None
270
+ RAG_LOAD_ERROR: Optional[str] = None
271
+
272
+
273
+ def _load_rag() -> None:
274
+ """Load RAG index + embedder. On any failure leaves everything None
275
+ and stores the error message so the UI can surface it. The chat
276
+ handler treats RAG as optional: if it didn't load, generation still
277
+ works, just bare without retrieval."""
278
+ global RAG_INDEX, RAG_CHUNKS, RAG_EMBEDDER_TOK, RAG_EMBEDDER, RAG_LOAD_ERROR
279
+ try:
280
+ from huggingface_hub import hf_hub_download
281
+ print(f"Pulling RAG index from {HUB_REPO}...")
282
+ index_path = hf_hub_download(repo_id=HUB_REPO, filename="rag/index.npy", repo_type="model")
283
+ chunks_path = hf_hub_download(repo_id=HUB_REPO, filename="rag/chunks.jsonl", repo_type="model")
284
+
285
+ idx = np.load(index_path)
286
+ # Index ships as fp16 (halves the upload from 128 MB to 64 MB);
287
+ # upcast for the matmul against the fp32 query vector.
288
+ if idx.dtype != np.float32:
289
+ idx = idx.astype(np.float32)
290
+ chunks: List[dict] = []
291
+ with open(chunks_path) as f:
292
+ for line in f:
293
+ chunks.append(json.loads(line))
294
+
295
+ from transformers import AutoModel, AutoTokenizer
296
+ e_tok = AutoTokenizer.from_pretrained("BAAI/bge-small-en-v1.5")
297
+ e_model = AutoModel.from_pretrained("BAAI/bge-small-en-v1.5").eval()
298
+ if os.environ.get("SPACE_ID"):
299
+ # Same fp16 cast we apply to GhostLM on the Space's tight CPU
300
+ # memory budget. BGE-small is 130M params at fp32 (~260 MB);
301
+ # halving keeps the working set inside the cpu-basic worker
302
+ # ceiling alongside the v0.9 weights and the index matrix.
303
+ e_model = e_model.half()
304
+
305
+ RAG_INDEX = idx
306
+ RAG_CHUNKS = chunks
307
+ RAG_EMBEDDER_TOK = e_tok
308
+ RAG_EMBEDDER = e_model
309
+ print(f"RAG loaded: {len(chunks)} chunks, dim {idx.shape[1]}")
310
+ except Exception as e:
311
+ RAG_LOAD_ERROR = f"{type(e).__name__}: {e}"
312
+ print(f"RAG disabled, falling back to bare chat: {RAG_LOAD_ERROR}")
313
+
314
+
315
+ _load_rag()
316
+
317
+
318
+ def retrieve(query: str, k: int = 4) -> List[dict]:
319
+ """Embed the query and return the top-K chunks by cosine similarity.
320
+ Returns an empty list if RAG isn't loaded; caller handles that."""
321
+ if RAG_INDEX is None or RAG_EMBEDDER is None or RAG_EMBEDDER_TOK is None:
322
+ return []
323
+ # BGE recommends prefixing queries with this instruction string.
324
+ text = "Represent this sentence for searching relevant passages: " + query
325
+ enc = RAG_EMBEDDER_TOK(text, padding=True, truncation=True,
326
+ max_length=512, return_tensors="pt")
327
+ with torch.no_grad():
328
+ out = RAG_EMBEDDER(**enc)
329
+ emb = out.last_hidden_state[:, 0]
330
+ emb = F.normalize(emb, p=2, dim=-1)
331
+ q_vec = emb.cpu().to(torch.float32).numpy().reshape(-1)
332
+ scores = RAG_INDEX @ q_vec
333
+ top = np.argsort(-scores)[:k]
334
+ return [RAG_CHUNKS[i] for i in top]
335
+
336
+
337
+ def format_rag_prompt(query: str, passages: List[dict]) -> str:
338
+ """Wrap the query with retrieved reference passages. The model is
339
+ not RAFT-trained yet so it just sees this as part of the user
340
+ message; even without a RAFT pass, retrieval-augmented chat
341
+ dramatically reduces the bare 81M model's hallucination rate on
342
+ factual cybersec questions."""
343
+ if not passages:
344
+ return query
345
+ refs = []
346
+ for i, p in enumerate(passages):
347
+ text = p.get("text", "")
348
+ if len(text) > 400:
349
+ text = text[:400].rsplit(" ", 1)[0] + "..."
350
+ refs.append(f"[{i + 1}] ({p.get('source', '?')} {p.get('ref', '')}) {text}")
351
+ refs_block = "\n\n".join(refs)
352
+ return (
353
+ "Reference passages from the cybersecurity corpus:\n\n"
354
+ f"{refs_block}\n\n"
355
+ "Use the reference passages above to answer the question. If the "
356
+ "passages don't contain the answer, say so rather than guessing.\n\n"
357
+ f"Question: {query}"
358
+ )
359
+
360
+
361
  # ---------------------------------------------------------------------------
362
  # Chat handler
363
  # ---------------------------------------------------------------------------
 
376
  # process (e.g. a buggy plugin flipping training mode).
377
  MODEL.eval()
378
 
379
+ # If RAG loaded successfully at startup, retrieve top-K relevant
380
+ # passages from the indexed corpus and prepend them to the user's
381
+ # message before building the chat prompt. The retrieval cost is
382
+ # ~1-2 s on cpu-basic; activated by default whenever RAG is
383
+ # available because it's the difference between "register-shaped
384
+ # fiction" and "answers grounded in real CVE/MITRE/CWE text".
385
+ if RAG_INDEX is not None:
386
+ try:
387
+ passages = retrieve(message, k=4)
388
+ if passages:
389
+ message = format_rag_prompt(message, passages)
390
+ except Exception as e: # noqa: BLE001 - never break chat for retrieval issues
391
+ print(f"RAG retrieve failed for this turn: {type(e).__name__}: {e}")
392
+
393
  turns: list = []
394
  for h in history:
395
  if isinstance(h, dict) and h.get("role") in ("user", "assistant"):
 
483
  MCQ starting to emerge. Spec at
484
  [`docs/ghost_base_spec.md`](https://github.com/joemunene-by/GhostLM/blob/main/docs/ghost_base_spec.md).
485
 
486
+ **Retrieval-augmented mode:** {("**ON**. Each query is augmented with top-4 passages retrieved from a 83K-chunk index of the cybersec corpus (NVD / MITRE / CWE / OWASP / CTFtime / arXiv). The model conditions on real reference text instead of producing register-shaped fiction. Retrieval adds ~1-2 s per reply." if RAG_INDEX is not None else f"**OFF**. RAG could not load at startup (`{RAG_LOAD_ERROR}`). Generation is bare; expect hallucination on factual questions.")}
487
+
488
  **Loaded checkpoint:** `{LOADED_FROM}`
489
  """
490
 
requirements.txt CHANGED
@@ -14,12 +14,24 @@ torch>=2.0.0
14
  # tiktoken is the GPT-2 BPE backend the GhostTokenizer wraps.
15
  tiktoken>=0.5.0
16
 
17
- # huggingface_hub for hf_hub_download. The v0.9 chat weights are stored
18
- # in the Models repo Ghostgim/GhostLM-v0.9-experimental (not in the
19
- # Space's own LFS) so the Space stays within the 1 GB free-tier cap.
20
- # Pulled and cached on first launch.
21
  huggingface_hub>=0.20
22
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  # Python 3.13 removed the stdlib audioop module that gradio's transitive
24
  # pydub dep imports at module-load time. Without this the entire gradio
25
  # import chain fails with ModuleNotFoundError: No module named
 
14
  # tiktoken is the GPT-2 BPE backend the GhostTokenizer wraps.
15
  tiktoken>=0.5.0
16
 
17
+ # huggingface_hub for hf_hub_download. The v0.9 chat weights and the
18
+ # RAG index both live in the Models repo Ghostgim/GhostLM-v0.9-experimental
19
+ # (not in the Space's own LFS) so the Space stays within the 1 GB
20
+ # free-tier cap. Pulled and cached on first launch.
21
  huggingface_hub>=0.20
22
 
23
+ # transformers + sentencepiece for the BGE-small-en-v1.5 retrieval
24
+ # embedder used by the RAG layer. transformers is already pulled in
25
+ # transitively but listing it explicitly pins the version range we test
26
+ # against. sentencepiece is BGE's tokenizer backend.
27
+ transformers>=4.38
28
+ sentencepiece>=0.1.99
29
+
30
+ # numpy for the in-memory retrieval index (83K x 384 float32 matrix).
31
+ # Already a transitive dep of torch but listed explicitly because the
32
+ # RAG code path imports it directly.
33
+ numpy>=1.24
34
+
35
  # Python 3.13 removed the stdlib audioop module that gradio's transitive
36
  # pydub dep imports at module-load time. Without this the entire gradio
37
  # import chain fails with ModuleNotFoundError: No module named