| """GhostLM Gradio Space, multi-turn chat for the v0.9 chat (81M wide) model. |
| |
| The canonical chat model on the Space is now ``phase19_chat_v09`` (v0.9 |
| chat, 81M params, wide v0.7 architecture, pretrained on the 273M-token |
| PRIMUS + CWE + OWASP + RFC + fact-QA corpus and chat-tuned with the |
| chat-v3 SFT recipe). It is the ghost-small bench winner: 28.9% on |
| debiased CTIBench full (n=2500), 59.2% on the in-repo CTF MCQ eval, and |
| 39.3% on SecQA. Free-form fact recall is at floor across the entire |
| ghost-small line by design: at 81M params the model has the register of |
| cybersec writing but not the facts. The next rung (ghost-base ~360M) is |
| gated on GPU compute. |
| |
| The v0.9 chat weights (~324 MB slim) live in the Hub model repo |
| ``Ghostgim/GhostLM-v0.9-experimental`` rather than in the Space's own |
| LFS. The Space pulls them with ``huggingface_hub.hf_hub_download`` on |
| first launch and caches them locally; this keeps the Space well within |
| HF's 1 GB free-tier LFS budget. The previous Space checkpoint |
| (v0.5.0 chat-v3 on the v0.4 base, 45M, 36.9% single-order CTIBench) was |
| removed; it remains in the GitHub repo's checkpoint history at |
| ``checkpoints/phase5_chat_v3/best_model.pt``. |
| |
| Multi-turn chat using the tokenizer's three role tokens |
| (<|ghost_user|>, <|ghost_assistant|>, <|ghost_end|>). Generation stops the |
| moment the assistant's <|ghost_end|> is sampled. Repetition penalty is on |
| by default. Without it the small model occasionally degenerates into |
| "Wifi Wifi Wifi" loops on short prompts. |
| |
| Runs on Spaces cpu-basic (2 vCPU). Generation is ~10-25 s per reply at |
| the default 200-token cap on the 81M model. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import gc |
| import json |
| import os |
| import sys |
| from dataclasses import fields |
| from pathlib import Path |
| from typing import List, Optional |
|
|
| import gradio as gr |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
|
|
| REPO_ROOT = Path(__file__).resolve().parent |
| if str(REPO_ROOT) not in sys.path: |
| sys.path.insert(0, str(REPO_ROOT)) |
|
|
| from ghostlm.config import GhostLMConfig |
| from ghostlm.model import GhostLM |
| from ghostlm.tokenizer import GhostTokenizer |
|
|
|
|
| |
| |
| |
|
|
| HUB_REPO = "Ghostgim/GhostLM-v0.9-experimental" |
| HUB_FILENAME = "best_model.pt" |
|
|
| CHECKPOINT_CANDIDATES = [ |
| |
| |
| |
| |
| |
| |
| |
| "checkpoints/phase19_chat_v09/best_model.pt", |
| "checkpoints/best_model.pt", |
| ] |
|
|
|
|
| def find_checkpoint() -> str: |
| """Return a usable checkpoint path. |
| |
| Local paths win first so local dev doesn't need network. If none |
| exist (the normal case on the Space), pull the v0.9 weights from |
| the Hub model repo and return the cached local path. |
| """ |
| for path in CHECKPOINT_CANDIDATES: |
| if Path(path).exists(): |
| return path |
| try: |
| from huggingface_hub import hf_hub_download |
| print(f"Local checkpoint missing; downloading {HUB_REPO}/{HUB_FILENAME} from the Hub...") |
| return hf_hub_download(repo_id=HUB_REPO, filename=HUB_FILENAME, repo_type="model") |
| except Exception as e: |
| print(f"Hub fallback also failed: {type(e).__name__}: {e}") |
| return "" |
|
|
|
|
| def load_model(path: str): |
| """Load a GhostLM checkpoint into eval mode on CPU.""" |
| if not path: |
| |
| config = GhostLMConfig.from_preset("ghost-tiny") |
| config.vocab_size = 50264 |
| config.context_length = 256 |
| model = GhostLM(config).eval() |
| return model, config, "(random ghost-tiny, weights missing on Space)" |
|
|
| ckpt = torch.load(path, map_location="cpu", weights_only=False) |
| saved = ckpt["config"] |
| config = GhostLMConfig(**{ |
| f.name: saved[f.name] |
| for f in fields(GhostLMConfig) |
| if f.name in saved |
| }) |
| model = GhostLM(config) |
| state = ckpt.get("model_state_dict", ckpt.get("model")) |
| model.load_state_dict(state, strict=False) |
| model.eval() |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if os.environ.get("SPACE_ID"): |
| model = model.half() |
|
|
| return model, config, path |
|
|
|
|
| |
| |
| |
|
|
|
|
| def sample_next( |
| logits: torch.Tensor, |
| *, |
| temperature: float, |
| top_k: int, |
| top_p: float, |
| prev_ids: List[int], |
| repetition_penalty: float, |
| ) -> int: |
| """Sample one token from logits with temperature, top-k / top-p, and rep-penalty.""" |
| if prev_ids and repetition_penalty != 1.0: |
| for tok in set(prev_ids): |
| if logits[tok] > 0: |
| logits[tok] = logits[tok] / repetition_penalty |
| else: |
| logits[tok] = logits[tok] * repetition_penalty |
| logits = logits / max(temperature, 1e-6) |
| if top_k and top_k > 0: |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| logits[logits < v[..., -1:]] = float("-inf") |
| if top_p and top_p < 1.0: |
| sorted_logits, sorted_idx = torch.sort(logits, descending=True) |
| probs = F.softmax(sorted_logits, dim=-1) |
| cum = probs.cumsum(dim=-1) |
| cutoff = cum > top_p |
| cutoff[..., 0] = False |
| sorted_logits[cutoff] = float("-inf") |
| logits = torch.full_like(logits, float("-inf")).scatter(-1, sorted_idx, sorted_logits) |
| probs = F.softmax(logits, dim=-1) |
| return int(torch.multinomial(probs, num_samples=1).item()) |
|
|
|
|
| def generate_until_end( |
| model, |
| prompt_ids: List[int], |
| *, |
| end_id: int, |
| max_new_tokens: int, |
| temperature: float, |
| top_k: int, |
| top_p: float, |
| repetition_penalty: float, |
| ) -> List[int]: |
| """Greedy-or-sampled generation that stops the moment ``end_id`` is sampled.""" |
| ids = torch.tensor(prompt_ids, dtype=torch.long).unsqueeze(0) |
| new_ids: List[int] = [] |
| ctx = model.config.context_length |
| with torch.no_grad(): |
| for _ in range(max_new_tokens): |
| cond = ids[:, -ctx:] |
| logits, _ = model(cond) |
| next_logits = logits[:, -1, :].squeeze(0).clone() |
| tok = sample_next( |
| next_logits, |
| temperature=temperature, top_k=top_k, top_p=top_p, |
| prev_ids=new_ids[-128:], repetition_penalty=repetition_penalty, |
| ) |
| if tok == end_id: |
| break |
| new_ids.append(tok) |
| ids = torch.cat([ids, torch.tensor([[tok]])], dim=1) |
| return new_ids |
|
|
|
|
| def generate_until_end_stream( |
| model, |
| prompt_ids: List[int], |
| *, |
| end_id: int, |
| max_new_tokens: int, |
| temperature: float, |
| top_k: int, |
| top_p: float, |
| repetition_penalty: float, |
| ): |
| """Streaming variant: same as ``generate_until_end`` but yields the |
| growing list of new token ids after every sampled token. |
| |
| Used by Gradio's chat interface so the user sees text appear |
| incrementally rather than waiting 15-25 s for the full response. |
| The yields happen with no extra forward-pass cost; the generator |
| just surfaces what each iteration of the loop produces.""" |
| ids = torch.tensor(prompt_ids, dtype=torch.long).unsqueeze(0) |
| new_ids: List[int] = [] |
| ctx = model.config.context_length |
| with torch.no_grad(): |
| for _ in range(max_new_tokens): |
| cond = ids[:, -ctx:] |
| logits, _ = model(cond) |
| next_logits = logits[:, -1, :].squeeze(0).clone() |
| tok = sample_next( |
| next_logits, |
| temperature=temperature, top_k=top_k, top_p=top_p, |
| prev_ids=new_ids[-128:], repetition_penalty=repetition_penalty, |
| ) |
| if tok == end_id: |
| break |
| new_ids.append(tok) |
| ids = torch.cat([ids, torch.tensor([[tok]])], dim=1) |
| yield new_ids |
|
|
|
|
| |
| |
| |
|
|
| CHECKPOINT_PATH = find_checkpoint() |
| MODEL, CONFIG, LOADED_FROM = load_model(CHECKPOINT_PATH) |
| TOKENIZER = GhostTokenizer() |
| END_ID = TOKENIZER._special_tokens[TOKENIZER.END] |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| RAG_INDEX: Optional[np.ndarray] = None |
| RAG_CHUNKS: Optional[List[dict]] = None |
| RAG_EMBEDDER_TOK = None |
| RAG_EMBEDDER = None |
| RAG_LOAD_ERROR: Optional[str] = None |
|
|
|
|
| def _load_rag() -> None: |
| """Load RAG index + embedder. On any failure leaves everything None |
| and stores the error message so the UI can surface it. The chat |
| handler treats RAG as optional: if it didn't load, generation still |
| works, just bare without retrieval.""" |
| global RAG_INDEX, RAG_CHUNKS, RAG_EMBEDDER_TOK, RAG_EMBEDDER, RAG_LOAD_ERROR |
| try: |
| from huggingface_hub import hf_hub_download |
| print(f"Pulling RAG index from {HUB_REPO}...") |
| index_path = hf_hub_download(repo_id=HUB_REPO, filename="rag/index.npy", repo_type="model") |
| chunks_path = hf_hub_download(repo_id=HUB_REPO, filename="rag/chunks.jsonl", repo_type="model") |
|
|
| idx = np.load(index_path) |
| |
| |
| if idx.dtype != np.float32: |
| idx = idx.astype(np.float32) |
| chunks: List[dict] = [] |
| with open(chunks_path) as f: |
| for line in f: |
| chunks.append(json.loads(line)) |
|
|
| from transformers import AutoModel, AutoTokenizer |
| e_tok = AutoTokenizer.from_pretrained("BAAI/bge-small-en-v1.5") |
| e_model = AutoModel.from_pretrained("BAAI/bge-small-en-v1.5").eval() |
| if os.environ.get("SPACE_ID"): |
| |
| |
| |
| |
| e_model = e_model.half() |
|
|
| RAG_INDEX = idx |
| RAG_CHUNKS = chunks |
| RAG_EMBEDDER_TOK = e_tok |
| RAG_EMBEDDER = e_model |
| print(f"RAG loaded: {len(chunks)} chunks, dim {idx.shape[1]}") |
| except Exception as e: |
| RAG_LOAD_ERROR = f"{type(e).__name__}: {e}" |
| print(f"RAG disabled, falling back to bare chat: {RAG_LOAD_ERROR}") |
|
|
|
|
| _load_rag() |
|
|
|
|
| def retrieve(query: str, k: int = 4) -> List[dict]: |
| """Embed the query and return the top-K chunks by cosine similarity. |
| Returns an empty list if RAG isn't loaded; caller handles that.""" |
| if RAG_INDEX is None or RAG_EMBEDDER is None or RAG_EMBEDDER_TOK is None: |
| return [] |
| |
| text = "Represent this sentence for searching relevant passages: " + query |
| enc = RAG_EMBEDDER_TOK(text, padding=True, truncation=True, |
| max_length=512, return_tensors="pt") |
| with torch.no_grad(): |
| out = RAG_EMBEDDER(**enc) |
| emb = out.last_hidden_state[:, 0] |
| emb = F.normalize(emb, p=2, dim=-1) |
| q_vec = emb.cpu().to(torch.float32).numpy().reshape(-1) |
| scores = RAG_INDEX @ q_vec |
| top = np.argsort(-scores)[:k] |
| return [RAG_CHUNKS[i] for i in top] |
|
|
|
|
| def format_rag_prompt(query: str, passages: List[dict]) -> str: |
| """Wrap the query with retrieved reference passages. The model is |
| not RAFT-trained yet so it just sees this as part of the user |
| message; even without a RAFT pass, retrieval-augmented chat |
| dramatically reduces the bare 81M model's hallucination rate on |
| factual cybersec questions.""" |
| if not passages: |
| return query |
| refs = [] |
| for i, p in enumerate(passages): |
| text = p.get("text", "") |
| if len(text) > 400: |
| text = text[:400].rsplit(" ", 1)[0] + "..." |
| refs.append(f"[{i + 1}] ({p.get('source', '?')} {p.get('ref', '')}) {text}") |
| refs_block = "\n\n".join(refs) |
| return ( |
| "Reference passages from the cybersecurity corpus:\n\n" |
| f"{refs_block}\n\n" |
| "Use the reference passages above to answer the question. If the " |
| "passages don't contain the answer, say so rather than guessing.\n\n" |
| f"Question: {query}" |
| ) |
|
|
|
|
| |
| |
| |
|
|
|
|
| def chat_fn(message: str, history: list, temperature: float, top_k: int, |
| top_p: float, max_tokens: int, repetition_penalty: float) -> str: |
| """Generate one assistant turn given the prior history + new user message. |
| |
| ``history`` may arrive in either Gradio-tuples format |
| ``[(user, bot), ...]`` (older) or messages format |
| ``[{"role", "content"}, ...]`` (newer). We coerce to messages. |
| """ |
| |
| |
| |
| MODEL.eval() |
|
|
| |
| |
| |
| |
| |
| |
| if RAG_INDEX is not None: |
| try: |
| passages = retrieve(message, k=4) |
| if passages: |
| message = format_rag_prompt(message, passages) |
| except Exception as e: |
| print(f"RAG retrieve failed for this turn: {type(e).__name__}: {e}") |
|
|
| turns: list = [] |
| for h in history: |
| if isinstance(h, dict) and h.get("role") in ("user", "assistant"): |
| turns.append({"role": h["role"], "content": h["content"]}) |
| elif isinstance(h, (list, tuple)) and len(h) == 2: |
| user_msg, bot_msg = h |
| if user_msg: |
| turns.append({"role": "user", "content": user_msg}) |
| if bot_msg: |
| turns.append({"role": "assistant", "content": bot_msg}) |
| turns.append({"role": "user", "content": message}) |
|
|
| prompt_ids = TOKENIZER.format_chat_prompt(turns) |
| |
| ctx_budget = CONFIG.context_length - max_tokens - 8 |
| while len(prompt_ids) > ctx_budget and len(turns) > 1: |
| |
| if len(turns) >= 3: |
| del turns[:2] |
| prompt_ids = TOKENIZER.format_chat_prompt(turns) |
| else: |
| break |
|
|
| |
| |
| |
| |
| |
| last_text = "" |
| for new_ids in generate_until_end_stream( |
| MODEL, prompt_ids, |
| end_id=END_ID, |
| max_new_tokens=int(max_tokens), |
| temperature=float(temperature), |
| top_k=int(top_k), |
| top_p=float(top_p), |
| repetition_penalty=float(repetition_penalty), |
| ): |
| text = TOKENIZER.decode(new_ids).strip() |
| if text and text != last_text: |
| last_text = text |
| yield text |
|
|
| if not last_text: |
| yield "(no response)" |
|
|
| |
| |
| |
| |
| if torch.backends.mps.is_available(): |
| torch.mps.empty_cache() |
| elif torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| gc.collect() |
|
|
|
|
| |
| |
| |
|
|
| DESCRIPTION = f""" |
| # GhostLM chat (v0.9) |
| |
| An 81M-parameter cybersecurity language model **trained from scratch** in |
| PyTorch. The pretrain corpus is 273M tokens (PRIMUS-Seed, PRIMUS-FineWeb, |
| NVD CVEs, MITRE ATT&CK, CWE, CAPEC, OWASP, IETF RFCs, Exploit-DB, CTFtime, |
| arXiv cs.CR, plus a fact-dense Q&A set). Architecture: 6 layers · d_model |
| 768 · 12 heads, with RoPE + SwiGLU + RMSNorm. |
| |
| Chat-tuned with supervised fine-tuning on the chat-v3 SFT recipe. The |
| v0.9 chat checkpoint is the **bench winner of the ghost-small line**: |
| |
| - **28.9%** on [CTIBench MCQ](https://huggingface.co/datasets/AI4Sec/cti-bench) |
| full test split (n=2500, 2-permutation debiased text-scoring) |
| - **59.2%** on the in-repo CTF MCQ eval (n=30) |
| - **39.3%** on SecQA (n=210, external) |
| |
| **Honest expectations.** v0.9 wins on multiple-choice, but **free-form |
| fact recall is at the floor of the entire ghost-small rung** (1/50 on a |
| hand-written 50-question fact-recall set, the one "hit" arguably spurious). |
| The model has learned the *register* of cybersec writing (sentence |
| shape, technique vocabulary, OWASP-style cadence) but not the *facts* in |
| any retrievable form. Treat outputs as register-shaped fiction: identity, |
| OOD-refusal, and chat shape work; specific CVE numbers, CVSS scores, dates, |
| and technique IDs are unreliable. Always verify against authoritative |
| sources. |
| |
| The next rung is **ghost-base (~360M, SmolLM2-360M shape)**, gated on |
| rented GPU compute, where literature reports factual recall on cybersec |
| MCQ starting to emerge. Spec at |
| [`docs/ghost_base_spec.md`](https://github.com/joemunene-by/GhostLM/blob/main/docs/ghost_base_spec.md). |
| |
| **Retrieval-augmented mode:** {("**ON**. Each query is augmented with top-4 passages retrieved from a 83K-chunk index of the cybersec corpus (NVD / MITRE / CWE / OWASP / CTFtime / arXiv). The model conditions on real reference text instead of producing register-shaped fiction. Retrieval adds ~1-2 s per reply." if RAG_INDEX is not None else f"**OFF**. RAG could not load at startup (`{RAG_LOAD_ERROR}`). Generation is bare; expect hallucination on factual questions.")} |
| |
| **Loaded checkpoint:** `{LOADED_FROM}` |
| """ |
|
|
| EXAMPLES = [ |
| "What is XSS?", |
| "Explain MITRE ATT&CK technique T1059.", |
| "What does SSRF stand for?", |
| "How does a buffer overflow work?", |
| "Walk me through a typical SQL injection attack.", |
| "What's the difference between CVE and CWE?", |
| "Where do I start learning cybersecurity?", |
| "Are you ChatGPT?", |
| ] |
|
|
|
|
| with gr.Blocks(title="GhostLM Chat") as demo: |
| gr.Markdown(DESCRIPTION) |
| with gr.Row(): |
| with gr.Column(scale=3): |
| chat = gr.ChatInterface( |
| fn=chat_fn, |
| |
| |
| |
| |
| |
| examples=[[ex, 0.7, 40, 0.95, 200, 1.25] for ex in EXAMPLES], |
| additional_inputs=[ |
| gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Temperature"), |
| gr.Slider(0, 100, value=40, step=1, label="Top-k"), |
| gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p"), |
| gr.Slider(32, 400, value=200, step=8, label="Max tokens"), |
| gr.Slider(1.0, 2.0, value=1.25, step=0.05, label="Repetition penalty"), |
| ], |
| ) |
| gr.Markdown( |
| "Source: [github.com/joemunene-by/GhostLM](https://github.com/joemunene-by/GhostLM)" |
| " · v0.9 weights: [Ghostgim/GhostLM-v0.9-experimental](https://huggingface.co/Ghostgim/GhostLM-v0.9-experimental)" |
| " · The model is small enough to run locally on a laptop CPU. See the GitHub README for instructions." |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| |
| |
| |
| |
| |
| demo.queue(default_concurrency_limit=1, max_size=20).launch() |
|
|