chore: squash history to reclaim LFS storage from removed checkpoint
Browse files- .gitattributes +8 -0
- README.md +157 -0
- app.py +265 -0
- checkpoints/phase5_chat_v3/best_model.pt +3 -0
- ghostlm/__init__.py +19 -0
- ghostlm/config.py +162 -0
- ghostlm/dataset.py +125 -0
- ghostlm/model.py +428 -0
- ghostlm/tokenizer.py +321 -0
- ghostlm/trainer.py +317 -0
- requirements.txt +22 -0
.gitattributes
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hugging Face Spaces LFS rules. Copy this to the Space repo root
|
| 2 |
+
# alongside app.py / requirements.txt / README.md / ghostlm/ /
|
| 3 |
+
# checkpoints/. Without it the ~177 MB checkpoints either fail to push
|
| 4 |
+
# or land as broken pointer files.
|
| 5 |
+
|
| 6 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: GhostLM
|
| 3 |
+
emoji: 🔐
|
| 4 |
+
colorFrom: purple
|
| 5 |
+
colorTo: gray
|
| 6 |
+
sdk: gradio
|
| 7 |
+
app_file: app.py
|
| 8 |
+
pinned: false
|
| 9 |
+
license: apache-2.0
|
| 10 |
+
short_description: From-scratch cybersecurity LM — interactive demo
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
# GhostLM Demo
|
| 14 |
+
|
| 15 |
+
Interactive Gradio UI for the canonical Phase 3.5 ghost-tiny model. Two
|
| 16 |
+
tabs: a single-checkpoint **Generate** view with curated prompt presets
|
| 17 |
+
and a generation history, and an optional **Compare** tab that runs the
|
| 18 |
+
same prompt through two checkpoints side-by-side (the canonical v0.3.5
|
| 19 |
+
vs. the v0.3.7 attempt that regressed).
|
| 20 |
+
|
| 21 |
+
This file is dual-purpose:
|
| 22 |
+
|
| 23 |
+
- **In the GitHub repo** (`demo/README.md`) — documents the demo and
|
| 24 |
+
the deploy steps.
|
| 25 |
+
- **As an HF Space README** — the YAML frontmatter at the top is parsed
|
| 26 |
+
by Hugging Face Spaces as the Space metadata. Keep it intact when
|
| 27 |
+
copying this file to a Space repo.
|
| 28 |
+
|
| 29 |
+
## Run locally
|
| 30 |
+
|
| 31 |
+
From the repo root:
|
| 32 |
+
|
| 33 |
+
```bash
|
| 34 |
+
pip install -r demo/requirements.txt
|
| 35 |
+
PYTHONPATH=. python3 demo/app.py
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
Open `http://localhost:7860`. The demo defaults to
|
| 39 |
+
`checkpoints/phase3.5_balanced/best_model.pt` — pass `--checkpoint` to
|
| 40 |
+
load a different one:
|
| 41 |
+
|
| 42 |
+
```bash
|
| 43 |
+
PYTHONPATH=. python3 demo/app.py --checkpoint checkpoints/phase3.6_exploitdb/best_model.pt
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
To enable the Compare tab, add a second checkpoint:
|
| 47 |
+
|
| 48 |
+
```bash
|
| 49 |
+
PYTHONPATH=. python3 demo/app.py \
|
| 50 |
+
--checkpoint checkpoints/phase3.5_balanced/best_model.pt \
|
| 51 |
+
--compare-checkpoint checkpoints/phase3.6_exploitdb/best_model.pt
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
The same `--share` flag Gradio supports works:
|
| 55 |
+
|
| 56 |
+
```bash
|
| 57 |
+
PYTHONPATH=. python3 demo/app.py --share
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
## Deploy to Hugging Face Spaces
|
| 61 |
+
|
| 62 |
+
A Space is a separate git repo on huggingface.co. The demo here lives
|
| 63 |
+
under `demo/` in the GhostLM repo so the source stays in one place; to
|
| 64 |
+
deploy you copy the demo files plus the `ghostlm/` package and a
|
| 65 |
+
checkpoint into a fresh Space repo.
|
| 66 |
+
|
| 67 |
+
### 1. Create the Space
|
| 68 |
+
|
| 69 |
+
Either via the Hugging Face web UI (New → Space, SDK = Gradio) or via
|
| 70 |
+
CLI:
|
| 71 |
+
|
| 72 |
+
```bash
|
| 73 |
+
pip install huggingface_hub
|
| 74 |
+
huggingface-cli login
|
| 75 |
+
huggingface-cli repo create ghostlm --type space --space-sdk gradio
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
Replace `ghostlm` with your preferred Space name.
|
| 79 |
+
|
| 80 |
+
### 2. Clone the Space repo and stage files
|
| 81 |
+
|
| 82 |
+
```bash
|
| 83 |
+
git clone https://huggingface.co/spaces/<your-user>/ghostlm hf-space
|
| 84 |
+
cd hf-space
|
| 85 |
+
|
| 86 |
+
# Track the checkpoint via LFS (it's ~177 MB)
|
| 87 |
+
git lfs install
|
| 88 |
+
git lfs track "*.pt"
|
| 89 |
+
|
| 90 |
+
# Copy the demo + the ghostlm package + the canonical checkpoint
|
| 91 |
+
cp ../demo/app.py .
|
| 92 |
+
cp ../demo/requirements.txt .
|
| 93 |
+
cp ../demo/README.md .
|
| 94 |
+
cp -r ../ghostlm .
|
| 95 |
+
mkdir -p checkpoints/phase3.5_balanced
|
| 96 |
+
cp ../checkpoints/phase3.5_balanced/best_model.pt checkpoints/phase3.5_balanced/
|
| 97 |
+
|
| 98 |
+
git add .
|
| 99 |
+
git commit -m "Initial GhostLM Space deploy"
|
| 100 |
+
git push
|
| 101 |
+
```
|
| 102 |
+
|
| 103 |
+
The Space will start building automatically; first build takes ~3–5
|
| 104 |
+
minutes (gradio + torch wheel install + checkpoint LFS pull). The
|
| 105 |
+
README's frontmatter tells HF this is a Gradio Space, sets the colors,
|
| 106 |
+
and pins `app_file: app.py`.
|
| 107 |
+
|
| 108 |
+
### 3. Optional — include the Phase 3.6 checkpoint for the Compare tab
|
| 109 |
+
|
| 110 |
+
If you want the Compare tab live in the Space, also copy the Phase 3.6
|
| 111 |
+
checkpoint (~177 MB more) and set the env var in the Space's Settings
|
| 112 |
+
page:
|
| 113 |
+
|
| 114 |
+
```bash
|
| 115 |
+
mkdir -p checkpoints/phase3.6_exploitdb
|
| 116 |
+
cp ../checkpoints/phase3.6_exploitdb/best_model.pt checkpoints/phase3.6_exploitdb/
|
| 117 |
+
git add checkpoints/phase3.6_exploitdb
|
| 118 |
+
git commit -m "Add Phase 3.6 for compare tab"
|
| 119 |
+
git push
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
In the Space's **Settings → Variables**, add:
|
| 123 |
+
|
| 124 |
+
```
|
| 125 |
+
GHOSTLM_COMPARE_CHECKPOINT = checkpoints/phase3.6_exploitdb/best_model.pt
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
The Space restarts automatically. The Compare tab will now be visible.
|
| 129 |
+
|
| 130 |
+
### 4. Updates
|
| 131 |
+
|
| 132 |
+
Push to the Space repo whenever the demo changes; the Space rebuilds.
|
| 133 |
+
For a checkpoint update push the new `.pt` file (LFS handles it).
|
| 134 |
+
|
| 135 |
+
## What it looks like
|
| 136 |
+
|
| 137 |
+
The **Generate** tab gives you a prompt textbox, three sampling sliders
|
| 138 |
+
(max tokens, temperature, top-k), and a continuation panel. Below that,
|
| 139 |
+
collapsible accordions group the preset prompts by register (CVE / MITRE
|
| 140 |
+
/ CTF / CAPEC / free-form) so visitors can immediately see what kind of
|
| 141 |
+
prose the model knows. A history panel keeps the last five generations
|
| 142 |
+
visible.
|
| 143 |
+
|
| 144 |
+
The **Compare** tab — only shown when a second checkpoint is loaded —
|
| 145 |
+
sends the same prompt + sampling settings to both models in turn so the
|
| 146 |
+
Phase 3.5 → 3.6 trajectory is visible in real text rather than just
|
| 147 |
+
accuracy numbers.
|
| 148 |
+
|
| 149 |
+
## Why this exists
|
| 150 |
+
|
| 151 |
+
The point of the demo isn't to impress visitors with fluency — at 14.7M
|
| 152 |
+
parameters trained on 8.8M tokens, the model produces register-shaped
|
| 153 |
+
fiction, not knowledge. The point is to make the project's
|
| 154 |
+
trajectory-over-absolute-quality framing concrete: visitors can poke at
|
| 155 |
+
the canonical model, see exactly what it knows and doesn't, and if both
|
| 156 |
+
checkpoints are loaded, see the empirical capacity-ceiling finding for
|
| 157 |
+
themselves.
|
app.py
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""GhostLM Gradio Space — chat UI for the v0.5.0 chat-v3 (CTIBench 36.9%) model.
|
| 2 |
+
|
| 3 |
+
Multi-turn chat using the model's three role tokens
|
| 4 |
+
(<|ghost_user|>, <|ghost_assistant|>, <|ghost_end|>). Generation stops the
|
| 5 |
+
moment the assistant's <|ghost_end|> is sampled. Repetition penalty is on
|
| 6 |
+
by default — without it the 45M model occasionally degenerates into
|
| 7 |
+
"Wifi Wifi Wifi…" loops on small prompts.
|
| 8 |
+
|
| 9 |
+
Runs on Spaces cpu-basic (2 vCPU). Generation is ~5-15 s per reply at
|
| 10 |
+
the default 200-token cap.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import sys
|
| 17 |
+
from dataclasses import fields
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import List
|
| 20 |
+
|
| 21 |
+
import gradio as gr
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn.functional as F
|
| 24 |
+
|
| 25 |
+
REPO_ROOT = Path(__file__).resolve().parent
|
| 26 |
+
if str(REPO_ROOT) not in sys.path:
|
| 27 |
+
sys.path.insert(0, str(REPO_ROOT))
|
| 28 |
+
|
| 29 |
+
from ghostlm.config import GhostLMConfig
|
| 30 |
+
from ghostlm.model import GhostLM
|
| 31 |
+
from ghostlm.tokenizer import GhostTokenizer
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# ---------------------------------------------------------------------------
|
| 35 |
+
# Loading
|
| 36 |
+
# ---------------------------------------------------------------------------
|
| 37 |
+
|
| 38 |
+
CHECKPOINT_CANDIDATES = [
|
| 39 |
+
"checkpoints/phase5_chat_v3/best_model.pt",
|
| 40 |
+
"checkpoints/best_model.pt", # fallback if pushed at the root
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def find_checkpoint() -> str:
|
| 45 |
+
"""Return the first checkpoint path that exists, or empty string."""
|
| 46 |
+
for path in CHECKPOINT_CANDIDATES:
|
| 47 |
+
if Path(path).exists():
|
| 48 |
+
return path
|
| 49 |
+
return ""
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def load_model(path: str):
|
| 53 |
+
"""Load a GhostLM checkpoint into eval mode on CPU."""
|
| 54 |
+
if not path:
|
| 55 |
+
# Random-init fallback so the UI still launches if weights are missing.
|
| 56 |
+
config = GhostLMConfig.from_preset("ghost-tiny")
|
| 57 |
+
config.vocab_size = 50264
|
| 58 |
+
config.context_length = 256
|
| 59 |
+
model = GhostLM(config).eval()
|
| 60 |
+
return model, config, "(random ghost-tiny — weights missing on Space)"
|
| 61 |
+
|
| 62 |
+
ckpt = torch.load(path, map_location="cpu", weights_only=False)
|
| 63 |
+
saved = ckpt["config"]
|
| 64 |
+
config = GhostLMConfig(**{
|
| 65 |
+
f.name: saved[f.name]
|
| 66 |
+
for f in fields(GhostLMConfig)
|
| 67 |
+
if f.name in saved
|
| 68 |
+
})
|
| 69 |
+
model = GhostLM(config)
|
| 70 |
+
state = ckpt.get("model_state_dict", ckpt.get("model"))
|
| 71 |
+
model.load_state_dict(state, strict=False)
|
| 72 |
+
model.eval()
|
| 73 |
+
return model, config, path
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# ---------------------------------------------------------------------------
|
| 77 |
+
# Generation — inlined from scripts/chat.py so the Space stays self-contained.
|
| 78 |
+
# ---------------------------------------------------------------------------
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def sample_next(
|
| 82 |
+
logits: torch.Tensor,
|
| 83 |
+
*,
|
| 84 |
+
temperature: float,
|
| 85 |
+
top_k: int,
|
| 86 |
+
top_p: float,
|
| 87 |
+
prev_ids: List[int],
|
| 88 |
+
repetition_penalty: float,
|
| 89 |
+
) -> int:
|
| 90 |
+
"""Sample one token from logits with temperature, top-k / top-p, and rep-penalty."""
|
| 91 |
+
if prev_ids and repetition_penalty != 1.0:
|
| 92 |
+
for tok in set(prev_ids):
|
| 93 |
+
if logits[tok] > 0:
|
| 94 |
+
logits[tok] = logits[tok] / repetition_penalty
|
| 95 |
+
else:
|
| 96 |
+
logits[tok] = logits[tok] * repetition_penalty
|
| 97 |
+
logits = logits / max(temperature, 1e-6)
|
| 98 |
+
if top_k and top_k > 0:
|
| 99 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 100 |
+
logits[logits < v[..., -1:]] = float("-inf")
|
| 101 |
+
if top_p and top_p < 1.0:
|
| 102 |
+
sorted_logits, sorted_idx = torch.sort(logits, descending=True)
|
| 103 |
+
probs = F.softmax(sorted_logits, dim=-1)
|
| 104 |
+
cum = probs.cumsum(dim=-1)
|
| 105 |
+
cutoff = cum > top_p
|
| 106 |
+
cutoff[..., 0] = False
|
| 107 |
+
sorted_logits[cutoff] = float("-inf")
|
| 108 |
+
logits = torch.full_like(logits, float("-inf")).scatter(-1, sorted_idx, sorted_logits)
|
| 109 |
+
probs = F.softmax(logits, dim=-1)
|
| 110 |
+
return int(torch.multinomial(probs, num_samples=1).item())
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def generate_until_end(
|
| 114 |
+
model,
|
| 115 |
+
prompt_ids: List[int],
|
| 116 |
+
*,
|
| 117 |
+
end_id: int,
|
| 118 |
+
max_new_tokens: int,
|
| 119 |
+
temperature: float,
|
| 120 |
+
top_k: int,
|
| 121 |
+
top_p: float,
|
| 122 |
+
repetition_penalty: float,
|
| 123 |
+
) -> List[int]:
|
| 124 |
+
"""Greedy-or-sampled generation that stops the moment ``end_id`` is sampled."""
|
| 125 |
+
ids = torch.tensor(prompt_ids, dtype=torch.long).unsqueeze(0)
|
| 126 |
+
new_ids: List[int] = []
|
| 127 |
+
ctx = model.config.context_length
|
| 128 |
+
with torch.no_grad():
|
| 129 |
+
for _ in range(max_new_tokens):
|
| 130 |
+
cond = ids[:, -ctx:]
|
| 131 |
+
logits, _ = model(cond)
|
| 132 |
+
next_logits = logits[:, -1, :].squeeze(0).clone()
|
| 133 |
+
tok = sample_next(
|
| 134 |
+
next_logits,
|
| 135 |
+
temperature=temperature, top_k=top_k, top_p=top_p,
|
| 136 |
+
prev_ids=new_ids[-128:], repetition_penalty=repetition_penalty,
|
| 137 |
+
)
|
| 138 |
+
if tok == end_id:
|
| 139 |
+
break
|
| 140 |
+
new_ids.append(tok)
|
| 141 |
+
ids = torch.cat([ids, torch.tensor([[tok]])], dim=1)
|
| 142 |
+
return new_ids
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
# ---------------------------------------------------------------------------
|
| 146 |
+
# Module-level state
|
| 147 |
+
# ---------------------------------------------------------------------------
|
| 148 |
+
|
| 149 |
+
CHECKPOINT_PATH = find_checkpoint()
|
| 150 |
+
MODEL, CONFIG, LOADED_FROM = load_model(CHECKPOINT_PATH)
|
| 151 |
+
TOKENIZER = GhostTokenizer()
|
| 152 |
+
END_ID = TOKENIZER._special_tokens[TOKENIZER.END]
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
# ---------------------------------------------------------------------------
|
| 156 |
+
# Chat handler
|
| 157 |
+
# ---------------------------------------------------------------------------
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def chat_fn(message: str, history: list, temperature: float, top_k: int,
|
| 161 |
+
top_p: float, max_tokens: int, repetition_penalty: float) -> str:
|
| 162 |
+
"""Generate one assistant turn given the prior history + new user message.
|
| 163 |
+
|
| 164 |
+
``history`` may arrive in either Gradio-tuples format
|
| 165 |
+
``[(user, bot), ...]`` (older) or messages format
|
| 166 |
+
``[{"role", "content"}, ...]`` (newer). We coerce to messages.
|
| 167 |
+
"""
|
| 168 |
+
turns: list = []
|
| 169 |
+
for h in history:
|
| 170 |
+
if isinstance(h, dict) and h.get("role") in ("user", "assistant"):
|
| 171 |
+
turns.append({"role": h["role"], "content": h["content"]})
|
| 172 |
+
elif isinstance(h, (list, tuple)) and len(h) == 2:
|
| 173 |
+
user_msg, bot_msg = h
|
| 174 |
+
if user_msg:
|
| 175 |
+
turns.append({"role": "user", "content": user_msg})
|
| 176 |
+
if bot_msg:
|
| 177 |
+
turns.append({"role": "assistant", "content": bot_msg})
|
| 178 |
+
turns.append({"role": "user", "content": message})
|
| 179 |
+
|
| 180 |
+
prompt_ids = TOKENIZER.format_chat_prompt(turns)
|
| 181 |
+
# Trim conversation if the prompt overflows the context budget.
|
| 182 |
+
ctx_budget = CONFIG.context_length - max_tokens - 8
|
| 183 |
+
while len(prompt_ids) > ctx_budget and len(turns) > 1:
|
| 184 |
+
# Drop the oldest user/assistant pair, but keep the just-asked turn.
|
| 185 |
+
if len(turns) >= 3:
|
| 186 |
+
del turns[:2]
|
| 187 |
+
prompt_ids = TOKENIZER.format_chat_prompt(turns)
|
| 188 |
+
else:
|
| 189 |
+
break
|
| 190 |
+
|
| 191 |
+
new_ids = generate_until_end(
|
| 192 |
+
MODEL, prompt_ids,
|
| 193 |
+
end_id=END_ID,
|
| 194 |
+
max_new_tokens=int(max_tokens),
|
| 195 |
+
temperature=float(temperature),
|
| 196 |
+
top_k=int(top_k),
|
| 197 |
+
top_p=float(top_p),
|
| 198 |
+
repetition_penalty=float(repetition_penalty),
|
| 199 |
+
)
|
| 200 |
+
return TOKENIZER.decode(new_ids).strip() or "(no response)"
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
# ---------------------------------------------------------------------------
|
| 204 |
+
# UI
|
| 205 |
+
# ---------------------------------------------------------------------------
|
| 206 |
+
|
| 207 |
+
DESCRIPTION = f"""
|
| 208 |
+
# GhostLM — chat-v3 (v0.5.0)
|
| 209 |
+
|
| 210 |
+
A 45M-parameter cybersecurity language model **trained from scratch** on
|
| 211 |
+
12.56M tokens of NVD / MITRE ATT&CK / Exploit-DB / CTFtime / arXiv cs.CR
|
| 212 |
+
text. The chat-tuned checkpoint here scored **36.9% on
|
| 213 |
+
[CTIBench MCQ](https://huggingface.co/datasets/AI4Sec/cti-bench)** — 1.48× random for a
|
| 214 |
+
2,500-question security multiple-choice benchmark.
|
| 215 |
+
|
| 216 |
+
**Honest expectations:** identity, OOD-refusal, and chat shape work. Specific
|
| 217 |
+
facts (CVE numbers, CVSS scores, dates, technique IDs) are unreliable —
|
| 218 |
+
the model often confabulates plausible-looking but wrong specifics. Always
|
| 219 |
+
verify against authoritative sources. Outside cybersecurity, the model
|
| 220 |
+
politely declines and returns to its domain.
|
| 221 |
+
|
| 222 |
+
**Loaded checkpoint:** `{LOADED_FROM}`
|
| 223 |
+
"""
|
| 224 |
+
|
| 225 |
+
EXAMPLES = [
|
| 226 |
+
"What is XSS?",
|
| 227 |
+
"Explain MITRE ATT&CK technique T1059.",
|
| 228 |
+
"What does SSRF stand for?",
|
| 229 |
+
"How does a buffer overflow work?",
|
| 230 |
+
"Walk me through a typical SQL injection attack.",
|
| 231 |
+
"What's the difference between CVE and CWE?",
|
| 232 |
+
"Where do I start learning cybersecurity?",
|
| 233 |
+
"Are you ChatGPT?",
|
| 234 |
+
]
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
with gr.Blocks(title="GhostLM Chat") as demo:
|
| 238 |
+
gr.Markdown(DESCRIPTION)
|
| 239 |
+
with gr.Row():
|
| 240 |
+
with gr.Column(scale=3):
|
| 241 |
+
chat = gr.ChatInterface(
|
| 242 |
+
fn=chat_fn,
|
| 243 |
+
# Each example needs values for every additional_input when
|
| 244 |
+
# they're configured below — list-of-lists [message, temp,
|
| 245 |
+
# top_k, top_p, max_tokens, rep_penalty]. The defaults below
|
| 246 |
+
# match the sliders so a user can click an example and get
|
| 247 |
+
# consistent generation settings.
|
| 248 |
+
examples=[[ex, 0.7, 40, 0.95, 200, 1.25] for ex in EXAMPLES],
|
| 249 |
+
additional_inputs=[
|
| 250 |
+
gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Temperature"),
|
| 251 |
+
gr.Slider(0, 100, value=40, step=1, label="Top-k"),
|
| 252 |
+
gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p"),
|
| 253 |
+
gr.Slider(32, 400, value=200, step=8, label="Max tokens"),
|
| 254 |
+
gr.Slider(1.0, 2.0, value=1.25, step=0.05, label="Repetition penalty"),
|
| 255 |
+
],
|
| 256 |
+
)
|
| 257 |
+
gr.Markdown(
|
| 258 |
+
"Source: [github.com/joemunene-by/GhostLM](https://github.com/joemunene-by/GhostLM)"
|
| 259 |
+
" · Weights: [Ghostgim/GhostLM](https://huggingface.co/Ghostgim/GhostLM)"
|
| 260 |
+
" · The model is small enough to run locally — see the GitHub README for instructions."
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
if __name__ == "__main__":
|
| 265 |
+
demo.queue().launch()
|
checkpoints/phase5_chat_v3/best_model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a1c2dbbb3f2559153953cdec8c0e8adbcdf0659fe4b61c3eb05a4e21c6b216f0
|
| 3 |
+
size 542187521
|
ghostlm/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""GhostLM — open-source cybersecurity-focused language model."""
|
| 2 |
+
|
| 3 |
+
from ghostlm.config import GhostLMConfig
|
| 4 |
+
from ghostlm.model import GhostLM
|
| 5 |
+
from ghostlm.tokenizer import GhostTokenizer
|
| 6 |
+
from ghostlm.dataset import GhostDataset, build_dataloaders
|
| 7 |
+
from ghostlm.trainer import GhostTrainer
|
| 8 |
+
|
| 9 |
+
__version__ = "0.1.0"
|
| 10 |
+
__author__ = "Joe Munene"
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"GhostLMConfig",
|
| 14 |
+
"GhostLM",
|
| 15 |
+
"GhostTokenizer",
|
| 16 |
+
"GhostDataset",
|
| 17 |
+
"build_dataloaders",
|
| 18 |
+
"GhostTrainer",
|
| 19 |
+
]
|
ghostlm/config.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""GhostLM configuration — all model and training hyperparameters live here."""
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@dataclass
|
| 7 |
+
class GhostLMConfig:
|
| 8 |
+
"""Configuration dataclass for the GhostLM transformer language model.
|
| 9 |
+
|
| 10 |
+
Holds all hyperparameters for model architecture, training, data paths,
|
| 11 |
+
and system settings. Supports preset configurations and parameter counting.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
# Model architecture
|
| 15 |
+
vocab_size: int = 50257
|
| 16 |
+
context_length: int = 1024
|
| 17 |
+
d_model: int = 512
|
| 18 |
+
n_heads: int = 8
|
| 19 |
+
n_layers: int = 6
|
| 20 |
+
d_ff: int = 2048
|
| 21 |
+
dropout: float = 0.1
|
| 22 |
+
bias: bool = True
|
| 23 |
+
use_rope: bool = False
|
| 24 |
+
use_flash_attention: bool = False
|
| 25 |
+
|
| 26 |
+
# Training
|
| 27 |
+
batch_size: int = 32
|
| 28 |
+
learning_rate: float = 3e-4
|
| 29 |
+
weight_decay: float = 0.1
|
| 30 |
+
beta1: float = 0.9
|
| 31 |
+
beta2: float = 0.95
|
| 32 |
+
grad_clip: float = 1.0
|
| 33 |
+
grad_accum_steps: int = 4
|
| 34 |
+
warmup_steps: int = 2000
|
| 35 |
+
max_steps: int = 100000
|
| 36 |
+
eval_interval: int = 500
|
| 37 |
+
save_interval: int = 1000
|
| 38 |
+
|
| 39 |
+
# Paths
|
| 40 |
+
data_dir: str = "data/processed"
|
| 41 |
+
checkpoint_dir: str = "checkpoints"
|
| 42 |
+
log_dir: str = "logs"
|
| 43 |
+
|
| 44 |
+
# System
|
| 45 |
+
device: str = "auto"
|
| 46 |
+
dtype: str = "float32"
|
| 47 |
+
seed: int = 42
|
| 48 |
+
use_wandb: bool = False
|
| 49 |
+
|
| 50 |
+
def model_size(self) -> str:
|
| 51 |
+
"""Estimate total parameter count and return a human-readable string.
|
| 52 |
+
|
| 53 |
+
Computes the approximate number of trainable parameters based on
|
| 54 |
+
vocab_size, d_model, n_heads, n_layers, and d_ff.
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
A string like "124M" or "1.2B" representing the estimated size.
|
| 58 |
+
"""
|
| 59 |
+
embedding_params = self.vocab_size * self.d_model
|
| 60 |
+
attention_params = self.n_layers * (
|
| 61 |
+
4 * self.d_model * self.d_model + 2 * self.d_model
|
| 62 |
+
)
|
| 63 |
+
ffn_params = self.n_layers * (
|
| 64 |
+
2 * self.d_model * self.d_ff + self.d_model + self.d_ff
|
| 65 |
+
)
|
| 66 |
+
layer_norm_params = self.n_layers * 4 * self.d_model
|
| 67 |
+
output_head_params = self.d_model * self.vocab_size
|
| 68 |
+
|
| 69 |
+
total = embedding_params + attention_params + ffn_params + layer_norm_params + output_head_params
|
| 70 |
+
|
| 71 |
+
if total >= 1e9:
|
| 72 |
+
return f"{total / 1e9:.1f}B"
|
| 73 |
+
elif total >= 1e6:
|
| 74 |
+
return f"{total / 1e6:.0f}M"
|
| 75 |
+
else:
|
| 76 |
+
return f"{total:.0f}K"
|
| 77 |
+
|
| 78 |
+
@classmethod
|
| 79 |
+
def from_preset(cls, preset: str) -> "GhostLMConfig":
|
| 80 |
+
"""Return a GhostLMConfig instance from a named preset.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
preset: One of "ghost-tiny", "ghost-small", or "ghost-medium".
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
A GhostLMConfig configured with the preset's hyperparameters.
|
| 87 |
+
|
| 88 |
+
Raises:
|
| 89 |
+
ValueError: If the preset name is not recognized.
|
| 90 |
+
"""
|
| 91 |
+
presets = {
|
| 92 |
+
"ghost-tiny": {
|
| 93 |
+
"n_layers": 2,
|
| 94 |
+
"d_model": 256,
|
| 95 |
+
"n_heads": 4,
|
| 96 |
+
"d_ff": 1024,
|
| 97 |
+
},
|
| 98 |
+
"ghost-small": {
|
| 99 |
+
"n_layers": 6,
|
| 100 |
+
"d_model": 512,
|
| 101 |
+
"n_heads": 8,
|
| 102 |
+
"d_ff": 2048,
|
| 103 |
+
},
|
| 104 |
+
"ghost-medium": {
|
| 105 |
+
"n_layers": 12,
|
| 106 |
+
"d_model": 768,
|
| 107 |
+
"n_heads": 12,
|
| 108 |
+
"d_ff": 3072,
|
| 109 |
+
},
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
if preset not in presets:
|
| 113 |
+
raise ValueError(
|
| 114 |
+
f"Unknown preset '{preset}'. "
|
| 115 |
+
f"Available presets: {', '.join(presets.keys())}"
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
return cls(**presets[preset])
|
| 119 |
+
|
| 120 |
+
def __repr__(self) -> str:
|
| 121 |
+
"""Return a clean, grouped string summary of all config values.
|
| 122 |
+
|
| 123 |
+
Returns:
|
| 124 |
+
A formatted multi-line string with config values grouped by
|
| 125 |
+
category: Architecture, Training, Paths, and System.
|
| 126 |
+
"""
|
| 127 |
+
lines = [
|
| 128 |
+
"GhostLMConfig",
|
| 129 |
+
"=" * 40,
|
| 130 |
+
"Architecture:",
|
| 131 |
+
f" vocab_size: {self.vocab_size}",
|
| 132 |
+
f" context_length: {self.context_length}",
|
| 133 |
+
f" d_model: {self.d_model}",
|
| 134 |
+
f" n_heads: {self.n_heads}",
|
| 135 |
+
f" n_layers: {self.n_layers}",
|
| 136 |
+
f" d_ff: {self.d_ff}",
|
| 137 |
+
f" dropout: {self.dropout}",
|
| 138 |
+
f" bias: {self.bias}",
|
| 139 |
+
"Training:",
|
| 140 |
+
f" batch_size: {self.batch_size}",
|
| 141 |
+
f" learning_rate: {self.learning_rate}",
|
| 142 |
+
f" weight_decay: {self.weight_decay}",
|
| 143 |
+
f" beta1: {self.beta1}",
|
| 144 |
+
f" beta2: {self.beta2}",
|
| 145 |
+
f" grad_clip: {self.grad_clip}",
|
| 146 |
+
f" warmup_steps: {self.warmup_steps}",
|
| 147 |
+
f" max_steps: {self.max_steps}",
|
| 148 |
+
f" eval_interval: {self.eval_interval}",
|
| 149 |
+
f" save_interval: {self.save_interval}",
|
| 150 |
+
"Paths:",
|
| 151 |
+
f" data_dir: {self.data_dir}",
|
| 152 |
+
f" checkpoint_dir: {self.checkpoint_dir}",
|
| 153 |
+
f" log_dir: {self.log_dir}",
|
| 154 |
+
"System:",
|
| 155 |
+
f" device: {self.device}",
|
| 156 |
+
f" dtype: {self.dtype}",
|
| 157 |
+
f" seed: {self.seed}",
|
| 158 |
+
f" use_wandb: {self.use_wandb}",
|
| 159 |
+
"=" * 40,
|
| 160 |
+
f"Estimated size: {self.model_size()}",
|
| 161 |
+
]
|
| 162 |
+
return "\n".join(lines)
|
ghostlm/dataset.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""GhostLM dataset — converts processed JSONL data into PyTorch DataLoader-ready tensors."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Dict, List, Optional, Tuple
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch.utils.data import DataLoader, Dataset
|
| 9 |
+
|
| 10 |
+
from ghostlm.config import GhostLMConfig
|
| 11 |
+
from ghostlm.tokenizer import GhostTokenizer
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class GhostDataset(Dataset):
|
| 15 |
+
"""PyTorch Dataset for GhostLM language model training.
|
| 16 |
+
|
| 17 |
+
Loads tokenized text from a JSONL file, concatenates all tokens
|
| 18 |
+
into a single flat sequence, and yields fixed-length chunks for
|
| 19 |
+
autoregressive language modeling (x, y shifted by one token).
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, jsonl_path: str, tokenizer: GhostTokenizer, config: GhostLMConfig):
|
| 23 |
+
"""Initialize the dataset from a JSONL file.
|
| 24 |
+
|
| 25 |
+
Reads all records, tokenizes the "text" field of each, and
|
| 26 |
+
concatenates them into one continuous token stream.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
jsonl_path: Path to the processed JSONL file.
|
| 30 |
+
tokenizer: GhostTokenizer instance for encoding text.
|
| 31 |
+
config: GhostLMConfig containing context_length.
|
| 32 |
+
"""
|
| 33 |
+
self.context_length = config.context_length
|
| 34 |
+
self.tokens: List[int] = []
|
| 35 |
+
|
| 36 |
+
with open(jsonl_path, "r", encoding="utf-8") as f:
|
| 37 |
+
for line in f:
|
| 38 |
+
line = line.strip()
|
| 39 |
+
if not line:
|
| 40 |
+
continue
|
| 41 |
+
record = json.loads(line)
|
| 42 |
+
text = record.get("text", "")
|
| 43 |
+
if text:
|
| 44 |
+
self.tokens.extend(tokenizer.encode(text))
|
| 45 |
+
|
| 46 |
+
print(f" Loaded {len(self.tokens):,} tokens from {jsonl_path}")
|
| 47 |
+
|
| 48 |
+
def __len__(self) -> int:
|
| 49 |
+
"""Return the number of non-overlapping context-length chunks.
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
Integer count of available training samples.
|
| 53 |
+
"""
|
| 54 |
+
return len(self.tokens) // self.context_length
|
| 55 |
+
|
| 56 |
+
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 57 |
+
"""Retrieve a single (input, target) token chunk.
|
| 58 |
+
|
| 59 |
+
The target sequence is the input sequence shifted left by one
|
| 60 |
+
token, enabling next-token prediction training.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
idx: Index of the chunk to retrieve.
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
Tuple of (x, y) tensors, each of shape (context_length,).
|
| 67 |
+
"""
|
| 68 |
+
start = idx * self.context_length
|
| 69 |
+
end = start + self.context_length
|
| 70 |
+
|
| 71 |
+
x = self.tokens[start:end]
|
| 72 |
+
y = self.tokens[start + 1 : end + 1]
|
| 73 |
+
|
| 74 |
+
# Pad target with -1 if we hit the end of data (cross-entropy ignores -1)
|
| 75 |
+
if len(y) < len(x):
|
| 76 |
+
y = y + [-1] * (len(x) - len(y))
|
| 77 |
+
|
| 78 |
+
return (
|
| 79 |
+
torch.tensor(x, dtype=torch.long),
|
| 80 |
+
torch.tensor(y, dtype=torch.long),
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def build_dataloaders(
|
| 85 |
+
train_path: str,
|
| 86 |
+
val_path: str,
|
| 87 |
+
tokenizer: GhostTokenizer,
|
| 88 |
+
config: GhostLMConfig,
|
| 89 |
+
) -> Tuple[DataLoader, DataLoader]:
|
| 90 |
+
"""Build train and validation DataLoaders from JSONL files.
|
| 91 |
+
|
| 92 |
+
Creates GhostDataset instances for both splits and wraps them
|
| 93 |
+
in PyTorch DataLoaders with appropriate batching and shuffling.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
train_path: Path to the training JSONL file.
|
| 97 |
+
val_path: Path to the validation JSONL file.
|
| 98 |
+
tokenizer: GhostTokenizer instance for encoding.
|
| 99 |
+
config: GhostLMConfig with batch_size and context_length.
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
Tuple of (train_loader, val_loader).
|
| 103 |
+
"""
|
| 104 |
+
train_dataset = GhostDataset(train_path, tokenizer, config)
|
| 105 |
+
val_dataset = GhostDataset(val_path, tokenizer, config)
|
| 106 |
+
|
| 107 |
+
train_loader = DataLoader(
|
| 108 |
+
train_dataset,
|
| 109 |
+
batch_size=config.batch_size,
|
| 110 |
+
shuffle=True,
|
| 111 |
+
drop_last=True,
|
| 112 |
+
num_workers=0,
|
| 113 |
+
pin_memory=True,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
val_loader = DataLoader(
|
| 117 |
+
val_dataset,
|
| 118 |
+
batch_size=config.batch_size,
|
| 119 |
+
shuffle=False,
|
| 120 |
+
drop_last=False,
|
| 121 |
+
num_workers=0,
|
| 122 |
+
pin_memory=True,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
return train_loader, val_loader
|
ghostlm/model.py
ADDED
|
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""GhostLM transformer model — decoder-only architecture built from scratch in PyTorch."""
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
from ghostlm.config import GhostLMConfig
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class RotaryEmbedding(nn.Module):
|
| 13 |
+
"""Rotary Position Embedding (RoPE).
|
| 14 |
+
|
| 15 |
+
Encodes relative position information directly into the attention
|
| 16 |
+
computation by rotating query and key vectors. Used by LLaMA, Mistral,
|
| 17 |
+
and most modern transformer architectures.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self, head_dim: int, context_length: int, base: float = 10000.0):
|
| 21 |
+
super().__init__()
|
| 22 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
|
| 23 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 24 |
+
|
| 25 |
+
# Precompute cos/sin for all positions
|
| 26 |
+
t = torch.arange(context_length).float()
|
| 27 |
+
freqs = torch.outer(t, inv_freq)
|
| 28 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 29 |
+
self.register_buffer("cos_cached", emb.cos(), persistent=False)
|
| 30 |
+
self.register_buffer("sin_cached", emb.sin(), persistent=False)
|
| 31 |
+
|
| 32 |
+
def forward(self, seq_len: int):
|
| 33 |
+
return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _rotate_half(x):
|
| 37 |
+
"""Rotate the second half of the last dimension and negate it."""
|
| 38 |
+
x1, x2 = x.chunk(2, dim=-1)
|
| 39 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def apply_rotary_pos_emb(q, k, cos, sin):
|
| 43 |
+
"""Apply rotary position embeddings to query and key tensors.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
q: Query tensor of shape (B, n_heads, T, head_dim).
|
| 47 |
+
k: Key tensor of shape (B, n_heads, T, head_dim).
|
| 48 |
+
cos: Cosine frequencies of shape (T, head_dim).
|
| 49 |
+
sin: Sine frequencies of shape (T, head_dim).
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
Tuple of (rotated_q, rotated_k).
|
| 53 |
+
"""
|
| 54 |
+
cos = cos.unsqueeze(0).unsqueeze(0) # (1, 1, T, head_dim)
|
| 55 |
+
sin = sin.unsqueeze(0).unsqueeze(0)
|
| 56 |
+
q_embed = (q * cos) + (_rotate_half(q) * sin)
|
| 57 |
+
k_embed = (k * cos) + (_rotate_half(k) * sin)
|
| 58 |
+
return q_embed, k_embed
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class CausalSelfAttention(nn.Module):
|
| 62 |
+
"""Multi-head causal self-attention with autoregressive masking.
|
| 63 |
+
|
| 64 |
+
Uses a single combined QKV projection for efficiency, then splits
|
| 65 |
+
the result into separate query, key, and value tensors. Supports
|
| 66 |
+
optional RoPE (Rotary Position Embeddings) and Flash Attention.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
def __init__(self, config: GhostLMConfig):
|
| 70 |
+
"""Initialize causal self-attention.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
config: GhostLMConfig containing d_model, n_heads, dropout,
|
| 74 |
+
context_length, bias, use_rope, and use_flash_attention.
|
| 75 |
+
"""
|
| 76 |
+
super().__init__()
|
| 77 |
+
assert config.d_model % config.n_heads == 0, "d_model must be divisible by n_heads"
|
| 78 |
+
|
| 79 |
+
self.n_heads = config.n_heads
|
| 80 |
+
self.head_dim = config.d_model // config.n_heads
|
| 81 |
+
self.context_length = config.context_length
|
| 82 |
+
self.use_rope = config.use_rope
|
| 83 |
+
self.use_flash_attention = config.use_flash_attention
|
| 84 |
+
self.dropout_p = config.dropout
|
| 85 |
+
|
| 86 |
+
# Single combined QKV projection
|
| 87 |
+
self.c_qkv = nn.Linear(config.d_model, 3 * config.d_model, bias=config.bias)
|
| 88 |
+
self.proj = nn.Linear(config.d_model, config.d_model, bias=config.bias)
|
| 89 |
+
|
| 90 |
+
# Dropout applied to attention weights (manual path only)
|
| 91 |
+
self.attn_dropout = nn.Dropout(config.dropout)
|
| 92 |
+
self.resid_dropout = nn.Dropout(config.dropout)
|
| 93 |
+
|
| 94 |
+
# RoPE
|
| 95 |
+
if self.use_rope:
|
| 96 |
+
self.rope = RotaryEmbedding(self.head_dim, config.context_length)
|
| 97 |
+
|
| 98 |
+
# Causal mask buffer (only needed for manual attention path)
|
| 99 |
+
if not self.use_flash_attention:
|
| 100 |
+
self.register_buffer(
|
| 101 |
+
"causal_mask",
|
| 102 |
+
torch.tril(torch.ones(config.context_length, config.context_length))
|
| 103 |
+
.view(1, 1, config.context_length, config.context_length),
|
| 104 |
+
persistent=False,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
def forward(self, x):
|
| 108 |
+
"""Forward pass through causal self-attention.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
x: Input tensor of shape (B, T, d_model).
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
Output tensor of shape (B, T, d_model).
|
| 115 |
+
"""
|
| 116 |
+
B, T, C = x.size()
|
| 117 |
+
|
| 118 |
+
# Combined QKV projection and split
|
| 119 |
+
qkv = self.c_qkv(x)
|
| 120 |
+
q, k, v = qkv.split(self.n_heads * self.head_dim, dim=-1)
|
| 121 |
+
|
| 122 |
+
# Reshape to (B, n_heads, T, head_dim)
|
| 123 |
+
q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
| 124 |
+
k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
| 125 |
+
v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
| 126 |
+
|
| 127 |
+
# Apply RoPE to Q and K (not V)
|
| 128 |
+
if self.use_rope:
|
| 129 |
+
cos, sin = self.rope(T)
|
| 130 |
+
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
| 131 |
+
|
| 132 |
+
if self.use_flash_attention:
|
| 133 |
+
# PyTorch 2.0+ scaled_dot_product_attention with automatic backend selection
|
| 134 |
+
y = F.scaled_dot_product_attention(
|
| 135 |
+
q, k, v,
|
| 136 |
+
attn_mask=None,
|
| 137 |
+
dropout_p=self.dropout_p if self.training else 0.0,
|
| 138 |
+
is_causal=True,
|
| 139 |
+
)
|
| 140 |
+
else:
|
| 141 |
+
# Manual attention path
|
| 142 |
+
scale = 1.0 / math.sqrt(self.head_dim)
|
| 143 |
+
att = (q @ k.transpose(-2, -1)) * scale
|
| 144 |
+
att = att.masked_fill(self.causal_mask[:, :, :T, :T] == 0, float("-inf"))
|
| 145 |
+
att = F.softmax(att, dim=-1)
|
| 146 |
+
att = self.attn_dropout(att)
|
| 147 |
+
y = att @ v
|
| 148 |
+
|
| 149 |
+
# Reassemble heads and project
|
| 150 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
| 151 |
+
y = self.resid_dropout(self.proj(y))
|
| 152 |
+
|
| 153 |
+
return y
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class FeedForward(nn.Module):
|
| 157 |
+
"""Position-wise feed-forward network with GELU activation.
|
| 158 |
+
|
| 159 |
+
Two linear layers with an intermediate GELU non-linearity:
|
| 160 |
+
d_model -> d_ff -> d_model, with dropout after the second layer.
|
| 161 |
+
"""
|
| 162 |
+
|
| 163 |
+
def __init__(self, config: GhostLMConfig):
|
| 164 |
+
"""Initialize the feed-forward network.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
config: GhostLMConfig containing d_model, d_ff, dropout, and bias.
|
| 168 |
+
"""
|
| 169 |
+
super().__init__()
|
| 170 |
+
self.fc1 = nn.Linear(config.d_model, config.d_ff, bias=config.bias)
|
| 171 |
+
self.fc2 = nn.Linear(config.d_ff, config.d_model, bias=config.bias)
|
| 172 |
+
self.dropout = nn.Dropout(config.dropout)
|
| 173 |
+
|
| 174 |
+
def forward(self, x):
|
| 175 |
+
"""Forward pass through the feed-forward network.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
x: Input tensor of shape (B, T, d_model).
|
| 179 |
+
|
| 180 |
+
Returns:
|
| 181 |
+
Output tensor of shape (B, T, d_model).
|
| 182 |
+
"""
|
| 183 |
+
x = self.fc1(x)
|
| 184 |
+
x = F.gelu(x)
|
| 185 |
+
x = self.fc2(x)
|
| 186 |
+
x = self.dropout(x)
|
| 187 |
+
return x
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class TransformerBlock(nn.Module):
|
| 191 |
+
"""Single transformer decoder block with pre-normalization.
|
| 192 |
+
|
| 193 |
+
Applies LayerNorm before both the self-attention and feed-forward
|
| 194 |
+
sub-layers (pre-norm architecture), with residual connections
|
| 195 |
+
around each sub-layer.
|
| 196 |
+
"""
|
| 197 |
+
|
| 198 |
+
def __init__(self, config: GhostLMConfig):
|
| 199 |
+
"""Initialize the transformer block.
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
config: GhostLMConfig passed to sub-modules.
|
| 203 |
+
"""
|
| 204 |
+
super().__init__()
|
| 205 |
+
self.ln_1 = nn.LayerNorm(config.d_model)
|
| 206 |
+
self.attn = CausalSelfAttention(config)
|
| 207 |
+
self.ln_2 = nn.LayerNorm(config.d_model)
|
| 208 |
+
self.ffn = FeedForward(config)
|
| 209 |
+
|
| 210 |
+
def forward(self, x):
|
| 211 |
+
"""Forward pass through the transformer block.
|
| 212 |
+
|
| 213 |
+
Args:
|
| 214 |
+
x: Input tensor of shape (B, T, d_model).
|
| 215 |
+
|
| 216 |
+
Returns:
|
| 217 |
+
Output tensor of shape (B, T, d_model).
|
| 218 |
+
"""
|
| 219 |
+
# Pre-norm + self-attention with residual
|
| 220 |
+
x = x + self.attn(self.ln_1(x))
|
| 221 |
+
# Pre-norm + feed-forward with residual
|
| 222 |
+
x = x + self.ffn(self.ln_2(x))
|
| 223 |
+
return x
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
class GhostLM(nn.Module):
|
| 227 |
+
"""GhostLM decoder-only transformer language model.
|
| 228 |
+
|
| 229 |
+
Built from scratch in PyTorch with learned positional embeddings,
|
| 230 |
+
stacked transformer blocks, and weight-tied output projection.
|
| 231 |
+
"""
|
| 232 |
+
|
| 233 |
+
def __init__(self, config: GhostLMConfig):
|
| 234 |
+
"""Initialize the GhostLM model.
|
| 235 |
+
|
| 236 |
+
Args:
|
| 237 |
+
config: GhostLMConfig with all model hyperparameters.
|
| 238 |
+
"""
|
| 239 |
+
super().__init__()
|
| 240 |
+
self.config = config
|
| 241 |
+
|
| 242 |
+
# Embeddings
|
| 243 |
+
self.token_embedding = nn.Embedding(config.vocab_size, config.d_model)
|
| 244 |
+
if not config.use_rope:
|
| 245 |
+
self.pos_embedding = nn.Embedding(config.context_length, config.d_model)
|
| 246 |
+
self.dropout = nn.Dropout(config.dropout)
|
| 247 |
+
|
| 248 |
+
# Transformer blocks
|
| 249 |
+
self.blocks = nn.ModuleList(
|
| 250 |
+
[TransformerBlock(config) for _ in range(config.n_layers)]
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
# Final layer norm
|
| 254 |
+
self.ln_f = nn.LayerNorm(config.d_model)
|
| 255 |
+
|
| 256 |
+
# Output head with weight tying (no bias)
|
| 257 |
+
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
| 258 |
+
self.lm_head.weight = self.token_embedding.weight
|
| 259 |
+
|
| 260 |
+
# Initialize weights
|
| 261 |
+
self.apply(self._init_weights)
|
| 262 |
+
|
| 263 |
+
# Apply scaled residual initialization for deeper models
|
| 264 |
+
for pn, p in self.named_parameters():
|
| 265 |
+
if pn.endswith("proj.weight") or pn.endswith("fc2.weight"):
|
| 266 |
+
torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layers))
|
| 267 |
+
|
| 268 |
+
def _init_weights(self, module):
|
| 269 |
+
"""Initialize module weights with a normal distribution.
|
| 270 |
+
|
| 271 |
+
Args:
|
| 272 |
+
module: nn.Module to initialize.
|
| 273 |
+
"""
|
| 274 |
+
if isinstance(module, nn.Linear):
|
| 275 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 276 |
+
if module.bias is not None:
|
| 277 |
+
torch.nn.init.zeros_(module.bias)
|
| 278 |
+
elif isinstance(module, nn.Embedding):
|
| 279 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 280 |
+
|
| 281 |
+
def forward(self, idx, targets=None):
|
| 282 |
+
"""Forward pass of the model.
|
| 283 |
+
|
| 284 |
+
Args:
|
| 285 |
+
idx: Input token ids of shape (B, T).
|
| 286 |
+
targets: Optional target token ids of shape (B, T) for loss computation.
|
| 287 |
+
|
| 288 |
+
Returns:
|
| 289 |
+
Tuple of (logits, loss). Logits have shape (B, T, vocab_size).
|
| 290 |
+
Loss is returned only if targets are provided.
|
| 291 |
+
|
| 292 |
+
Raises:
|
| 293 |
+
AssertionError: If sequence length exceeds context_length.
|
| 294 |
+
"""
|
| 295 |
+
B, T = idx.size()
|
| 296 |
+
assert T <= self.config.context_length, (
|
| 297 |
+
f"Sequence length {T} exceeds context length {self.config.context_length}"
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
# Token + positional embeddings
|
| 301 |
+
tok_emb = self.token_embedding(idx)
|
| 302 |
+
if self.config.use_rope:
|
| 303 |
+
x = self.dropout(tok_emb)
|
| 304 |
+
else:
|
| 305 |
+
pos = torch.arange(0, T, dtype=torch.long, device=idx.device)
|
| 306 |
+
pos_emb = self.pos_embedding(pos)
|
| 307 |
+
x = self.dropout(tok_emb + pos_emb)
|
| 308 |
+
|
| 309 |
+
# Transformer blocks
|
| 310 |
+
for block in self.blocks:
|
| 311 |
+
x = block(x)
|
| 312 |
+
|
| 313 |
+
# Final layer norm
|
| 314 |
+
x = self.ln_f(x)
|
| 315 |
+
|
| 316 |
+
# Output logits
|
| 317 |
+
logits = self.lm_head(x)
|
| 318 |
+
|
| 319 |
+
loss = None
|
| 320 |
+
if targets is not None:
|
| 321 |
+
loss = F.cross_entropy(
|
| 322 |
+
logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
return logits, loss
|
| 326 |
+
|
| 327 |
+
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
|
| 328 |
+
"""Autoregressively generate new tokens.
|
| 329 |
+
|
| 330 |
+
Args:
|
| 331 |
+
idx: Input token ids of shape (B, T) serving as the prompt.
|
| 332 |
+
max_new_tokens: Number of tokens to generate.
|
| 333 |
+
temperature: Sampling temperature (higher = more random).
|
| 334 |
+
top_k: If set, only sample from the top-k most likely tokens.
|
| 335 |
+
|
| 336 |
+
Returns:
|
| 337 |
+
Tensor of shape (B, T + max_new_tokens) with generated tokens.
|
| 338 |
+
"""
|
| 339 |
+
for _ in range(max_new_tokens):
|
| 340 |
+
# Crop context if needed
|
| 341 |
+
idx_cond = idx[:, -self.config.context_length:]
|
| 342 |
+
|
| 343 |
+
# Forward pass
|
| 344 |
+
logits, _ = self(idx_cond)
|
| 345 |
+
|
| 346 |
+
# Take logits at the last position
|
| 347 |
+
logits = logits[:, -1, :] / temperature
|
| 348 |
+
|
| 349 |
+
# Optional top-k filtering
|
| 350 |
+
if top_k is not None:
|
| 351 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 352 |
+
logits[logits < v[:, [-1]]] = float("-inf")
|
| 353 |
+
|
| 354 |
+
# Apply softmax and sample
|
| 355 |
+
probs = F.softmax(logits, dim=-1)
|
| 356 |
+
idx_next = torch.multinomial(probs, num_samples=1)
|
| 357 |
+
|
| 358 |
+
# Append to sequence
|
| 359 |
+
idx = torch.cat((idx, idx_next), dim=1)
|
| 360 |
+
|
| 361 |
+
return idx
|
| 362 |
+
|
| 363 |
+
def num_params(self) -> int:
|
| 364 |
+
"""Return the total number of trainable parameters.
|
| 365 |
+
|
| 366 |
+
Returns:
|
| 367 |
+
Integer count of trainable parameters in the model.
|
| 368 |
+
"""
|
| 369 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 370 |
+
|
| 371 |
+
def configure_optimizers(self, config: GhostLMConfig):
|
| 372 |
+
"""Create an AdamW optimizer with weight decay separation.
|
| 373 |
+
|
| 374 |
+
Separates parameters into two groups: those that should receive
|
| 375 |
+
weight decay (linear weights) and those that should not
|
| 376 |
+
(biases, LayerNorm weights, embeddings).
|
| 377 |
+
|
| 378 |
+
Args:
|
| 379 |
+
config: GhostLMConfig containing learning_rate, betas, and weight_decay.
|
| 380 |
+
|
| 381 |
+
Returns:
|
| 382 |
+
torch.optim.AdamW optimizer with properly configured parameter groups.
|
| 383 |
+
"""
|
| 384 |
+
decay = set()
|
| 385 |
+
no_decay = set()
|
| 386 |
+
|
| 387 |
+
whitelist = (nn.Linear,)
|
| 388 |
+
blacklist = (nn.LayerNorm, nn.Embedding)
|
| 389 |
+
|
| 390 |
+
for mn, m in self.named_modules():
|
| 391 |
+
for pn, p in m.named_parameters():
|
| 392 |
+
fpn = f"{mn}.{pn}" if mn else pn
|
| 393 |
+
|
| 394 |
+
if pn.endswith("bias"):
|
| 395 |
+
no_decay.add(fpn)
|
| 396 |
+
elif pn.endswith("weight") and isinstance(m, whitelist):
|
| 397 |
+
decay.add(fpn)
|
| 398 |
+
elif pn.endswith("weight") and isinstance(m, blacklist):
|
| 399 |
+
no_decay.add(fpn)
|
| 400 |
+
|
| 401 |
+
# Remove lm_head.weight from decay if present — it is tied to token_embedding.weight
|
| 402 |
+
decay.discard("lm_head.weight")
|
| 403 |
+
no_decay.discard("lm_head.weight")
|
| 404 |
+
|
| 405 |
+
# Validate all parameters are accounted for (excluding tied weight)
|
| 406 |
+
param_dict = {pn: p for pn, p in self.named_parameters()}
|
| 407 |
+
all_params = decay | no_decay
|
| 408 |
+
uncategorized = {k for k in param_dict.keys() if k not in all_params and k != "lm_head.weight"}
|
| 409 |
+
assert len(uncategorized) == 0, f"Parameters {uncategorized} not categorized"
|
| 410 |
+
|
| 411 |
+
optim_groups = [
|
| 412 |
+
{
|
| 413 |
+
"params": [param_dict[pn] for pn in sorted(decay)],
|
| 414 |
+
"weight_decay": config.weight_decay,
|
| 415 |
+
},
|
| 416 |
+
{
|
| 417 |
+
"params": [param_dict[pn] for pn in sorted(no_decay)],
|
| 418 |
+
"weight_decay": 0.0,
|
| 419 |
+
},
|
| 420 |
+
]
|
| 421 |
+
|
| 422 |
+
optimizer = torch.optim.AdamW(
|
| 423 |
+
optim_groups,
|
| 424 |
+
lr=config.learning_rate,
|
| 425 |
+
betas=(config.beta1, config.beta2),
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
return optimizer
|
ghostlm/tokenizer.py
ADDED
|
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""GhostLM tokenizer — wraps tiktoken's GPT-2 BPE tokenizer with cybersecurity-aware utilities."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import List, Optional, Union
|
| 7 |
+
|
| 8 |
+
import tiktoken
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class GhostTokenizer:
|
| 13 |
+
"""Wrapper around tiktoken GPT-2 BPE tokenizer with GhostLM utilities.
|
| 14 |
+
|
| 15 |
+
Provides encoding, decoding, batching, padding, and text chunking
|
| 16 |
+
utilities tailored for cybersecurity document processing.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
# Special token strings
|
| 20 |
+
BOS = "<|ghost_bos|>"
|
| 21 |
+
EOS = "<|ghost_eos|>"
|
| 22 |
+
PAD = "<|ghost_pad|>"
|
| 23 |
+
UNK = "<|ghost_unk|>"
|
| 24 |
+
# Chat role markers (added in v0.5 chat-tuning) — IDs appended after the
|
| 25 |
+
# original four so pre-chat checkpoints can be expanded by 3 rows rather
|
| 26 |
+
# than reshuffled.
|
| 27 |
+
USER = "<|ghost_user|>"
|
| 28 |
+
ASSISTANT = "<|ghost_assistant|>"
|
| 29 |
+
END = "<|ghost_end|>"
|
| 30 |
+
|
| 31 |
+
def __init__(self):
|
| 32 |
+
"""Initialize the GhostTokenizer with the GPT-2 BPE encoding.
|
| 33 |
+
|
| 34 |
+
Loads the tiktoken gpt2 encoding and assigns special token IDs
|
| 35 |
+
beyond the standard vocabulary for begin-of-sequence, end-of-sequence,
|
| 36 |
+
padding, unknown, and chat role markers.
|
| 37 |
+
"""
|
| 38 |
+
self._encoder = tiktoken.get_encoding("gpt2")
|
| 39 |
+
self._vocab_size = self._encoder.n_vocab
|
| 40 |
+
|
| 41 |
+
# Assign special token IDs beyond the base vocabulary
|
| 42 |
+
self._special_tokens = {
|
| 43 |
+
self.BOS: self._vocab_size,
|
| 44 |
+
self.EOS: self._vocab_size + 1,
|
| 45 |
+
self.PAD: self._vocab_size + 2,
|
| 46 |
+
self.UNK: self._vocab_size + 3,
|
| 47 |
+
self.USER: self._vocab_size + 4,
|
| 48 |
+
self.ASSISTANT: self._vocab_size + 5,
|
| 49 |
+
self.END: self._vocab_size + 6,
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
# Reverse mapping for quick lookup
|
| 53 |
+
self._id_to_special = {v: k for k, v in self._special_tokens.items()}
|
| 54 |
+
|
| 55 |
+
@property
|
| 56 |
+
def vocab_size(self) -> int:
|
| 57 |
+
"""Return the effective vocabulary size including special tokens.
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
Total vocabulary size (base vocab + 7 special tokens).
|
| 61 |
+
"""
|
| 62 |
+
return self._vocab_size + len(self._special_tokens)
|
| 63 |
+
|
| 64 |
+
def _special_token_ids(self) -> set:
|
| 65 |
+
"""Return a set of all special token IDs.
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
Set of integer token IDs reserved for special tokens.
|
| 69 |
+
"""
|
| 70 |
+
return set(self._special_tokens.values())
|
| 71 |
+
|
| 72 |
+
def encode(self, text: str, add_bos: bool = False, add_eos: bool = False) -> List[int]:
|
| 73 |
+
"""Encode a text string into a list of token IDs.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
text: Input text to encode.
|
| 77 |
+
add_bos: If True, prepend the BOS token ID.
|
| 78 |
+
add_eos: If True, append the EOS token ID.
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
List of integer token IDs.
|
| 82 |
+
"""
|
| 83 |
+
ids = self._encoder.encode(text, allowed_special="all")
|
| 84 |
+
|
| 85 |
+
if add_bos:
|
| 86 |
+
ids = [self._special_tokens[self.BOS]] + ids
|
| 87 |
+
if add_eos:
|
| 88 |
+
ids = ids + [self._special_tokens[self.EOS]]
|
| 89 |
+
|
| 90 |
+
return ids
|
| 91 |
+
|
| 92 |
+
def decode(self, ids: List[int], skip_special: bool = True) -> str:
|
| 93 |
+
"""Decode a list of token IDs back into a text string.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
ids: List of integer token IDs to decode.
|
| 97 |
+
skip_special: If True, filter out special token IDs before decoding.
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
Decoded text string.
|
| 101 |
+
"""
|
| 102 |
+
if skip_special:
|
| 103 |
+
special_ids = self._special_token_ids()
|
| 104 |
+
ids = [i for i in ids if i not in special_ids]
|
| 105 |
+
|
| 106 |
+
return self._encoder.decode(ids)
|
| 107 |
+
|
| 108 |
+
def encode_chat(self, turns: List[dict]) -> tuple:
|
| 109 |
+
"""Encode a multi-turn chat conversation with role markers and a loss mask.
|
| 110 |
+
|
| 111 |
+
Format: <|ghost_user|>{content}<|ghost_end|><|ghost_assistant|>{content}<|ghost_end|>...
|
| 112 |
+
The loss mask is 1 on assistant content tokens and the assistant's trailing
|
| 113 |
+
<|ghost_end|> (so the model learns to stop), and 0 everywhere else (user
|
| 114 |
+
prompts and role markers themselves).
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
turns: List of {"role": "user"|"assistant", "content": str} dicts,
|
| 118 |
+
strictly alternating starting with "user".
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
Tuple (token_ids, loss_mask) — same length, both lists of int.
|
| 122 |
+
"""
|
| 123 |
+
user_id = self._special_tokens[self.USER]
|
| 124 |
+
assistant_id = self._special_tokens[self.ASSISTANT]
|
| 125 |
+
end_id = self._special_tokens[self.END]
|
| 126 |
+
|
| 127 |
+
ids: List[int] = []
|
| 128 |
+
mask: List[int] = []
|
| 129 |
+
|
| 130 |
+
for turn in turns:
|
| 131 |
+
role = turn["role"]
|
| 132 |
+
content_ids = self._encoder.encode(turn["content"], allowed_special="all")
|
| 133 |
+
if role == "user":
|
| 134 |
+
ids.append(user_id)
|
| 135 |
+
mask.append(0)
|
| 136 |
+
ids.extend(content_ids)
|
| 137 |
+
mask.extend([0] * len(content_ids))
|
| 138 |
+
ids.append(end_id)
|
| 139 |
+
mask.append(0)
|
| 140 |
+
elif role == "assistant":
|
| 141 |
+
ids.append(assistant_id)
|
| 142 |
+
mask.append(0)
|
| 143 |
+
ids.extend(content_ids)
|
| 144 |
+
mask.extend([1] * len(content_ids))
|
| 145 |
+
ids.append(end_id)
|
| 146 |
+
mask.append(1)
|
| 147 |
+
else:
|
| 148 |
+
raise ValueError(f"Unknown role: {role!r}")
|
| 149 |
+
|
| 150 |
+
return ids, mask
|
| 151 |
+
|
| 152 |
+
def format_chat_prompt(self, turns: List[dict]) -> List[int]:
|
| 153 |
+
"""Encode a chat history and append <|ghost_assistant|> ready for generation.
|
| 154 |
+
|
| 155 |
+
Used at inference: feed the resulting token ids to the model; it should
|
| 156 |
+
generate the assistant's reply followed by <|ghost_end|>.
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
turns: List of {"role": "user"|"assistant", "content": str}, ending
|
| 160 |
+
with a "user" turn (the prompt awaiting a reply).
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
List of token IDs ending in the assistant role marker.
|
| 164 |
+
"""
|
| 165 |
+
ids, _ = self.encode_chat(turns)
|
| 166 |
+
ids.append(self._special_tokens[self.ASSISTANT])
|
| 167 |
+
return ids
|
| 168 |
+
|
| 169 |
+
def encode_batch(self, texts: List[str], add_bos: bool = False, add_eos: bool = False) -> List[List[int]]:
|
| 170 |
+
"""Encode a list of text strings into lists of token IDs.
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
texts: List of input text strings to encode.
|
| 174 |
+
add_bos: If True, prepend BOS token ID to each sequence.
|
| 175 |
+
add_eos: If True, append EOS token ID to each sequence.
|
| 176 |
+
|
| 177 |
+
Returns:
|
| 178 |
+
List of lists of integer token IDs, one per input text.
|
| 179 |
+
"""
|
| 180 |
+
return [self.encode(text, add_bos=add_bos, add_eos=add_eos) for text in texts]
|
| 181 |
+
|
| 182 |
+
def to_tensor(self, ids: List[int], device: str = "cpu") -> torch.Tensor:
|
| 183 |
+
"""Convert a list of token IDs to a PyTorch tensor.
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
ids: List of integer token IDs.
|
| 187 |
+
device: Target device for the tensor (default: "cpu").
|
| 188 |
+
|
| 189 |
+
Returns:
|
| 190 |
+
torch.LongTensor of shape (1, len(ids)).
|
| 191 |
+
"""
|
| 192 |
+
return torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0)
|
| 193 |
+
|
| 194 |
+
def pad_batch(self, batch: List[List[int]], pad_left: bool = False) -> tuple:
|
| 195 |
+
"""Pad a batch of token ID lists to the same length.
|
| 196 |
+
|
| 197 |
+
Pads all sequences in the batch to the length of the longest sequence
|
| 198 |
+
using the PAD token ID. Returns both the padded tensor and an attention
|
| 199 |
+
mask indicating real tokens (1) vs padding (0).
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
batch: List of token ID lists, each potentially different length.
|
| 203 |
+
pad_left: If True, pad on the left side (useful for generation).
|
| 204 |
+
If False, pad on the right side (default).
|
| 205 |
+
|
| 206 |
+
Returns:
|
| 207 |
+
Tuple of (padded_tensor, attention_mask) where:
|
| 208 |
+
- padded_tensor: torch.LongTensor of shape (batch_size, max_len)
|
| 209 |
+
- attention_mask: torch.LongTensor of shape (batch_size, max_len)
|
| 210 |
+
"""
|
| 211 |
+
max_len = max(len(seq) for seq in batch)
|
| 212 |
+
pad_id = self._special_tokens[self.PAD]
|
| 213 |
+
|
| 214 |
+
padded = []
|
| 215 |
+
masks = []
|
| 216 |
+
|
| 217 |
+
for seq in batch:
|
| 218 |
+
pad_count = max_len - len(seq)
|
| 219 |
+
if pad_left:
|
| 220 |
+
padded_seq = [pad_id] * pad_count + seq
|
| 221 |
+
mask = [0] * pad_count + [1] * len(seq)
|
| 222 |
+
else:
|
| 223 |
+
padded_seq = seq + [pad_id] * pad_count
|
| 224 |
+
mask = [1] * len(seq) + [0] * pad_count
|
| 225 |
+
|
| 226 |
+
padded.append(padded_seq)
|
| 227 |
+
masks.append(mask)
|
| 228 |
+
|
| 229 |
+
padded_tensor = torch.tensor(padded, dtype=torch.long)
|
| 230 |
+
mask_tensor = torch.tensor(masks, dtype=torch.long)
|
| 231 |
+
|
| 232 |
+
return padded_tensor, mask_tensor
|
| 233 |
+
|
| 234 |
+
def chunk_text(self, text: str, chunk_size: int = 1024, overlap: int = 64) -> List[List[int]]:
|
| 235 |
+
"""Encode text and split into overlapping token chunks.
|
| 236 |
+
|
| 237 |
+
Useful for processing long cybersecurity documents that exceed
|
| 238 |
+
the model's context length. Overlapping chunks preserve context
|
| 239 |
+
continuity across boundaries.
|
| 240 |
+
|
| 241 |
+
Args:
|
| 242 |
+
text: Input text string to chunk.
|
| 243 |
+
chunk_size: Maximum number of tokens per chunk.
|
| 244 |
+
overlap: Number of overlapping tokens between consecutive chunks.
|
| 245 |
+
|
| 246 |
+
Returns:
|
| 247 |
+
List of token ID lists, each of length at most chunk_size.
|
| 248 |
+
"""
|
| 249 |
+
ids = self.encode(text)
|
| 250 |
+
|
| 251 |
+
if len(ids) <= chunk_size:
|
| 252 |
+
return [ids]
|
| 253 |
+
|
| 254 |
+
chunks = []
|
| 255 |
+
stride = chunk_size - overlap
|
| 256 |
+
|
| 257 |
+
for i in range(0, len(ids), stride):
|
| 258 |
+
chunk = ids[i : i + chunk_size]
|
| 259 |
+
chunks.append(chunk)
|
| 260 |
+
if i + chunk_size >= len(ids):
|
| 261 |
+
break
|
| 262 |
+
|
| 263 |
+
return chunks
|
| 264 |
+
|
| 265 |
+
def save(self, path: str) -> None:
|
| 266 |
+
"""Save tokenizer metadata to a JSON file.
|
| 267 |
+
|
| 268 |
+
Stores vocabulary size, special token strings, and their assigned
|
| 269 |
+
IDs so the tokenizer can be reconstructed later.
|
| 270 |
+
|
| 271 |
+
Args:
|
| 272 |
+
path: File path to save the JSON metadata.
|
| 273 |
+
"""
|
| 274 |
+
metadata = {
|
| 275 |
+
"vocab_size": self._vocab_size,
|
| 276 |
+
"special_tokens": self._special_tokens,
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
Path(path).parent.mkdir(parents=True, exist_ok=True)
|
| 280 |
+
with open(path, "w") as f:
|
| 281 |
+
json.dump(metadata, f, indent=2)
|
| 282 |
+
|
| 283 |
+
@classmethod
|
| 284 |
+
def load(cls, path: str) -> "GhostTokenizer":
|
| 285 |
+
"""Load a GhostTokenizer from saved metadata JSON.
|
| 286 |
+
|
| 287 |
+
Reconstructs the tokenizer by reading special token assignments
|
| 288 |
+
from the saved metadata file.
|
| 289 |
+
|
| 290 |
+
Args:
|
| 291 |
+
path: File path to the saved JSON metadata.
|
| 292 |
+
|
| 293 |
+
Returns:
|
| 294 |
+
GhostTokenizer instance loaded with the saved configuration.
|
| 295 |
+
"""
|
| 296 |
+
with open(path, "r") as f:
|
| 297 |
+
metadata = json.load(f)
|
| 298 |
+
|
| 299 |
+
tokenizer = cls()
|
| 300 |
+
|
| 301 |
+
# Restore special token mappings
|
| 302 |
+
tokenizer._special_tokens = {k: int(v) for k, v in metadata["special_tokens"].items()}
|
| 303 |
+
tokenizer._id_to_special = {v: k for k, v in tokenizer._special_tokens.items()}
|
| 304 |
+
|
| 305 |
+
return tokenizer
|
| 306 |
+
|
| 307 |
+
def __len__(self) -> int:
|
| 308 |
+
"""Return the effective vocabulary size.
|
| 309 |
+
|
| 310 |
+
Returns:
|
| 311 |
+
Integer count of tokens including special tokens.
|
| 312 |
+
"""
|
| 313 |
+
return self.vocab_size
|
| 314 |
+
|
| 315 |
+
def __repr__(self) -> str:
|
| 316 |
+
"""Return a concise string representation of the tokenizer.
|
| 317 |
+
|
| 318 |
+
Returns:
|
| 319 |
+
String like: GhostTokenizer(vocab_size=50261, special_tokens=4)
|
| 320 |
+
"""
|
| 321 |
+
return f"GhostTokenizer(vocab_size={self.vocab_size}, special_tokens={len(self._special_tokens)})"
|
ghostlm/trainer.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""GhostLM trainer — handles the full training loop, evaluation, checkpointing, and logging."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import time
|
| 7 |
+
from dataclasses import asdict
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Dict, Optional, Tuple
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
from ghostlm.config import GhostLMConfig
|
| 15 |
+
from ghostlm.model import GhostLM
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class GhostTrainer:
|
| 19 |
+
"""Manages the GhostLM training loop with evaluation, checkpointing, and logging.
|
| 20 |
+
|
| 21 |
+
Handles device placement, optimizer setup, cosine learning rate scheduling
|
| 22 |
+
with warmup, gradient clipping, periodic evaluation, checkpoint saving,
|
| 23 |
+
and JSON-based training log persistence. Supports mixed precision (AMP)
|
| 24 |
+
training on CUDA devices for faster throughput and lower memory usage.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, model: GhostLM, config: GhostLMConfig, use_amp: Optional[bool] = None):
|
| 28 |
+
"""Initialize the trainer.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
model: GhostLM model instance to train.
|
| 32 |
+
config: GhostLMConfig with training hyperparameters and paths.
|
| 33 |
+
use_amp: Enable mixed precision (AMP) training. Defaults to True
|
| 34 |
+
when running on CUDA, False otherwise. AMP is only supported
|
| 35 |
+
on CUDA devices — setting True on CPU/MPS will be ignored.
|
| 36 |
+
"""
|
| 37 |
+
self.model = model
|
| 38 |
+
self.config = config
|
| 39 |
+
|
| 40 |
+
# Resolve device
|
| 41 |
+
if config.device == "auto":
|
| 42 |
+
if torch.cuda.is_available():
|
| 43 |
+
self.device = "cuda"
|
| 44 |
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
| 45 |
+
self.device = "mps"
|
| 46 |
+
else:
|
| 47 |
+
self.device = "cpu"
|
| 48 |
+
else:
|
| 49 |
+
self.device = config.device
|
| 50 |
+
|
| 51 |
+
self.model = self.model.to(self.device)
|
| 52 |
+
|
| 53 |
+
# Mixed precision (AMP) — only effective on CUDA
|
| 54 |
+
if use_amp is None:
|
| 55 |
+
self.use_amp = self.device == "cuda"
|
| 56 |
+
else:
|
| 57 |
+
self.use_amp = use_amp and self.device == "cuda"
|
| 58 |
+
|
| 59 |
+
self.grad_scaler = torch.amp.GradScaler("cuda", enabled=self.use_amp)
|
| 60 |
+
|
| 61 |
+
# Optimizer
|
| 62 |
+
self.optimizer = self.model.configure_optimizers(config)
|
| 63 |
+
|
| 64 |
+
# Create directories
|
| 65 |
+
self.checkpoint_dir = Path(config.checkpoint_dir)
|
| 66 |
+
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| 67 |
+
|
| 68 |
+
self.log_dir = Path(config.log_dir)
|
| 69 |
+
self.log_dir.mkdir(parents=True, exist_ok=True)
|
| 70 |
+
|
| 71 |
+
# State
|
| 72 |
+
self.step = 0
|
| 73 |
+
self.accum_steps = getattr(config, 'grad_accum_steps', 4)
|
| 74 |
+
self.best_val_loss = float("inf")
|
| 75 |
+
self.log: list = []
|
| 76 |
+
|
| 77 |
+
def get_lr(self) -> float:
|
| 78 |
+
"""Compute the current learning rate using cosine decay with linear warmup.
|
| 79 |
+
|
| 80 |
+
During the warmup phase (step < warmup_steps), the learning rate scales
|
| 81 |
+
linearly from 0 to config.learning_rate. After warmup, it follows a
|
| 82 |
+
cosine decay schedule down to a minimum of 1e-5.
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
Current learning rate as a float.
|
| 86 |
+
"""
|
| 87 |
+
step = self.step
|
| 88 |
+
warmup = self.config.warmup_steps
|
| 89 |
+
max_steps = self.config.max_steps
|
| 90 |
+
base_lr = self.config.learning_rate
|
| 91 |
+
min_lr = 1e-5
|
| 92 |
+
|
| 93 |
+
if step < warmup:
|
| 94 |
+
return base_lr * (step + 1) / warmup
|
| 95 |
+
|
| 96 |
+
decay_ratio = (step - warmup) / max(1, max_steps - warmup)
|
| 97 |
+
decay_ratio = min(decay_ratio, 1.0)
|
| 98 |
+
|
| 99 |
+
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
|
| 100 |
+
return min_lr + (base_lr - min_lr) * cosine_decay
|
| 101 |
+
|
| 102 |
+
def _set_lr(self) -> None:
|
| 103 |
+
"""Apply the current learning rate from get_lr() to all optimizer parameter groups."""
|
| 104 |
+
lr = self.get_lr()
|
| 105 |
+
for group in self.optimizer.param_groups:
|
| 106 |
+
group["lr"] = lr
|
| 107 |
+
|
| 108 |
+
def train_step(self, batch: Tuple[torch.Tensor, torch.Tensor]) -> float:
|
| 109 |
+
"""Execute a single training step with gradient accumulation and optional AMP.
|
| 110 |
+
|
| 111 |
+
Accumulates gradients over self.accum_steps micro-steps before
|
| 112 |
+
updating weights, effectively multiplying the batch size without
|
| 113 |
+
increasing memory usage. When AMP is enabled, the forward pass runs
|
| 114 |
+
in float16 and the GradScaler handles loss scaling for stable training.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
batch: Tuple of (input_ids, target_ids) tensors.
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
Training loss as a float.
|
| 121 |
+
"""
|
| 122 |
+
x, y = batch
|
| 123 |
+
x = x.to(self.device)
|
| 124 |
+
y = y.to(self.device)
|
| 125 |
+
|
| 126 |
+
self.model.train()
|
| 127 |
+
|
| 128 |
+
# Split batch into micro-batches for gradient accumulation
|
| 129 |
+
micro_x = x.split(max(1, x.size(0) // self.accum_steps), dim=0)
|
| 130 |
+
micro_y = y.split(max(1, y.size(0) // self.accum_steps), dim=0)
|
| 131 |
+
|
| 132 |
+
total_loss = 0.0
|
| 133 |
+
|
| 134 |
+
for mx, my in zip(micro_x, micro_y):
|
| 135 |
+
with torch.amp.autocast("cuda", enabled=self.use_amp):
|
| 136 |
+
_, loss = self.model(mx, targets=my)
|
| 137 |
+
# Scale loss by number of accumulation steps
|
| 138 |
+
scaled_loss = loss / len(micro_x)
|
| 139 |
+
|
| 140 |
+
self.grad_scaler.scale(scaled_loss).backward()
|
| 141 |
+
total_loss += loss.item()
|
| 142 |
+
|
| 143 |
+
# Gradient clipping and optimizer step after accumulation
|
| 144 |
+
self.grad_scaler.unscale_(self.optimizer)
|
| 145 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_clip)
|
| 146 |
+
self.grad_scaler.step(self.optimizer)
|
| 147 |
+
self.grad_scaler.update()
|
| 148 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 149 |
+
|
| 150 |
+
self.step += 1
|
| 151 |
+
self._set_lr()
|
| 152 |
+
|
| 153 |
+
return total_loss / len(micro_x)
|
| 154 |
+
|
| 155 |
+
def eval_step(self, val_loader, num_batches: int = 20) -> float:
|
| 156 |
+
"""Run evaluation over a number of validation batches.
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
val_loader: DataLoader yielding (input_ids, target_ids) batches.
|
| 160 |
+
num_batches: Maximum number of batches to evaluate over.
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
Average validation loss as a float.
|
| 164 |
+
"""
|
| 165 |
+
self.model.eval()
|
| 166 |
+
total_loss = 0.0
|
| 167 |
+
count = 0
|
| 168 |
+
|
| 169 |
+
with torch.no_grad():
|
| 170 |
+
for i, batch in enumerate(val_loader):
|
| 171 |
+
if i >= num_batches:
|
| 172 |
+
break
|
| 173 |
+
x, y = batch
|
| 174 |
+
x = x.to(self.device)
|
| 175 |
+
y = y.to(self.device)
|
| 176 |
+
|
| 177 |
+
with torch.amp.autocast("cuda", enabled=self.use_amp):
|
| 178 |
+
_, loss = self.model(x, targets=y)
|
| 179 |
+
total_loss += loss.item()
|
| 180 |
+
count += 1
|
| 181 |
+
|
| 182 |
+
return total_loss / max(count, 1)
|
| 183 |
+
|
| 184 |
+
def save_checkpoint(self, val_loss: float) -> None:
|
| 185 |
+
"""Save a model checkpoint to disk.
|
| 186 |
+
|
| 187 |
+
Saves the current step, validation loss, model state dict, optimizer
|
| 188 |
+
state dict, and config. Also saves as "best_model.pt" if the current
|
| 189 |
+
validation loss is the best seen so far.
|
| 190 |
+
|
| 191 |
+
Args:
|
| 192 |
+
val_loss: Current validation loss for comparison.
|
| 193 |
+
"""
|
| 194 |
+
checkpoint = {
|
| 195 |
+
"step": self.step,
|
| 196 |
+
"val_loss": val_loss,
|
| 197 |
+
"model_state_dict": self.model.state_dict(),
|
| 198 |
+
"optimizer_state_dict": self.optimizer.state_dict(),
|
| 199 |
+
"grad_scaler_state_dict": self.grad_scaler.state_dict(),
|
| 200 |
+
"config": asdict(self.config),
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
filename = f"checkpoint_step_{self.step}.pt"
|
| 204 |
+
path = self.checkpoint_dir / filename
|
| 205 |
+
torch.save(checkpoint, path)
|
| 206 |
+
print(f" Saved checkpoint: {path}")
|
| 207 |
+
|
| 208 |
+
if val_loss < self.best_val_loss:
|
| 209 |
+
self.best_val_loss = val_loss
|
| 210 |
+
best_path = self.checkpoint_dir / "best_model.pt"
|
| 211 |
+
torch.save(checkpoint, best_path)
|
| 212 |
+
print(f" New best model saved: {best_path} (val_loss={val_loss:.4f})")
|
| 213 |
+
|
| 214 |
+
def load_checkpoint(self, path: str) -> None:
|
| 215 |
+
"""Load a model checkpoint from disk.
|
| 216 |
+
|
| 217 |
+
Restores the model state dict, optimizer state dict, training step,
|
| 218 |
+
and best validation loss from the saved checkpoint file.
|
| 219 |
+
|
| 220 |
+
Args:
|
| 221 |
+
path: File path to the checkpoint .pt file.
|
| 222 |
+
"""
|
| 223 |
+
checkpoint = torch.load(path, map_location=self.device, weights_only=False)
|
| 224 |
+
|
| 225 |
+
self.model.load_state_dict(checkpoint["model_state_dict"])
|
| 226 |
+
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
| 227 |
+
if "grad_scaler_state_dict" in checkpoint:
|
| 228 |
+
self.grad_scaler.load_state_dict(checkpoint["grad_scaler_state_dict"])
|
| 229 |
+
self.step = checkpoint["step"]
|
| 230 |
+
self.best_val_loss = checkpoint["val_loss"]
|
| 231 |
+
|
| 232 |
+
print(f"Loaded checkpoint from step {self.step} (val_loss={self.best_val_loss:.4f})")
|
| 233 |
+
|
| 234 |
+
def _log(self, data: dict) -> None:
|
| 235 |
+
"""Append a data dict to the training log and persist as JSON.
|
| 236 |
+
|
| 237 |
+
Args:
|
| 238 |
+
data: Dictionary of metrics and metadata to log.
|
| 239 |
+
"""
|
| 240 |
+
self.log.append(data)
|
| 241 |
+
log_path = self.log_dir / "training_log.json"
|
| 242 |
+
with open(log_path, "w") as f:
|
| 243 |
+
json.dump(self.log, f, indent=2)
|
| 244 |
+
|
| 245 |
+
def train(self, train_loader, val_loader) -> None:
|
| 246 |
+
"""Run the main training loop.
|
| 247 |
+
|
| 248 |
+
Iterates from the current step to config.max_steps, performing training
|
| 249 |
+
steps with a tqdm progress bar. Evaluates periodically at config.eval_interval
|
| 250 |
+
and saves checkpoints at config.save_interval. Performs a final evaluation
|
| 251 |
+
and saves the final checkpoint at the end of training.
|
| 252 |
+
|
| 253 |
+
Args:
|
| 254 |
+
train_loader: DataLoader yielding (input_ids, target_ids) training batches.
|
| 255 |
+
val_loader: DataLoader yielding (input_ids, target_ids) validation batches.
|
| 256 |
+
"""
|
| 257 |
+
print(f"Training on device: {self.device}")
|
| 258 |
+
print(f"Mixed precision (AMP): {'enabled' if self.use_amp else 'disabled'}")
|
| 259 |
+
print(f"Model size: {self.model.num_params():,} parameters")
|
| 260 |
+
print(f"Training from step {self.step} to {self.config.max_steps}")
|
| 261 |
+
|
| 262 |
+
# Create iterator that cycles through train_loader
|
| 263 |
+
def cycle(loader):
|
| 264 |
+
while True:
|
| 265 |
+
for batch in loader:
|
| 266 |
+
yield batch
|
| 267 |
+
|
| 268 |
+
train_iter = cycle(train_loader)
|
| 269 |
+
|
| 270 |
+
with tqdm(initial=self.step, total=self.config.max_steps, desc="Training") as pbar:
|
| 271 |
+
while self.step < self.config.max_steps:
|
| 272 |
+
t0 = time.time()
|
| 273 |
+
|
| 274 |
+
# Training step
|
| 275 |
+
batch = next(train_iter)
|
| 276 |
+
loss = self.train_step(batch)
|
| 277 |
+
|
| 278 |
+
dt = time.time() - t0
|
| 279 |
+
lr = self.get_lr()
|
| 280 |
+
|
| 281 |
+
pbar.set_postfix(loss=f"{loss:.4f}", lr=f"{lr:.2e}", dt=f"{dt:.3f}s")
|
| 282 |
+
pbar.update(1)
|
| 283 |
+
|
| 284 |
+
# Periodic evaluation
|
| 285 |
+
if self.step % self.config.eval_interval == 0:
|
| 286 |
+
val_loss = self.eval_step(val_loader)
|
| 287 |
+
print(f"\n Step {self.step} | val_loss={val_loss:.4f} | train_loss={loss:.4f}")
|
| 288 |
+
|
| 289 |
+
self._log({
|
| 290 |
+
"step": self.step,
|
| 291 |
+
"train_loss": loss,
|
| 292 |
+
"val_loss": val_loss,
|
| 293 |
+
"lr": lr,
|
| 294 |
+
"time": dt,
|
| 295 |
+
})
|
| 296 |
+
|
| 297 |
+
# Periodic checkpoint
|
| 298 |
+
if self.step % self.config.save_interval == 0:
|
| 299 |
+
val_loss = self.eval_step(val_loader)
|
| 300 |
+
self.save_checkpoint(val_loss)
|
| 301 |
+
|
| 302 |
+
# Final evaluation and checkpoint
|
| 303 |
+
print("\nTraining complete. Running final evaluation...")
|
| 304 |
+
val_loss = self.eval_step(val_loader)
|
| 305 |
+
print(f"Final val_loss: {val_loss:.4f}")
|
| 306 |
+
self.save_checkpoint(val_loss)
|
| 307 |
+
|
| 308 |
+
self._log({
|
| 309 |
+
"step": self.step,
|
| 310 |
+
"train_loss": loss,
|
| 311 |
+
"val_loss": val_loss,
|
| 312 |
+
"lr": lr,
|
| 313 |
+
"time": dt,
|
| 314 |
+
"status": "complete",
|
| 315 |
+
})
|
| 316 |
+
|
| 317 |
+
print(f"Training log saved to {self.log_dir / 'training_log.json'}")
|
requirements.txt
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hugging Face Spaces installs from this file at build time.
|
| 2 |
+
# Pinned conservatively so a Space build doesn't regress on a future
|
| 3 |
+
# breaking change in any of the deps.
|
| 4 |
+
|
| 5 |
+
# Note: gradio is intentionally NOT listed here. HF Spaces auto-installs
|
| 6 |
+
# `gradio[oauth,mcp]==<sdk_version>` on top of this file based on the SDK
|
| 7 |
+
# selection in README.md frontmatter. Listing it here causes a pip
|
| 8 |
+
# version-conflict at build time when our pin disagrees with HF's.
|
| 9 |
+
|
| 10 |
+
# torch >= 2.0 for the scaled_dot_product_attention path. CPU-only is
|
| 11 |
+
# fine on free Spaces.
|
| 12 |
+
torch>=2.0.0
|
| 13 |
+
|
| 14 |
+
# tiktoken is the GPT-2 BPE backend the GhostTokenizer wraps.
|
| 15 |
+
tiktoken>=0.5.0
|
| 16 |
+
|
| 17 |
+
# Python 3.13 removed the stdlib audioop module that gradio's transitive
|
| 18 |
+
# pydub dep imports at module-load time. Without this the entire gradio
|
| 19 |
+
# import chain fails with ModuleNotFoundError: No module named
|
| 20 |
+
# 'pyaudioop'. The PEP 594 replacement is audioop-lts. Conditional so
|
| 21 |
+
# 3.12 and earlier (where stdlib audioop still exists) skip it.
|
| 22 |
+
audioop-lts; python_version >= '3.13'
|