Ghostgim ghost commited on
Commit
4fcda01
·
0 Parent(s):

chore: squash history to reclaim LFS storage from removed checkpoint

Browse files
.gitattributes ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Hugging Face Spaces LFS rules. Copy this to the Space repo root
2
+ # alongside app.py / requirements.txt / README.md / ghostlm/ /
3
+ # checkpoints/. Without it the ~177 MB checkpoints either fail to push
4
+ # or land as broken pointer files.
5
+
6
+ *.pt filter=lfs diff=lfs merge=lfs -text
7
+ *.bin filter=lfs diff=lfs merge=lfs -text
8
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: GhostLM
3
+ emoji: 🔐
4
+ colorFrom: purple
5
+ colorTo: gray
6
+ sdk: gradio
7
+ app_file: app.py
8
+ pinned: false
9
+ license: apache-2.0
10
+ short_description: From-scratch cybersecurity LM — interactive demo
11
+ ---
12
+
13
+ # GhostLM Demo
14
+
15
+ Interactive Gradio UI for the canonical Phase 3.5 ghost-tiny model. Two
16
+ tabs: a single-checkpoint **Generate** view with curated prompt presets
17
+ and a generation history, and an optional **Compare** tab that runs the
18
+ same prompt through two checkpoints side-by-side (the canonical v0.3.5
19
+ vs. the v0.3.7 attempt that regressed).
20
+
21
+ This file is dual-purpose:
22
+
23
+ - **In the GitHub repo** (`demo/README.md`) — documents the demo and
24
+ the deploy steps.
25
+ - **As an HF Space README** — the YAML frontmatter at the top is parsed
26
+ by Hugging Face Spaces as the Space metadata. Keep it intact when
27
+ copying this file to a Space repo.
28
+
29
+ ## Run locally
30
+
31
+ From the repo root:
32
+
33
+ ```bash
34
+ pip install -r demo/requirements.txt
35
+ PYTHONPATH=. python3 demo/app.py
36
+ ```
37
+
38
+ Open `http://localhost:7860`. The demo defaults to
39
+ `checkpoints/phase3.5_balanced/best_model.pt` — pass `--checkpoint` to
40
+ load a different one:
41
+
42
+ ```bash
43
+ PYTHONPATH=. python3 demo/app.py --checkpoint checkpoints/phase3.6_exploitdb/best_model.pt
44
+ ```
45
+
46
+ To enable the Compare tab, add a second checkpoint:
47
+
48
+ ```bash
49
+ PYTHONPATH=. python3 demo/app.py \
50
+ --checkpoint checkpoints/phase3.5_balanced/best_model.pt \
51
+ --compare-checkpoint checkpoints/phase3.6_exploitdb/best_model.pt
52
+ ```
53
+
54
+ The same `--share` flag Gradio supports works:
55
+
56
+ ```bash
57
+ PYTHONPATH=. python3 demo/app.py --share
58
+ ```
59
+
60
+ ## Deploy to Hugging Face Spaces
61
+
62
+ A Space is a separate git repo on huggingface.co. The demo here lives
63
+ under `demo/` in the GhostLM repo so the source stays in one place; to
64
+ deploy you copy the demo files plus the `ghostlm/` package and a
65
+ checkpoint into a fresh Space repo.
66
+
67
+ ### 1. Create the Space
68
+
69
+ Either via the Hugging Face web UI (New → Space, SDK = Gradio) or via
70
+ CLI:
71
+
72
+ ```bash
73
+ pip install huggingface_hub
74
+ huggingface-cli login
75
+ huggingface-cli repo create ghostlm --type space --space-sdk gradio
76
+ ```
77
+
78
+ Replace `ghostlm` with your preferred Space name.
79
+
80
+ ### 2. Clone the Space repo and stage files
81
+
82
+ ```bash
83
+ git clone https://huggingface.co/spaces/<your-user>/ghostlm hf-space
84
+ cd hf-space
85
+
86
+ # Track the checkpoint via LFS (it's ~177 MB)
87
+ git lfs install
88
+ git lfs track "*.pt"
89
+
90
+ # Copy the demo + the ghostlm package + the canonical checkpoint
91
+ cp ../demo/app.py .
92
+ cp ../demo/requirements.txt .
93
+ cp ../demo/README.md .
94
+ cp -r ../ghostlm .
95
+ mkdir -p checkpoints/phase3.5_balanced
96
+ cp ../checkpoints/phase3.5_balanced/best_model.pt checkpoints/phase3.5_balanced/
97
+
98
+ git add .
99
+ git commit -m "Initial GhostLM Space deploy"
100
+ git push
101
+ ```
102
+
103
+ The Space will start building automatically; first build takes ~3–5
104
+ minutes (gradio + torch wheel install + checkpoint LFS pull). The
105
+ README's frontmatter tells HF this is a Gradio Space, sets the colors,
106
+ and pins `app_file: app.py`.
107
+
108
+ ### 3. Optional — include the Phase 3.6 checkpoint for the Compare tab
109
+
110
+ If you want the Compare tab live in the Space, also copy the Phase 3.6
111
+ checkpoint (~177 MB more) and set the env var in the Space's Settings
112
+ page:
113
+
114
+ ```bash
115
+ mkdir -p checkpoints/phase3.6_exploitdb
116
+ cp ../checkpoints/phase3.6_exploitdb/best_model.pt checkpoints/phase3.6_exploitdb/
117
+ git add checkpoints/phase3.6_exploitdb
118
+ git commit -m "Add Phase 3.6 for compare tab"
119
+ git push
120
+ ```
121
+
122
+ In the Space's **Settings → Variables**, add:
123
+
124
+ ```
125
+ GHOSTLM_COMPARE_CHECKPOINT = checkpoints/phase3.6_exploitdb/best_model.pt
126
+ ```
127
+
128
+ The Space restarts automatically. The Compare tab will now be visible.
129
+
130
+ ### 4. Updates
131
+
132
+ Push to the Space repo whenever the demo changes; the Space rebuilds.
133
+ For a checkpoint update push the new `.pt` file (LFS handles it).
134
+
135
+ ## What it looks like
136
+
137
+ The **Generate** tab gives you a prompt textbox, three sampling sliders
138
+ (max tokens, temperature, top-k), and a continuation panel. Below that,
139
+ collapsible accordions group the preset prompts by register (CVE / MITRE
140
+ / CTF / CAPEC / free-form) so visitors can immediately see what kind of
141
+ prose the model knows. A history panel keeps the last five generations
142
+ visible.
143
+
144
+ The **Compare** tab — only shown when a second checkpoint is loaded —
145
+ sends the same prompt + sampling settings to both models in turn so the
146
+ Phase 3.5 → 3.6 trajectory is visible in real text rather than just
147
+ accuracy numbers.
148
+
149
+ ## Why this exists
150
+
151
+ The point of the demo isn't to impress visitors with fluency — at 14.7M
152
+ parameters trained on 8.8M tokens, the model produces register-shaped
153
+ fiction, not knowledge. The point is to make the project's
154
+ trajectory-over-absolute-quality framing concrete: visitors can poke at
155
+ the canonical model, see exactly what it knows and doesn't, and if both
156
+ checkpoints are loaded, see the empirical capacity-ceiling finding for
157
+ themselves.
app.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GhostLM Gradio Space — chat UI for the v0.5.0 chat-v3 (CTIBench 36.9%) model.
2
+
3
+ Multi-turn chat using the model's three role tokens
4
+ (<|ghost_user|>, <|ghost_assistant|>, <|ghost_end|>). Generation stops the
5
+ moment the assistant's <|ghost_end|> is sampled. Repetition penalty is on
6
+ by default — without it the 45M model occasionally degenerates into
7
+ "Wifi Wifi Wifi…" loops on small prompts.
8
+
9
+ Runs on Spaces cpu-basic (2 vCPU). Generation is ~5-15 s per reply at
10
+ the default 200-token cap.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import os
16
+ import sys
17
+ from dataclasses import fields
18
+ from pathlib import Path
19
+ from typing import List
20
+
21
+ import gradio as gr
22
+ import torch
23
+ import torch.nn.functional as F
24
+
25
+ REPO_ROOT = Path(__file__).resolve().parent
26
+ if str(REPO_ROOT) not in sys.path:
27
+ sys.path.insert(0, str(REPO_ROOT))
28
+
29
+ from ghostlm.config import GhostLMConfig
30
+ from ghostlm.model import GhostLM
31
+ from ghostlm.tokenizer import GhostTokenizer
32
+
33
+
34
+ # ---------------------------------------------------------------------------
35
+ # Loading
36
+ # ---------------------------------------------------------------------------
37
+
38
+ CHECKPOINT_CANDIDATES = [
39
+ "checkpoints/phase5_chat_v3/best_model.pt",
40
+ "checkpoints/best_model.pt", # fallback if pushed at the root
41
+ ]
42
+
43
+
44
+ def find_checkpoint() -> str:
45
+ """Return the first checkpoint path that exists, or empty string."""
46
+ for path in CHECKPOINT_CANDIDATES:
47
+ if Path(path).exists():
48
+ return path
49
+ return ""
50
+
51
+
52
+ def load_model(path: str):
53
+ """Load a GhostLM checkpoint into eval mode on CPU."""
54
+ if not path:
55
+ # Random-init fallback so the UI still launches if weights are missing.
56
+ config = GhostLMConfig.from_preset("ghost-tiny")
57
+ config.vocab_size = 50264
58
+ config.context_length = 256
59
+ model = GhostLM(config).eval()
60
+ return model, config, "(random ghost-tiny — weights missing on Space)"
61
+
62
+ ckpt = torch.load(path, map_location="cpu", weights_only=False)
63
+ saved = ckpt["config"]
64
+ config = GhostLMConfig(**{
65
+ f.name: saved[f.name]
66
+ for f in fields(GhostLMConfig)
67
+ if f.name in saved
68
+ })
69
+ model = GhostLM(config)
70
+ state = ckpt.get("model_state_dict", ckpt.get("model"))
71
+ model.load_state_dict(state, strict=False)
72
+ model.eval()
73
+ return model, config, path
74
+
75
+
76
+ # ---------------------------------------------------------------------------
77
+ # Generation — inlined from scripts/chat.py so the Space stays self-contained.
78
+ # ---------------------------------------------------------------------------
79
+
80
+
81
+ def sample_next(
82
+ logits: torch.Tensor,
83
+ *,
84
+ temperature: float,
85
+ top_k: int,
86
+ top_p: float,
87
+ prev_ids: List[int],
88
+ repetition_penalty: float,
89
+ ) -> int:
90
+ """Sample one token from logits with temperature, top-k / top-p, and rep-penalty."""
91
+ if prev_ids and repetition_penalty != 1.0:
92
+ for tok in set(prev_ids):
93
+ if logits[tok] > 0:
94
+ logits[tok] = logits[tok] / repetition_penalty
95
+ else:
96
+ logits[tok] = logits[tok] * repetition_penalty
97
+ logits = logits / max(temperature, 1e-6)
98
+ if top_k and top_k > 0:
99
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
100
+ logits[logits < v[..., -1:]] = float("-inf")
101
+ if top_p and top_p < 1.0:
102
+ sorted_logits, sorted_idx = torch.sort(logits, descending=True)
103
+ probs = F.softmax(sorted_logits, dim=-1)
104
+ cum = probs.cumsum(dim=-1)
105
+ cutoff = cum > top_p
106
+ cutoff[..., 0] = False
107
+ sorted_logits[cutoff] = float("-inf")
108
+ logits = torch.full_like(logits, float("-inf")).scatter(-1, sorted_idx, sorted_logits)
109
+ probs = F.softmax(logits, dim=-1)
110
+ return int(torch.multinomial(probs, num_samples=1).item())
111
+
112
+
113
+ def generate_until_end(
114
+ model,
115
+ prompt_ids: List[int],
116
+ *,
117
+ end_id: int,
118
+ max_new_tokens: int,
119
+ temperature: float,
120
+ top_k: int,
121
+ top_p: float,
122
+ repetition_penalty: float,
123
+ ) -> List[int]:
124
+ """Greedy-or-sampled generation that stops the moment ``end_id`` is sampled."""
125
+ ids = torch.tensor(prompt_ids, dtype=torch.long).unsqueeze(0)
126
+ new_ids: List[int] = []
127
+ ctx = model.config.context_length
128
+ with torch.no_grad():
129
+ for _ in range(max_new_tokens):
130
+ cond = ids[:, -ctx:]
131
+ logits, _ = model(cond)
132
+ next_logits = logits[:, -1, :].squeeze(0).clone()
133
+ tok = sample_next(
134
+ next_logits,
135
+ temperature=temperature, top_k=top_k, top_p=top_p,
136
+ prev_ids=new_ids[-128:], repetition_penalty=repetition_penalty,
137
+ )
138
+ if tok == end_id:
139
+ break
140
+ new_ids.append(tok)
141
+ ids = torch.cat([ids, torch.tensor([[tok]])], dim=1)
142
+ return new_ids
143
+
144
+
145
+ # ---------------------------------------------------------------------------
146
+ # Module-level state
147
+ # ---------------------------------------------------------------------------
148
+
149
+ CHECKPOINT_PATH = find_checkpoint()
150
+ MODEL, CONFIG, LOADED_FROM = load_model(CHECKPOINT_PATH)
151
+ TOKENIZER = GhostTokenizer()
152
+ END_ID = TOKENIZER._special_tokens[TOKENIZER.END]
153
+
154
+
155
+ # ---------------------------------------------------------------------------
156
+ # Chat handler
157
+ # ---------------------------------------------------------------------------
158
+
159
+
160
+ def chat_fn(message: str, history: list, temperature: float, top_k: int,
161
+ top_p: float, max_tokens: int, repetition_penalty: float) -> str:
162
+ """Generate one assistant turn given the prior history + new user message.
163
+
164
+ ``history`` may arrive in either Gradio-tuples format
165
+ ``[(user, bot), ...]`` (older) or messages format
166
+ ``[{"role", "content"}, ...]`` (newer). We coerce to messages.
167
+ """
168
+ turns: list = []
169
+ for h in history:
170
+ if isinstance(h, dict) and h.get("role") in ("user", "assistant"):
171
+ turns.append({"role": h["role"], "content": h["content"]})
172
+ elif isinstance(h, (list, tuple)) and len(h) == 2:
173
+ user_msg, bot_msg = h
174
+ if user_msg:
175
+ turns.append({"role": "user", "content": user_msg})
176
+ if bot_msg:
177
+ turns.append({"role": "assistant", "content": bot_msg})
178
+ turns.append({"role": "user", "content": message})
179
+
180
+ prompt_ids = TOKENIZER.format_chat_prompt(turns)
181
+ # Trim conversation if the prompt overflows the context budget.
182
+ ctx_budget = CONFIG.context_length - max_tokens - 8
183
+ while len(prompt_ids) > ctx_budget and len(turns) > 1:
184
+ # Drop the oldest user/assistant pair, but keep the just-asked turn.
185
+ if len(turns) >= 3:
186
+ del turns[:2]
187
+ prompt_ids = TOKENIZER.format_chat_prompt(turns)
188
+ else:
189
+ break
190
+
191
+ new_ids = generate_until_end(
192
+ MODEL, prompt_ids,
193
+ end_id=END_ID,
194
+ max_new_tokens=int(max_tokens),
195
+ temperature=float(temperature),
196
+ top_k=int(top_k),
197
+ top_p=float(top_p),
198
+ repetition_penalty=float(repetition_penalty),
199
+ )
200
+ return TOKENIZER.decode(new_ids).strip() or "(no response)"
201
+
202
+
203
+ # ---------------------------------------------------------------------------
204
+ # UI
205
+ # ---------------------------------------------------------------------------
206
+
207
+ DESCRIPTION = f"""
208
+ # GhostLM — chat-v3 (v0.5.0)
209
+
210
+ A 45M-parameter cybersecurity language model **trained from scratch** on
211
+ 12.56M tokens of NVD / MITRE ATT&CK / Exploit-DB / CTFtime / arXiv cs.CR
212
+ text. The chat-tuned checkpoint here scored **36.9% on
213
+ [CTIBench MCQ](https://huggingface.co/datasets/AI4Sec/cti-bench)** — 1.48× random for a
214
+ 2,500-question security multiple-choice benchmark.
215
+
216
+ **Honest expectations:** identity, OOD-refusal, and chat shape work. Specific
217
+ facts (CVE numbers, CVSS scores, dates, technique IDs) are unreliable —
218
+ the model often confabulates plausible-looking but wrong specifics. Always
219
+ verify against authoritative sources. Outside cybersecurity, the model
220
+ politely declines and returns to its domain.
221
+
222
+ **Loaded checkpoint:** `{LOADED_FROM}`
223
+ """
224
+
225
+ EXAMPLES = [
226
+ "What is XSS?",
227
+ "Explain MITRE ATT&CK technique T1059.",
228
+ "What does SSRF stand for?",
229
+ "How does a buffer overflow work?",
230
+ "Walk me through a typical SQL injection attack.",
231
+ "What's the difference between CVE and CWE?",
232
+ "Where do I start learning cybersecurity?",
233
+ "Are you ChatGPT?",
234
+ ]
235
+
236
+
237
+ with gr.Blocks(title="GhostLM Chat") as demo:
238
+ gr.Markdown(DESCRIPTION)
239
+ with gr.Row():
240
+ with gr.Column(scale=3):
241
+ chat = gr.ChatInterface(
242
+ fn=chat_fn,
243
+ # Each example needs values for every additional_input when
244
+ # they're configured below — list-of-lists [message, temp,
245
+ # top_k, top_p, max_tokens, rep_penalty]. The defaults below
246
+ # match the sliders so a user can click an example and get
247
+ # consistent generation settings.
248
+ examples=[[ex, 0.7, 40, 0.95, 200, 1.25] for ex in EXAMPLES],
249
+ additional_inputs=[
250
+ gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Temperature"),
251
+ gr.Slider(0, 100, value=40, step=1, label="Top-k"),
252
+ gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p"),
253
+ gr.Slider(32, 400, value=200, step=8, label="Max tokens"),
254
+ gr.Slider(1.0, 2.0, value=1.25, step=0.05, label="Repetition penalty"),
255
+ ],
256
+ )
257
+ gr.Markdown(
258
+ "Source: [github.com/joemunene-by/GhostLM](https://github.com/joemunene-by/GhostLM)"
259
+ " · Weights: [Ghostgim/GhostLM](https://huggingface.co/Ghostgim/GhostLM)"
260
+ " · The model is small enough to run locally — see the GitHub README for instructions."
261
+ )
262
+
263
+
264
+ if __name__ == "__main__":
265
+ demo.queue().launch()
checkpoints/phase5_chat_v3/best_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a1c2dbbb3f2559153953cdec8c0e8adbcdf0659fe4b61c3eb05a4e21c6b216f0
3
+ size 542187521
ghostlm/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GhostLM — open-source cybersecurity-focused language model."""
2
+
3
+ from ghostlm.config import GhostLMConfig
4
+ from ghostlm.model import GhostLM
5
+ from ghostlm.tokenizer import GhostTokenizer
6
+ from ghostlm.dataset import GhostDataset, build_dataloaders
7
+ from ghostlm.trainer import GhostTrainer
8
+
9
+ __version__ = "0.1.0"
10
+ __author__ = "Joe Munene"
11
+
12
+ __all__ = [
13
+ "GhostLMConfig",
14
+ "GhostLM",
15
+ "GhostTokenizer",
16
+ "GhostDataset",
17
+ "build_dataloaders",
18
+ "GhostTrainer",
19
+ ]
ghostlm/config.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GhostLM configuration — all model and training hyperparameters live here."""
2
+
3
+ from dataclasses import dataclass, field
4
+
5
+
6
+ @dataclass
7
+ class GhostLMConfig:
8
+ """Configuration dataclass for the GhostLM transformer language model.
9
+
10
+ Holds all hyperparameters for model architecture, training, data paths,
11
+ and system settings. Supports preset configurations and parameter counting.
12
+ """
13
+
14
+ # Model architecture
15
+ vocab_size: int = 50257
16
+ context_length: int = 1024
17
+ d_model: int = 512
18
+ n_heads: int = 8
19
+ n_layers: int = 6
20
+ d_ff: int = 2048
21
+ dropout: float = 0.1
22
+ bias: bool = True
23
+ use_rope: bool = False
24
+ use_flash_attention: bool = False
25
+
26
+ # Training
27
+ batch_size: int = 32
28
+ learning_rate: float = 3e-4
29
+ weight_decay: float = 0.1
30
+ beta1: float = 0.9
31
+ beta2: float = 0.95
32
+ grad_clip: float = 1.0
33
+ grad_accum_steps: int = 4
34
+ warmup_steps: int = 2000
35
+ max_steps: int = 100000
36
+ eval_interval: int = 500
37
+ save_interval: int = 1000
38
+
39
+ # Paths
40
+ data_dir: str = "data/processed"
41
+ checkpoint_dir: str = "checkpoints"
42
+ log_dir: str = "logs"
43
+
44
+ # System
45
+ device: str = "auto"
46
+ dtype: str = "float32"
47
+ seed: int = 42
48
+ use_wandb: bool = False
49
+
50
+ def model_size(self) -> str:
51
+ """Estimate total parameter count and return a human-readable string.
52
+
53
+ Computes the approximate number of trainable parameters based on
54
+ vocab_size, d_model, n_heads, n_layers, and d_ff.
55
+
56
+ Returns:
57
+ A string like "124M" or "1.2B" representing the estimated size.
58
+ """
59
+ embedding_params = self.vocab_size * self.d_model
60
+ attention_params = self.n_layers * (
61
+ 4 * self.d_model * self.d_model + 2 * self.d_model
62
+ )
63
+ ffn_params = self.n_layers * (
64
+ 2 * self.d_model * self.d_ff + self.d_model + self.d_ff
65
+ )
66
+ layer_norm_params = self.n_layers * 4 * self.d_model
67
+ output_head_params = self.d_model * self.vocab_size
68
+
69
+ total = embedding_params + attention_params + ffn_params + layer_norm_params + output_head_params
70
+
71
+ if total >= 1e9:
72
+ return f"{total / 1e9:.1f}B"
73
+ elif total >= 1e6:
74
+ return f"{total / 1e6:.0f}M"
75
+ else:
76
+ return f"{total:.0f}K"
77
+
78
+ @classmethod
79
+ def from_preset(cls, preset: str) -> "GhostLMConfig":
80
+ """Return a GhostLMConfig instance from a named preset.
81
+
82
+ Args:
83
+ preset: One of "ghost-tiny", "ghost-small", or "ghost-medium".
84
+
85
+ Returns:
86
+ A GhostLMConfig configured with the preset's hyperparameters.
87
+
88
+ Raises:
89
+ ValueError: If the preset name is not recognized.
90
+ """
91
+ presets = {
92
+ "ghost-tiny": {
93
+ "n_layers": 2,
94
+ "d_model": 256,
95
+ "n_heads": 4,
96
+ "d_ff": 1024,
97
+ },
98
+ "ghost-small": {
99
+ "n_layers": 6,
100
+ "d_model": 512,
101
+ "n_heads": 8,
102
+ "d_ff": 2048,
103
+ },
104
+ "ghost-medium": {
105
+ "n_layers": 12,
106
+ "d_model": 768,
107
+ "n_heads": 12,
108
+ "d_ff": 3072,
109
+ },
110
+ }
111
+
112
+ if preset not in presets:
113
+ raise ValueError(
114
+ f"Unknown preset '{preset}'. "
115
+ f"Available presets: {', '.join(presets.keys())}"
116
+ )
117
+
118
+ return cls(**presets[preset])
119
+
120
+ def __repr__(self) -> str:
121
+ """Return a clean, grouped string summary of all config values.
122
+
123
+ Returns:
124
+ A formatted multi-line string with config values grouped by
125
+ category: Architecture, Training, Paths, and System.
126
+ """
127
+ lines = [
128
+ "GhostLMConfig",
129
+ "=" * 40,
130
+ "Architecture:",
131
+ f" vocab_size: {self.vocab_size}",
132
+ f" context_length: {self.context_length}",
133
+ f" d_model: {self.d_model}",
134
+ f" n_heads: {self.n_heads}",
135
+ f" n_layers: {self.n_layers}",
136
+ f" d_ff: {self.d_ff}",
137
+ f" dropout: {self.dropout}",
138
+ f" bias: {self.bias}",
139
+ "Training:",
140
+ f" batch_size: {self.batch_size}",
141
+ f" learning_rate: {self.learning_rate}",
142
+ f" weight_decay: {self.weight_decay}",
143
+ f" beta1: {self.beta1}",
144
+ f" beta2: {self.beta2}",
145
+ f" grad_clip: {self.grad_clip}",
146
+ f" warmup_steps: {self.warmup_steps}",
147
+ f" max_steps: {self.max_steps}",
148
+ f" eval_interval: {self.eval_interval}",
149
+ f" save_interval: {self.save_interval}",
150
+ "Paths:",
151
+ f" data_dir: {self.data_dir}",
152
+ f" checkpoint_dir: {self.checkpoint_dir}",
153
+ f" log_dir: {self.log_dir}",
154
+ "System:",
155
+ f" device: {self.device}",
156
+ f" dtype: {self.dtype}",
157
+ f" seed: {self.seed}",
158
+ f" use_wandb: {self.use_wandb}",
159
+ "=" * 40,
160
+ f"Estimated size: {self.model_size()}",
161
+ ]
162
+ return "\n".join(lines)
ghostlm/dataset.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GhostLM dataset — converts processed JSONL data into PyTorch DataLoader-ready tensors."""
2
+
3
+ import json
4
+ from pathlib import Path
5
+ from typing import Dict, List, Optional, Tuple
6
+
7
+ import torch
8
+ from torch.utils.data import DataLoader, Dataset
9
+
10
+ from ghostlm.config import GhostLMConfig
11
+ from ghostlm.tokenizer import GhostTokenizer
12
+
13
+
14
+ class GhostDataset(Dataset):
15
+ """PyTorch Dataset for GhostLM language model training.
16
+
17
+ Loads tokenized text from a JSONL file, concatenates all tokens
18
+ into a single flat sequence, and yields fixed-length chunks for
19
+ autoregressive language modeling (x, y shifted by one token).
20
+ """
21
+
22
+ def __init__(self, jsonl_path: str, tokenizer: GhostTokenizer, config: GhostLMConfig):
23
+ """Initialize the dataset from a JSONL file.
24
+
25
+ Reads all records, tokenizes the "text" field of each, and
26
+ concatenates them into one continuous token stream.
27
+
28
+ Args:
29
+ jsonl_path: Path to the processed JSONL file.
30
+ tokenizer: GhostTokenizer instance for encoding text.
31
+ config: GhostLMConfig containing context_length.
32
+ """
33
+ self.context_length = config.context_length
34
+ self.tokens: List[int] = []
35
+
36
+ with open(jsonl_path, "r", encoding="utf-8") as f:
37
+ for line in f:
38
+ line = line.strip()
39
+ if not line:
40
+ continue
41
+ record = json.loads(line)
42
+ text = record.get("text", "")
43
+ if text:
44
+ self.tokens.extend(tokenizer.encode(text))
45
+
46
+ print(f" Loaded {len(self.tokens):,} tokens from {jsonl_path}")
47
+
48
+ def __len__(self) -> int:
49
+ """Return the number of non-overlapping context-length chunks.
50
+
51
+ Returns:
52
+ Integer count of available training samples.
53
+ """
54
+ return len(self.tokens) // self.context_length
55
+
56
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
57
+ """Retrieve a single (input, target) token chunk.
58
+
59
+ The target sequence is the input sequence shifted left by one
60
+ token, enabling next-token prediction training.
61
+
62
+ Args:
63
+ idx: Index of the chunk to retrieve.
64
+
65
+ Returns:
66
+ Tuple of (x, y) tensors, each of shape (context_length,).
67
+ """
68
+ start = idx * self.context_length
69
+ end = start + self.context_length
70
+
71
+ x = self.tokens[start:end]
72
+ y = self.tokens[start + 1 : end + 1]
73
+
74
+ # Pad target with -1 if we hit the end of data (cross-entropy ignores -1)
75
+ if len(y) < len(x):
76
+ y = y + [-1] * (len(x) - len(y))
77
+
78
+ return (
79
+ torch.tensor(x, dtype=torch.long),
80
+ torch.tensor(y, dtype=torch.long),
81
+ )
82
+
83
+
84
+ def build_dataloaders(
85
+ train_path: str,
86
+ val_path: str,
87
+ tokenizer: GhostTokenizer,
88
+ config: GhostLMConfig,
89
+ ) -> Tuple[DataLoader, DataLoader]:
90
+ """Build train and validation DataLoaders from JSONL files.
91
+
92
+ Creates GhostDataset instances for both splits and wraps them
93
+ in PyTorch DataLoaders with appropriate batching and shuffling.
94
+
95
+ Args:
96
+ train_path: Path to the training JSONL file.
97
+ val_path: Path to the validation JSONL file.
98
+ tokenizer: GhostTokenizer instance for encoding.
99
+ config: GhostLMConfig with batch_size and context_length.
100
+
101
+ Returns:
102
+ Tuple of (train_loader, val_loader).
103
+ """
104
+ train_dataset = GhostDataset(train_path, tokenizer, config)
105
+ val_dataset = GhostDataset(val_path, tokenizer, config)
106
+
107
+ train_loader = DataLoader(
108
+ train_dataset,
109
+ batch_size=config.batch_size,
110
+ shuffle=True,
111
+ drop_last=True,
112
+ num_workers=0,
113
+ pin_memory=True,
114
+ )
115
+
116
+ val_loader = DataLoader(
117
+ val_dataset,
118
+ batch_size=config.batch_size,
119
+ shuffle=False,
120
+ drop_last=False,
121
+ num_workers=0,
122
+ pin_memory=True,
123
+ )
124
+
125
+ return train_loader, val_loader
ghostlm/model.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GhostLM transformer model — decoder-only architecture built from scratch in PyTorch."""
2
+
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from ghostlm.config import GhostLMConfig
10
+
11
+
12
+ class RotaryEmbedding(nn.Module):
13
+ """Rotary Position Embedding (RoPE).
14
+
15
+ Encodes relative position information directly into the attention
16
+ computation by rotating query and key vectors. Used by LLaMA, Mistral,
17
+ and most modern transformer architectures.
18
+ """
19
+
20
+ def __init__(self, head_dim: int, context_length: int, base: float = 10000.0):
21
+ super().__init__()
22
+ inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
23
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
24
+
25
+ # Precompute cos/sin for all positions
26
+ t = torch.arange(context_length).float()
27
+ freqs = torch.outer(t, inv_freq)
28
+ emb = torch.cat((freqs, freqs), dim=-1)
29
+ self.register_buffer("cos_cached", emb.cos(), persistent=False)
30
+ self.register_buffer("sin_cached", emb.sin(), persistent=False)
31
+
32
+ def forward(self, seq_len: int):
33
+ return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
34
+
35
+
36
+ def _rotate_half(x):
37
+ """Rotate the second half of the last dimension and negate it."""
38
+ x1, x2 = x.chunk(2, dim=-1)
39
+ return torch.cat((-x2, x1), dim=-1)
40
+
41
+
42
+ def apply_rotary_pos_emb(q, k, cos, sin):
43
+ """Apply rotary position embeddings to query and key tensors.
44
+
45
+ Args:
46
+ q: Query tensor of shape (B, n_heads, T, head_dim).
47
+ k: Key tensor of shape (B, n_heads, T, head_dim).
48
+ cos: Cosine frequencies of shape (T, head_dim).
49
+ sin: Sine frequencies of shape (T, head_dim).
50
+
51
+ Returns:
52
+ Tuple of (rotated_q, rotated_k).
53
+ """
54
+ cos = cos.unsqueeze(0).unsqueeze(0) # (1, 1, T, head_dim)
55
+ sin = sin.unsqueeze(0).unsqueeze(0)
56
+ q_embed = (q * cos) + (_rotate_half(q) * sin)
57
+ k_embed = (k * cos) + (_rotate_half(k) * sin)
58
+ return q_embed, k_embed
59
+
60
+
61
+ class CausalSelfAttention(nn.Module):
62
+ """Multi-head causal self-attention with autoregressive masking.
63
+
64
+ Uses a single combined QKV projection for efficiency, then splits
65
+ the result into separate query, key, and value tensors. Supports
66
+ optional RoPE (Rotary Position Embeddings) and Flash Attention.
67
+ """
68
+
69
+ def __init__(self, config: GhostLMConfig):
70
+ """Initialize causal self-attention.
71
+
72
+ Args:
73
+ config: GhostLMConfig containing d_model, n_heads, dropout,
74
+ context_length, bias, use_rope, and use_flash_attention.
75
+ """
76
+ super().__init__()
77
+ assert config.d_model % config.n_heads == 0, "d_model must be divisible by n_heads"
78
+
79
+ self.n_heads = config.n_heads
80
+ self.head_dim = config.d_model // config.n_heads
81
+ self.context_length = config.context_length
82
+ self.use_rope = config.use_rope
83
+ self.use_flash_attention = config.use_flash_attention
84
+ self.dropout_p = config.dropout
85
+
86
+ # Single combined QKV projection
87
+ self.c_qkv = nn.Linear(config.d_model, 3 * config.d_model, bias=config.bias)
88
+ self.proj = nn.Linear(config.d_model, config.d_model, bias=config.bias)
89
+
90
+ # Dropout applied to attention weights (manual path only)
91
+ self.attn_dropout = nn.Dropout(config.dropout)
92
+ self.resid_dropout = nn.Dropout(config.dropout)
93
+
94
+ # RoPE
95
+ if self.use_rope:
96
+ self.rope = RotaryEmbedding(self.head_dim, config.context_length)
97
+
98
+ # Causal mask buffer (only needed for manual attention path)
99
+ if not self.use_flash_attention:
100
+ self.register_buffer(
101
+ "causal_mask",
102
+ torch.tril(torch.ones(config.context_length, config.context_length))
103
+ .view(1, 1, config.context_length, config.context_length),
104
+ persistent=False,
105
+ )
106
+
107
+ def forward(self, x):
108
+ """Forward pass through causal self-attention.
109
+
110
+ Args:
111
+ x: Input tensor of shape (B, T, d_model).
112
+
113
+ Returns:
114
+ Output tensor of shape (B, T, d_model).
115
+ """
116
+ B, T, C = x.size()
117
+
118
+ # Combined QKV projection and split
119
+ qkv = self.c_qkv(x)
120
+ q, k, v = qkv.split(self.n_heads * self.head_dim, dim=-1)
121
+
122
+ # Reshape to (B, n_heads, T, head_dim)
123
+ q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
124
+ k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
125
+ v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
126
+
127
+ # Apply RoPE to Q and K (not V)
128
+ if self.use_rope:
129
+ cos, sin = self.rope(T)
130
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
131
+
132
+ if self.use_flash_attention:
133
+ # PyTorch 2.0+ scaled_dot_product_attention with automatic backend selection
134
+ y = F.scaled_dot_product_attention(
135
+ q, k, v,
136
+ attn_mask=None,
137
+ dropout_p=self.dropout_p if self.training else 0.0,
138
+ is_causal=True,
139
+ )
140
+ else:
141
+ # Manual attention path
142
+ scale = 1.0 / math.sqrt(self.head_dim)
143
+ att = (q @ k.transpose(-2, -1)) * scale
144
+ att = att.masked_fill(self.causal_mask[:, :, :T, :T] == 0, float("-inf"))
145
+ att = F.softmax(att, dim=-1)
146
+ att = self.attn_dropout(att)
147
+ y = att @ v
148
+
149
+ # Reassemble heads and project
150
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
151
+ y = self.resid_dropout(self.proj(y))
152
+
153
+ return y
154
+
155
+
156
+ class FeedForward(nn.Module):
157
+ """Position-wise feed-forward network with GELU activation.
158
+
159
+ Two linear layers with an intermediate GELU non-linearity:
160
+ d_model -> d_ff -> d_model, with dropout after the second layer.
161
+ """
162
+
163
+ def __init__(self, config: GhostLMConfig):
164
+ """Initialize the feed-forward network.
165
+
166
+ Args:
167
+ config: GhostLMConfig containing d_model, d_ff, dropout, and bias.
168
+ """
169
+ super().__init__()
170
+ self.fc1 = nn.Linear(config.d_model, config.d_ff, bias=config.bias)
171
+ self.fc2 = nn.Linear(config.d_ff, config.d_model, bias=config.bias)
172
+ self.dropout = nn.Dropout(config.dropout)
173
+
174
+ def forward(self, x):
175
+ """Forward pass through the feed-forward network.
176
+
177
+ Args:
178
+ x: Input tensor of shape (B, T, d_model).
179
+
180
+ Returns:
181
+ Output tensor of shape (B, T, d_model).
182
+ """
183
+ x = self.fc1(x)
184
+ x = F.gelu(x)
185
+ x = self.fc2(x)
186
+ x = self.dropout(x)
187
+ return x
188
+
189
+
190
+ class TransformerBlock(nn.Module):
191
+ """Single transformer decoder block with pre-normalization.
192
+
193
+ Applies LayerNorm before both the self-attention and feed-forward
194
+ sub-layers (pre-norm architecture), with residual connections
195
+ around each sub-layer.
196
+ """
197
+
198
+ def __init__(self, config: GhostLMConfig):
199
+ """Initialize the transformer block.
200
+
201
+ Args:
202
+ config: GhostLMConfig passed to sub-modules.
203
+ """
204
+ super().__init__()
205
+ self.ln_1 = nn.LayerNorm(config.d_model)
206
+ self.attn = CausalSelfAttention(config)
207
+ self.ln_2 = nn.LayerNorm(config.d_model)
208
+ self.ffn = FeedForward(config)
209
+
210
+ def forward(self, x):
211
+ """Forward pass through the transformer block.
212
+
213
+ Args:
214
+ x: Input tensor of shape (B, T, d_model).
215
+
216
+ Returns:
217
+ Output tensor of shape (B, T, d_model).
218
+ """
219
+ # Pre-norm + self-attention with residual
220
+ x = x + self.attn(self.ln_1(x))
221
+ # Pre-norm + feed-forward with residual
222
+ x = x + self.ffn(self.ln_2(x))
223
+ return x
224
+
225
+
226
+ class GhostLM(nn.Module):
227
+ """GhostLM decoder-only transformer language model.
228
+
229
+ Built from scratch in PyTorch with learned positional embeddings,
230
+ stacked transformer blocks, and weight-tied output projection.
231
+ """
232
+
233
+ def __init__(self, config: GhostLMConfig):
234
+ """Initialize the GhostLM model.
235
+
236
+ Args:
237
+ config: GhostLMConfig with all model hyperparameters.
238
+ """
239
+ super().__init__()
240
+ self.config = config
241
+
242
+ # Embeddings
243
+ self.token_embedding = nn.Embedding(config.vocab_size, config.d_model)
244
+ if not config.use_rope:
245
+ self.pos_embedding = nn.Embedding(config.context_length, config.d_model)
246
+ self.dropout = nn.Dropout(config.dropout)
247
+
248
+ # Transformer blocks
249
+ self.blocks = nn.ModuleList(
250
+ [TransformerBlock(config) for _ in range(config.n_layers)]
251
+ )
252
+
253
+ # Final layer norm
254
+ self.ln_f = nn.LayerNorm(config.d_model)
255
+
256
+ # Output head with weight tying (no bias)
257
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
258
+ self.lm_head.weight = self.token_embedding.weight
259
+
260
+ # Initialize weights
261
+ self.apply(self._init_weights)
262
+
263
+ # Apply scaled residual initialization for deeper models
264
+ for pn, p in self.named_parameters():
265
+ if pn.endswith("proj.weight") or pn.endswith("fc2.weight"):
266
+ torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layers))
267
+
268
+ def _init_weights(self, module):
269
+ """Initialize module weights with a normal distribution.
270
+
271
+ Args:
272
+ module: nn.Module to initialize.
273
+ """
274
+ if isinstance(module, nn.Linear):
275
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
276
+ if module.bias is not None:
277
+ torch.nn.init.zeros_(module.bias)
278
+ elif isinstance(module, nn.Embedding):
279
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
280
+
281
+ def forward(self, idx, targets=None):
282
+ """Forward pass of the model.
283
+
284
+ Args:
285
+ idx: Input token ids of shape (B, T).
286
+ targets: Optional target token ids of shape (B, T) for loss computation.
287
+
288
+ Returns:
289
+ Tuple of (logits, loss). Logits have shape (B, T, vocab_size).
290
+ Loss is returned only if targets are provided.
291
+
292
+ Raises:
293
+ AssertionError: If sequence length exceeds context_length.
294
+ """
295
+ B, T = idx.size()
296
+ assert T <= self.config.context_length, (
297
+ f"Sequence length {T} exceeds context length {self.config.context_length}"
298
+ )
299
+
300
+ # Token + positional embeddings
301
+ tok_emb = self.token_embedding(idx)
302
+ if self.config.use_rope:
303
+ x = self.dropout(tok_emb)
304
+ else:
305
+ pos = torch.arange(0, T, dtype=torch.long, device=idx.device)
306
+ pos_emb = self.pos_embedding(pos)
307
+ x = self.dropout(tok_emb + pos_emb)
308
+
309
+ # Transformer blocks
310
+ for block in self.blocks:
311
+ x = block(x)
312
+
313
+ # Final layer norm
314
+ x = self.ln_f(x)
315
+
316
+ # Output logits
317
+ logits = self.lm_head(x)
318
+
319
+ loss = None
320
+ if targets is not None:
321
+ loss = F.cross_entropy(
322
+ logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1
323
+ )
324
+
325
+ return logits, loss
326
+
327
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
328
+ """Autoregressively generate new tokens.
329
+
330
+ Args:
331
+ idx: Input token ids of shape (B, T) serving as the prompt.
332
+ max_new_tokens: Number of tokens to generate.
333
+ temperature: Sampling temperature (higher = more random).
334
+ top_k: If set, only sample from the top-k most likely tokens.
335
+
336
+ Returns:
337
+ Tensor of shape (B, T + max_new_tokens) with generated tokens.
338
+ """
339
+ for _ in range(max_new_tokens):
340
+ # Crop context if needed
341
+ idx_cond = idx[:, -self.config.context_length:]
342
+
343
+ # Forward pass
344
+ logits, _ = self(idx_cond)
345
+
346
+ # Take logits at the last position
347
+ logits = logits[:, -1, :] / temperature
348
+
349
+ # Optional top-k filtering
350
+ if top_k is not None:
351
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
352
+ logits[logits < v[:, [-1]]] = float("-inf")
353
+
354
+ # Apply softmax and sample
355
+ probs = F.softmax(logits, dim=-1)
356
+ idx_next = torch.multinomial(probs, num_samples=1)
357
+
358
+ # Append to sequence
359
+ idx = torch.cat((idx, idx_next), dim=1)
360
+
361
+ return idx
362
+
363
+ def num_params(self) -> int:
364
+ """Return the total number of trainable parameters.
365
+
366
+ Returns:
367
+ Integer count of trainable parameters in the model.
368
+ """
369
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
370
+
371
+ def configure_optimizers(self, config: GhostLMConfig):
372
+ """Create an AdamW optimizer with weight decay separation.
373
+
374
+ Separates parameters into two groups: those that should receive
375
+ weight decay (linear weights) and those that should not
376
+ (biases, LayerNorm weights, embeddings).
377
+
378
+ Args:
379
+ config: GhostLMConfig containing learning_rate, betas, and weight_decay.
380
+
381
+ Returns:
382
+ torch.optim.AdamW optimizer with properly configured parameter groups.
383
+ """
384
+ decay = set()
385
+ no_decay = set()
386
+
387
+ whitelist = (nn.Linear,)
388
+ blacklist = (nn.LayerNorm, nn.Embedding)
389
+
390
+ for mn, m in self.named_modules():
391
+ for pn, p in m.named_parameters():
392
+ fpn = f"{mn}.{pn}" if mn else pn
393
+
394
+ if pn.endswith("bias"):
395
+ no_decay.add(fpn)
396
+ elif pn.endswith("weight") and isinstance(m, whitelist):
397
+ decay.add(fpn)
398
+ elif pn.endswith("weight") and isinstance(m, blacklist):
399
+ no_decay.add(fpn)
400
+
401
+ # Remove lm_head.weight from decay if present — it is tied to token_embedding.weight
402
+ decay.discard("lm_head.weight")
403
+ no_decay.discard("lm_head.weight")
404
+
405
+ # Validate all parameters are accounted for (excluding tied weight)
406
+ param_dict = {pn: p for pn, p in self.named_parameters()}
407
+ all_params = decay | no_decay
408
+ uncategorized = {k for k in param_dict.keys() if k not in all_params and k != "lm_head.weight"}
409
+ assert len(uncategorized) == 0, f"Parameters {uncategorized} not categorized"
410
+
411
+ optim_groups = [
412
+ {
413
+ "params": [param_dict[pn] for pn in sorted(decay)],
414
+ "weight_decay": config.weight_decay,
415
+ },
416
+ {
417
+ "params": [param_dict[pn] for pn in sorted(no_decay)],
418
+ "weight_decay": 0.0,
419
+ },
420
+ ]
421
+
422
+ optimizer = torch.optim.AdamW(
423
+ optim_groups,
424
+ lr=config.learning_rate,
425
+ betas=(config.beta1, config.beta2),
426
+ )
427
+
428
+ return optimizer
ghostlm/tokenizer.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GhostLM tokenizer — wraps tiktoken's GPT-2 BPE tokenizer with cybersecurity-aware utilities."""
2
+
3
+ import json
4
+ import os
5
+ from pathlib import Path
6
+ from typing import List, Optional, Union
7
+
8
+ import tiktoken
9
+ import torch
10
+
11
+
12
+ class GhostTokenizer:
13
+ """Wrapper around tiktoken GPT-2 BPE tokenizer with GhostLM utilities.
14
+
15
+ Provides encoding, decoding, batching, padding, and text chunking
16
+ utilities tailored for cybersecurity document processing.
17
+ """
18
+
19
+ # Special token strings
20
+ BOS = "<|ghost_bos|>"
21
+ EOS = "<|ghost_eos|>"
22
+ PAD = "<|ghost_pad|>"
23
+ UNK = "<|ghost_unk|>"
24
+ # Chat role markers (added in v0.5 chat-tuning) — IDs appended after the
25
+ # original four so pre-chat checkpoints can be expanded by 3 rows rather
26
+ # than reshuffled.
27
+ USER = "<|ghost_user|>"
28
+ ASSISTANT = "<|ghost_assistant|>"
29
+ END = "<|ghost_end|>"
30
+
31
+ def __init__(self):
32
+ """Initialize the GhostTokenizer with the GPT-2 BPE encoding.
33
+
34
+ Loads the tiktoken gpt2 encoding and assigns special token IDs
35
+ beyond the standard vocabulary for begin-of-sequence, end-of-sequence,
36
+ padding, unknown, and chat role markers.
37
+ """
38
+ self._encoder = tiktoken.get_encoding("gpt2")
39
+ self._vocab_size = self._encoder.n_vocab
40
+
41
+ # Assign special token IDs beyond the base vocabulary
42
+ self._special_tokens = {
43
+ self.BOS: self._vocab_size,
44
+ self.EOS: self._vocab_size + 1,
45
+ self.PAD: self._vocab_size + 2,
46
+ self.UNK: self._vocab_size + 3,
47
+ self.USER: self._vocab_size + 4,
48
+ self.ASSISTANT: self._vocab_size + 5,
49
+ self.END: self._vocab_size + 6,
50
+ }
51
+
52
+ # Reverse mapping for quick lookup
53
+ self._id_to_special = {v: k for k, v in self._special_tokens.items()}
54
+
55
+ @property
56
+ def vocab_size(self) -> int:
57
+ """Return the effective vocabulary size including special tokens.
58
+
59
+ Returns:
60
+ Total vocabulary size (base vocab + 7 special tokens).
61
+ """
62
+ return self._vocab_size + len(self._special_tokens)
63
+
64
+ def _special_token_ids(self) -> set:
65
+ """Return a set of all special token IDs.
66
+
67
+ Returns:
68
+ Set of integer token IDs reserved for special tokens.
69
+ """
70
+ return set(self._special_tokens.values())
71
+
72
+ def encode(self, text: str, add_bos: bool = False, add_eos: bool = False) -> List[int]:
73
+ """Encode a text string into a list of token IDs.
74
+
75
+ Args:
76
+ text: Input text to encode.
77
+ add_bos: If True, prepend the BOS token ID.
78
+ add_eos: If True, append the EOS token ID.
79
+
80
+ Returns:
81
+ List of integer token IDs.
82
+ """
83
+ ids = self._encoder.encode(text, allowed_special="all")
84
+
85
+ if add_bos:
86
+ ids = [self._special_tokens[self.BOS]] + ids
87
+ if add_eos:
88
+ ids = ids + [self._special_tokens[self.EOS]]
89
+
90
+ return ids
91
+
92
+ def decode(self, ids: List[int], skip_special: bool = True) -> str:
93
+ """Decode a list of token IDs back into a text string.
94
+
95
+ Args:
96
+ ids: List of integer token IDs to decode.
97
+ skip_special: If True, filter out special token IDs before decoding.
98
+
99
+ Returns:
100
+ Decoded text string.
101
+ """
102
+ if skip_special:
103
+ special_ids = self._special_token_ids()
104
+ ids = [i for i in ids if i not in special_ids]
105
+
106
+ return self._encoder.decode(ids)
107
+
108
+ def encode_chat(self, turns: List[dict]) -> tuple:
109
+ """Encode a multi-turn chat conversation with role markers and a loss mask.
110
+
111
+ Format: <|ghost_user|>{content}<|ghost_end|><|ghost_assistant|>{content}<|ghost_end|>...
112
+ The loss mask is 1 on assistant content tokens and the assistant's trailing
113
+ <|ghost_end|> (so the model learns to stop), and 0 everywhere else (user
114
+ prompts and role markers themselves).
115
+
116
+ Args:
117
+ turns: List of {"role": "user"|"assistant", "content": str} dicts,
118
+ strictly alternating starting with "user".
119
+
120
+ Returns:
121
+ Tuple (token_ids, loss_mask) — same length, both lists of int.
122
+ """
123
+ user_id = self._special_tokens[self.USER]
124
+ assistant_id = self._special_tokens[self.ASSISTANT]
125
+ end_id = self._special_tokens[self.END]
126
+
127
+ ids: List[int] = []
128
+ mask: List[int] = []
129
+
130
+ for turn in turns:
131
+ role = turn["role"]
132
+ content_ids = self._encoder.encode(turn["content"], allowed_special="all")
133
+ if role == "user":
134
+ ids.append(user_id)
135
+ mask.append(0)
136
+ ids.extend(content_ids)
137
+ mask.extend([0] * len(content_ids))
138
+ ids.append(end_id)
139
+ mask.append(0)
140
+ elif role == "assistant":
141
+ ids.append(assistant_id)
142
+ mask.append(0)
143
+ ids.extend(content_ids)
144
+ mask.extend([1] * len(content_ids))
145
+ ids.append(end_id)
146
+ mask.append(1)
147
+ else:
148
+ raise ValueError(f"Unknown role: {role!r}")
149
+
150
+ return ids, mask
151
+
152
+ def format_chat_prompt(self, turns: List[dict]) -> List[int]:
153
+ """Encode a chat history and append <|ghost_assistant|> ready for generation.
154
+
155
+ Used at inference: feed the resulting token ids to the model; it should
156
+ generate the assistant's reply followed by <|ghost_end|>.
157
+
158
+ Args:
159
+ turns: List of {"role": "user"|"assistant", "content": str}, ending
160
+ with a "user" turn (the prompt awaiting a reply).
161
+
162
+ Returns:
163
+ List of token IDs ending in the assistant role marker.
164
+ """
165
+ ids, _ = self.encode_chat(turns)
166
+ ids.append(self._special_tokens[self.ASSISTANT])
167
+ return ids
168
+
169
+ def encode_batch(self, texts: List[str], add_bos: bool = False, add_eos: bool = False) -> List[List[int]]:
170
+ """Encode a list of text strings into lists of token IDs.
171
+
172
+ Args:
173
+ texts: List of input text strings to encode.
174
+ add_bos: If True, prepend BOS token ID to each sequence.
175
+ add_eos: If True, append EOS token ID to each sequence.
176
+
177
+ Returns:
178
+ List of lists of integer token IDs, one per input text.
179
+ """
180
+ return [self.encode(text, add_bos=add_bos, add_eos=add_eos) for text in texts]
181
+
182
+ def to_tensor(self, ids: List[int], device: str = "cpu") -> torch.Tensor:
183
+ """Convert a list of token IDs to a PyTorch tensor.
184
+
185
+ Args:
186
+ ids: List of integer token IDs.
187
+ device: Target device for the tensor (default: "cpu").
188
+
189
+ Returns:
190
+ torch.LongTensor of shape (1, len(ids)).
191
+ """
192
+ return torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0)
193
+
194
+ def pad_batch(self, batch: List[List[int]], pad_left: bool = False) -> tuple:
195
+ """Pad a batch of token ID lists to the same length.
196
+
197
+ Pads all sequences in the batch to the length of the longest sequence
198
+ using the PAD token ID. Returns both the padded tensor and an attention
199
+ mask indicating real tokens (1) vs padding (0).
200
+
201
+ Args:
202
+ batch: List of token ID lists, each potentially different length.
203
+ pad_left: If True, pad on the left side (useful for generation).
204
+ If False, pad on the right side (default).
205
+
206
+ Returns:
207
+ Tuple of (padded_tensor, attention_mask) where:
208
+ - padded_tensor: torch.LongTensor of shape (batch_size, max_len)
209
+ - attention_mask: torch.LongTensor of shape (batch_size, max_len)
210
+ """
211
+ max_len = max(len(seq) for seq in batch)
212
+ pad_id = self._special_tokens[self.PAD]
213
+
214
+ padded = []
215
+ masks = []
216
+
217
+ for seq in batch:
218
+ pad_count = max_len - len(seq)
219
+ if pad_left:
220
+ padded_seq = [pad_id] * pad_count + seq
221
+ mask = [0] * pad_count + [1] * len(seq)
222
+ else:
223
+ padded_seq = seq + [pad_id] * pad_count
224
+ mask = [1] * len(seq) + [0] * pad_count
225
+
226
+ padded.append(padded_seq)
227
+ masks.append(mask)
228
+
229
+ padded_tensor = torch.tensor(padded, dtype=torch.long)
230
+ mask_tensor = torch.tensor(masks, dtype=torch.long)
231
+
232
+ return padded_tensor, mask_tensor
233
+
234
+ def chunk_text(self, text: str, chunk_size: int = 1024, overlap: int = 64) -> List[List[int]]:
235
+ """Encode text and split into overlapping token chunks.
236
+
237
+ Useful for processing long cybersecurity documents that exceed
238
+ the model's context length. Overlapping chunks preserve context
239
+ continuity across boundaries.
240
+
241
+ Args:
242
+ text: Input text string to chunk.
243
+ chunk_size: Maximum number of tokens per chunk.
244
+ overlap: Number of overlapping tokens between consecutive chunks.
245
+
246
+ Returns:
247
+ List of token ID lists, each of length at most chunk_size.
248
+ """
249
+ ids = self.encode(text)
250
+
251
+ if len(ids) <= chunk_size:
252
+ return [ids]
253
+
254
+ chunks = []
255
+ stride = chunk_size - overlap
256
+
257
+ for i in range(0, len(ids), stride):
258
+ chunk = ids[i : i + chunk_size]
259
+ chunks.append(chunk)
260
+ if i + chunk_size >= len(ids):
261
+ break
262
+
263
+ return chunks
264
+
265
+ def save(self, path: str) -> None:
266
+ """Save tokenizer metadata to a JSON file.
267
+
268
+ Stores vocabulary size, special token strings, and their assigned
269
+ IDs so the tokenizer can be reconstructed later.
270
+
271
+ Args:
272
+ path: File path to save the JSON metadata.
273
+ """
274
+ metadata = {
275
+ "vocab_size": self._vocab_size,
276
+ "special_tokens": self._special_tokens,
277
+ }
278
+
279
+ Path(path).parent.mkdir(parents=True, exist_ok=True)
280
+ with open(path, "w") as f:
281
+ json.dump(metadata, f, indent=2)
282
+
283
+ @classmethod
284
+ def load(cls, path: str) -> "GhostTokenizer":
285
+ """Load a GhostTokenizer from saved metadata JSON.
286
+
287
+ Reconstructs the tokenizer by reading special token assignments
288
+ from the saved metadata file.
289
+
290
+ Args:
291
+ path: File path to the saved JSON metadata.
292
+
293
+ Returns:
294
+ GhostTokenizer instance loaded with the saved configuration.
295
+ """
296
+ with open(path, "r") as f:
297
+ metadata = json.load(f)
298
+
299
+ tokenizer = cls()
300
+
301
+ # Restore special token mappings
302
+ tokenizer._special_tokens = {k: int(v) for k, v in metadata["special_tokens"].items()}
303
+ tokenizer._id_to_special = {v: k for k, v in tokenizer._special_tokens.items()}
304
+
305
+ return tokenizer
306
+
307
+ def __len__(self) -> int:
308
+ """Return the effective vocabulary size.
309
+
310
+ Returns:
311
+ Integer count of tokens including special tokens.
312
+ """
313
+ return self.vocab_size
314
+
315
+ def __repr__(self) -> str:
316
+ """Return a concise string representation of the tokenizer.
317
+
318
+ Returns:
319
+ String like: GhostTokenizer(vocab_size=50261, special_tokens=4)
320
+ """
321
+ return f"GhostTokenizer(vocab_size={self.vocab_size}, special_tokens={len(self._special_tokens)})"
ghostlm/trainer.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GhostLM trainer — handles the full training loop, evaluation, checkpointing, and logging."""
2
+
3
+ import json
4
+ import math
5
+ import os
6
+ import time
7
+ from dataclasses import asdict
8
+ from pathlib import Path
9
+ from typing import Dict, Optional, Tuple
10
+
11
+ import torch
12
+ from tqdm import tqdm
13
+
14
+ from ghostlm.config import GhostLMConfig
15
+ from ghostlm.model import GhostLM
16
+
17
+
18
+ class GhostTrainer:
19
+ """Manages the GhostLM training loop with evaluation, checkpointing, and logging.
20
+
21
+ Handles device placement, optimizer setup, cosine learning rate scheduling
22
+ with warmup, gradient clipping, periodic evaluation, checkpoint saving,
23
+ and JSON-based training log persistence. Supports mixed precision (AMP)
24
+ training on CUDA devices for faster throughput and lower memory usage.
25
+ """
26
+
27
+ def __init__(self, model: GhostLM, config: GhostLMConfig, use_amp: Optional[bool] = None):
28
+ """Initialize the trainer.
29
+
30
+ Args:
31
+ model: GhostLM model instance to train.
32
+ config: GhostLMConfig with training hyperparameters and paths.
33
+ use_amp: Enable mixed precision (AMP) training. Defaults to True
34
+ when running on CUDA, False otherwise. AMP is only supported
35
+ on CUDA devices — setting True on CPU/MPS will be ignored.
36
+ """
37
+ self.model = model
38
+ self.config = config
39
+
40
+ # Resolve device
41
+ if config.device == "auto":
42
+ if torch.cuda.is_available():
43
+ self.device = "cuda"
44
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
45
+ self.device = "mps"
46
+ else:
47
+ self.device = "cpu"
48
+ else:
49
+ self.device = config.device
50
+
51
+ self.model = self.model.to(self.device)
52
+
53
+ # Mixed precision (AMP) — only effective on CUDA
54
+ if use_amp is None:
55
+ self.use_amp = self.device == "cuda"
56
+ else:
57
+ self.use_amp = use_amp and self.device == "cuda"
58
+
59
+ self.grad_scaler = torch.amp.GradScaler("cuda", enabled=self.use_amp)
60
+
61
+ # Optimizer
62
+ self.optimizer = self.model.configure_optimizers(config)
63
+
64
+ # Create directories
65
+ self.checkpoint_dir = Path(config.checkpoint_dir)
66
+ self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
67
+
68
+ self.log_dir = Path(config.log_dir)
69
+ self.log_dir.mkdir(parents=True, exist_ok=True)
70
+
71
+ # State
72
+ self.step = 0
73
+ self.accum_steps = getattr(config, 'grad_accum_steps', 4)
74
+ self.best_val_loss = float("inf")
75
+ self.log: list = []
76
+
77
+ def get_lr(self) -> float:
78
+ """Compute the current learning rate using cosine decay with linear warmup.
79
+
80
+ During the warmup phase (step < warmup_steps), the learning rate scales
81
+ linearly from 0 to config.learning_rate. After warmup, it follows a
82
+ cosine decay schedule down to a minimum of 1e-5.
83
+
84
+ Returns:
85
+ Current learning rate as a float.
86
+ """
87
+ step = self.step
88
+ warmup = self.config.warmup_steps
89
+ max_steps = self.config.max_steps
90
+ base_lr = self.config.learning_rate
91
+ min_lr = 1e-5
92
+
93
+ if step < warmup:
94
+ return base_lr * (step + 1) / warmup
95
+
96
+ decay_ratio = (step - warmup) / max(1, max_steps - warmup)
97
+ decay_ratio = min(decay_ratio, 1.0)
98
+
99
+ cosine_decay = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
100
+ return min_lr + (base_lr - min_lr) * cosine_decay
101
+
102
+ def _set_lr(self) -> None:
103
+ """Apply the current learning rate from get_lr() to all optimizer parameter groups."""
104
+ lr = self.get_lr()
105
+ for group in self.optimizer.param_groups:
106
+ group["lr"] = lr
107
+
108
+ def train_step(self, batch: Tuple[torch.Tensor, torch.Tensor]) -> float:
109
+ """Execute a single training step with gradient accumulation and optional AMP.
110
+
111
+ Accumulates gradients over self.accum_steps micro-steps before
112
+ updating weights, effectively multiplying the batch size without
113
+ increasing memory usage. When AMP is enabled, the forward pass runs
114
+ in float16 and the GradScaler handles loss scaling for stable training.
115
+
116
+ Args:
117
+ batch: Tuple of (input_ids, target_ids) tensors.
118
+
119
+ Returns:
120
+ Training loss as a float.
121
+ """
122
+ x, y = batch
123
+ x = x.to(self.device)
124
+ y = y.to(self.device)
125
+
126
+ self.model.train()
127
+
128
+ # Split batch into micro-batches for gradient accumulation
129
+ micro_x = x.split(max(1, x.size(0) // self.accum_steps), dim=0)
130
+ micro_y = y.split(max(1, y.size(0) // self.accum_steps), dim=0)
131
+
132
+ total_loss = 0.0
133
+
134
+ for mx, my in zip(micro_x, micro_y):
135
+ with torch.amp.autocast("cuda", enabled=self.use_amp):
136
+ _, loss = self.model(mx, targets=my)
137
+ # Scale loss by number of accumulation steps
138
+ scaled_loss = loss / len(micro_x)
139
+
140
+ self.grad_scaler.scale(scaled_loss).backward()
141
+ total_loss += loss.item()
142
+
143
+ # Gradient clipping and optimizer step after accumulation
144
+ self.grad_scaler.unscale_(self.optimizer)
145
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_clip)
146
+ self.grad_scaler.step(self.optimizer)
147
+ self.grad_scaler.update()
148
+ self.optimizer.zero_grad(set_to_none=True)
149
+
150
+ self.step += 1
151
+ self._set_lr()
152
+
153
+ return total_loss / len(micro_x)
154
+
155
+ def eval_step(self, val_loader, num_batches: int = 20) -> float:
156
+ """Run evaluation over a number of validation batches.
157
+
158
+ Args:
159
+ val_loader: DataLoader yielding (input_ids, target_ids) batches.
160
+ num_batches: Maximum number of batches to evaluate over.
161
+
162
+ Returns:
163
+ Average validation loss as a float.
164
+ """
165
+ self.model.eval()
166
+ total_loss = 0.0
167
+ count = 0
168
+
169
+ with torch.no_grad():
170
+ for i, batch in enumerate(val_loader):
171
+ if i >= num_batches:
172
+ break
173
+ x, y = batch
174
+ x = x.to(self.device)
175
+ y = y.to(self.device)
176
+
177
+ with torch.amp.autocast("cuda", enabled=self.use_amp):
178
+ _, loss = self.model(x, targets=y)
179
+ total_loss += loss.item()
180
+ count += 1
181
+
182
+ return total_loss / max(count, 1)
183
+
184
+ def save_checkpoint(self, val_loss: float) -> None:
185
+ """Save a model checkpoint to disk.
186
+
187
+ Saves the current step, validation loss, model state dict, optimizer
188
+ state dict, and config. Also saves as "best_model.pt" if the current
189
+ validation loss is the best seen so far.
190
+
191
+ Args:
192
+ val_loss: Current validation loss for comparison.
193
+ """
194
+ checkpoint = {
195
+ "step": self.step,
196
+ "val_loss": val_loss,
197
+ "model_state_dict": self.model.state_dict(),
198
+ "optimizer_state_dict": self.optimizer.state_dict(),
199
+ "grad_scaler_state_dict": self.grad_scaler.state_dict(),
200
+ "config": asdict(self.config),
201
+ }
202
+
203
+ filename = f"checkpoint_step_{self.step}.pt"
204
+ path = self.checkpoint_dir / filename
205
+ torch.save(checkpoint, path)
206
+ print(f" Saved checkpoint: {path}")
207
+
208
+ if val_loss < self.best_val_loss:
209
+ self.best_val_loss = val_loss
210
+ best_path = self.checkpoint_dir / "best_model.pt"
211
+ torch.save(checkpoint, best_path)
212
+ print(f" New best model saved: {best_path} (val_loss={val_loss:.4f})")
213
+
214
+ def load_checkpoint(self, path: str) -> None:
215
+ """Load a model checkpoint from disk.
216
+
217
+ Restores the model state dict, optimizer state dict, training step,
218
+ and best validation loss from the saved checkpoint file.
219
+
220
+ Args:
221
+ path: File path to the checkpoint .pt file.
222
+ """
223
+ checkpoint = torch.load(path, map_location=self.device, weights_only=False)
224
+
225
+ self.model.load_state_dict(checkpoint["model_state_dict"])
226
+ self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
227
+ if "grad_scaler_state_dict" in checkpoint:
228
+ self.grad_scaler.load_state_dict(checkpoint["grad_scaler_state_dict"])
229
+ self.step = checkpoint["step"]
230
+ self.best_val_loss = checkpoint["val_loss"]
231
+
232
+ print(f"Loaded checkpoint from step {self.step} (val_loss={self.best_val_loss:.4f})")
233
+
234
+ def _log(self, data: dict) -> None:
235
+ """Append a data dict to the training log and persist as JSON.
236
+
237
+ Args:
238
+ data: Dictionary of metrics and metadata to log.
239
+ """
240
+ self.log.append(data)
241
+ log_path = self.log_dir / "training_log.json"
242
+ with open(log_path, "w") as f:
243
+ json.dump(self.log, f, indent=2)
244
+
245
+ def train(self, train_loader, val_loader) -> None:
246
+ """Run the main training loop.
247
+
248
+ Iterates from the current step to config.max_steps, performing training
249
+ steps with a tqdm progress bar. Evaluates periodically at config.eval_interval
250
+ and saves checkpoints at config.save_interval. Performs a final evaluation
251
+ and saves the final checkpoint at the end of training.
252
+
253
+ Args:
254
+ train_loader: DataLoader yielding (input_ids, target_ids) training batches.
255
+ val_loader: DataLoader yielding (input_ids, target_ids) validation batches.
256
+ """
257
+ print(f"Training on device: {self.device}")
258
+ print(f"Mixed precision (AMP): {'enabled' if self.use_amp else 'disabled'}")
259
+ print(f"Model size: {self.model.num_params():,} parameters")
260
+ print(f"Training from step {self.step} to {self.config.max_steps}")
261
+
262
+ # Create iterator that cycles through train_loader
263
+ def cycle(loader):
264
+ while True:
265
+ for batch in loader:
266
+ yield batch
267
+
268
+ train_iter = cycle(train_loader)
269
+
270
+ with tqdm(initial=self.step, total=self.config.max_steps, desc="Training") as pbar:
271
+ while self.step < self.config.max_steps:
272
+ t0 = time.time()
273
+
274
+ # Training step
275
+ batch = next(train_iter)
276
+ loss = self.train_step(batch)
277
+
278
+ dt = time.time() - t0
279
+ lr = self.get_lr()
280
+
281
+ pbar.set_postfix(loss=f"{loss:.4f}", lr=f"{lr:.2e}", dt=f"{dt:.3f}s")
282
+ pbar.update(1)
283
+
284
+ # Periodic evaluation
285
+ if self.step % self.config.eval_interval == 0:
286
+ val_loss = self.eval_step(val_loader)
287
+ print(f"\n Step {self.step} | val_loss={val_loss:.4f} | train_loss={loss:.4f}")
288
+
289
+ self._log({
290
+ "step": self.step,
291
+ "train_loss": loss,
292
+ "val_loss": val_loss,
293
+ "lr": lr,
294
+ "time": dt,
295
+ })
296
+
297
+ # Periodic checkpoint
298
+ if self.step % self.config.save_interval == 0:
299
+ val_loss = self.eval_step(val_loader)
300
+ self.save_checkpoint(val_loss)
301
+
302
+ # Final evaluation and checkpoint
303
+ print("\nTraining complete. Running final evaluation...")
304
+ val_loss = self.eval_step(val_loader)
305
+ print(f"Final val_loss: {val_loss:.4f}")
306
+ self.save_checkpoint(val_loss)
307
+
308
+ self._log({
309
+ "step": self.step,
310
+ "train_loss": loss,
311
+ "val_loss": val_loss,
312
+ "lr": lr,
313
+ "time": dt,
314
+ "status": "complete",
315
+ })
316
+
317
+ print(f"Training log saved to {self.log_dir / 'training_log.json'}")
requirements.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hugging Face Spaces installs from this file at build time.
2
+ # Pinned conservatively so a Space build doesn't regress on a future
3
+ # breaking change in any of the deps.
4
+
5
+ # Note: gradio is intentionally NOT listed here. HF Spaces auto-installs
6
+ # `gradio[oauth,mcp]==<sdk_version>` on top of this file based on the SDK
7
+ # selection in README.md frontmatter. Listing it here causes a pip
8
+ # version-conflict at build time when our pin disagrees with HF's.
9
+
10
+ # torch >= 2.0 for the scaled_dot_product_attention path. CPU-only is
11
+ # fine on free Spaces.
12
+ torch>=2.0.0
13
+
14
+ # tiktoken is the GPT-2 BPE backend the GhostTokenizer wraps.
15
+ tiktoken>=0.5.0
16
+
17
+ # Python 3.13 removed the stdlib audioop module that gradio's transitive
18
+ # pydub dep imports at module-load time. Without this the entire gradio
19
+ # import chain fails with ModuleNotFoundError: No module named
20
+ # 'pyaudioop'. The PEP 594 replacement is audioop-lts. Conditional so
21
+ # 3.12 and earlier (where stdlib audioop still exists) skip it.
22
+ audioop-lts; python_version >= '3.13'