fix(model): sync ghostlm package to v0.9 (SwiGLU + RMSNorm + RoPE)
Browse filesThe Space's bundled ghostlm/ package was the v0.4-era model.py without
SwiGLU / RMSNorm support, so loading the v0.9 chat checkpoint failed at
startup with size mismatches on every blocks.*.ffn.* weight: the v0.9
checkpoint stores SwiGLU-compressed FFN weights at hidden=2048
(int(d_ff * 2/3) rounded to a multiple of 64), but the old GELU
FeedForward in the Space's model.py allocated full d_ff=3072.
Syncs five files from the GhostLM main branch:
- ghostlm/__init__.py (re-exports + version)
- ghostlm/config.py (use_rope / use_swiglu / use_rmsnorm flags)
- ghostlm/model.py (SwiGLU class, RMSNorm class, RoPE rotary embed)
- ghostlm/tokenizer.py (50264 vocab + 7 special tokens, chat roles)
- ghostlm/trainer.py (kept aligned even though Space does not train)
- ghostlm/__init__.py +1 -1
- ghostlm/config.py +14 -0
- ghostlm/model.py +75 -7
- ghostlm/tokenizer.py +127 -0
- ghostlm/trainer.py +53 -6
|
@@ -6,7 +6,7 @@ from ghostlm.tokenizer import GhostTokenizer
|
|
| 6 |
from ghostlm.dataset import GhostDataset, build_dataloaders
|
| 7 |
from ghostlm.trainer import GhostTrainer
|
| 8 |
|
| 9 |
-
__version__ = "0.
|
| 10 |
__author__ = "Joe Munene"
|
| 11 |
|
| 12 |
__all__ = [
|
|
|
|
| 6 |
from ghostlm.dataset import GhostDataset, build_dataloaders
|
| 7 |
from ghostlm.trainer import GhostTrainer
|
| 8 |
|
| 9 |
+
__version__ = "0.5.0"
|
| 10 |
__author__ = "Joe Munene"
|
| 11 |
|
| 12 |
__all__ = [
|
|
@@ -21,6 +21,9 @@ class GhostLMConfig:
|
|
| 21 |
dropout: float = 0.1
|
| 22 |
bias: bool = True
|
| 23 |
use_rope: bool = False
|
|
|
|
|
|
|
|
|
|
| 24 |
use_flash_attention: bool = False
|
| 25 |
|
| 26 |
# Training
|
|
@@ -107,6 +110,17 @@ class GhostLMConfig:
|
|
| 107 |
"n_heads": 12,
|
| 108 |
"d_ff": 3072,
|
| 109 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
}
|
| 111 |
|
| 112 |
if preset not in presets:
|
|
|
|
| 21 |
dropout: float = 0.1
|
| 22 |
bias: bool = True
|
| 23 |
use_rope: bool = False
|
| 24 |
+
rope_base: float = 10000.0
|
| 25 |
+
use_swiglu: bool = False
|
| 26 |
+
use_rmsnorm: bool = False
|
| 27 |
use_flash_attention: bool = False
|
| 28 |
|
| 29 |
# Training
|
|
|
|
| 110 |
"n_heads": 12,
|
| 111 |
"d_ff": 3072,
|
| 112 |
},
|
| 113 |
+
# v0.5 preset — same param shape as ghost-small but flips on
|
| 114 |
+
# the modern-arch switches. Use this for the v0.4.2 retrain.
|
| 115 |
+
"ghost-small-v0.5": {
|
| 116 |
+
"n_layers": 6,
|
| 117 |
+
"d_model": 512,
|
| 118 |
+
"n_heads": 8,
|
| 119 |
+
"d_ff": 2048,
|
| 120 |
+
"use_rope": True,
|
| 121 |
+
"use_swiglu": True,
|
| 122 |
+
"use_rmsnorm": True,
|
| 123 |
+
},
|
| 124 |
}
|
| 125 |
|
| 126 |
if preset not in presets:
|
|
@@ -58,6 +58,33 @@ def apply_rotary_pos_emb(q, k, cos, sin):
|
|
| 58 |
return q_embed, k_embed
|
| 59 |
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
class CausalSelfAttention(nn.Module):
|
| 62 |
"""Multi-head causal self-attention with autoregressive masking.
|
| 63 |
|
|
@@ -187,6 +214,40 @@ class FeedForward(nn.Module):
|
|
| 187 |
return x
|
| 188 |
|
| 189 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
class TransformerBlock(nn.Module):
|
| 191 |
"""Single transformer decoder block with pre-normalization.
|
| 192 |
|
|
@@ -199,13 +260,15 @@ class TransformerBlock(nn.Module):
|
|
| 199 |
"""Initialize the transformer block.
|
| 200 |
|
| 201 |
Args:
|
| 202 |
-
config: GhostLMConfig passed to sub-modules.
|
|
|
|
|
|
|
| 203 |
"""
|
| 204 |
super().__init__()
|
| 205 |
-
self.ln_1 =
|
| 206 |
self.attn = CausalSelfAttention(config)
|
| 207 |
-
self.ln_2 =
|
| 208 |
-
self.ffn =
|
| 209 |
|
| 210 |
def forward(self, x):
|
| 211 |
"""Forward pass through the transformer block.
|
|
@@ -250,8 +313,8 @@ class GhostLM(nn.Module):
|
|
| 250 |
[TransformerBlock(config) for _ in range(config.n_layers)]
|
| 251 |
)
|
| 252 |
|
| 253 |
-
# Final layer norm
|
| 254 |
-
self.ln_f =
|
| 255 |
|
| 256 |
# Output head with weight tying (no bias)
|
| 257 |
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
|
@@ -385,7 +448,12 @@ class GhostLM(nn.Module):
|
|
| 385 |
no_decay = set()
|
| 386 |
|
| 387 |
whitelist = (nn.Linear,)
|
| 388 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
|
| 390 |
for mn, m in self.named_modules():
|
| 391 |
for pn, p in m.named_parameters():
|
|
|
|
| 58 |
return q_embed, k_embed
|
| 59 |
|
| 60 |
|
| 61 |
+
class RMSNorm(nn.Module):
|
| 62 |
+
"""Root-mean-square layer normalization (LLaMA-style, no mean subtraction).
|
| 63 |
+
|
| 64 |
+
Used by Llama-2 / Llama-3 / Mistral / Gemma — half the params of LayerNorm
|
| 65 |
+
and matches its quality at this scale per the 2024 - 2026 small-LM
|
| 66 |
+
literature. Toggled via ``GhostLMConfig.use_rmsnorm``.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
| 70 |
+
"""Initialize a learned scale vector of shape (dim,)."""
|
| 71 |
+
super().__init__()
|
| 72 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 73 |
+
self.eps = eps
|
| 74 |
+
|
| 75 |
+
def forward(self, x):
|
| 76 |
+
"""Normalize by RMS along the last dim, then scale."""
|
| 77 |
+
norm = x.float() * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
|
| 78 |
+
return (norm * self.weight).to(x.dtype)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def make_norm(config: GhostLMConfig, dim: int) -> nn.Module:
|
| 82 |
+
"""Return RMSNorm or LayerNorm based on ``config.use_rmsnorm``."""
|
| 83 |
+
if getattr(config, "use_rmsnorm", False):
|
| 84 |
+
return RMSNorm(dim)
|
| 85 |
+
return nn.LayerNorm(dim)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
class CausalSelfAttention(nn.Module):
|
| 89 |
"""Multi-head causal self-attention with autoregressive masking.
|
| 90 |
|
|
|
|
| 214 |
return x
|
| 215 |
|
| 216 |
|
| 217 |
+
class SwiGLU(nn.Module):
|
| 218 |
+
"""SwiGLU feed-forward — Llama / Mistral / Gemma style gated FFN.
|
| 219 |
+
|
| 220 |
+
Two parallel projections from d_model to a 2/3 d_ff hidden, gated through
|
| 221 |
+
SiLU. Matches GELU's parameter budget (we shrink the hidden dim by 2/3 to
|
| 222 |
+
compensate for the extra projection) but reliably wins by 1-2 nat-loss in
|
| 223 |
+
sub-1B comparisons. Toggled via ``GhostLMConfig.use_swiglu``.
|
| 224 |
+
"""
|
| 225 |
+
|
| 226 |
+
def __init__(self, config: GhostLMConfig):
|
| 227 |
+
"""Initialize the gated FFN with three linear projections (no bias)."""
|
| 228 |
+
super().__init__()
|
| 229 |
+
# Shrink hidden dim to keep total parameter count comparable to the
|
| 230 |
+
# GELU FeedForward at the same d_ff (which has 2 projections vs our 3).
|
| 231 |
+
hidden = int(config.d_ff * 2 / 3)
|
| 232 |
+
# Round to a multiple of 64 so MPS / CUDA matmul shapes stay friendly.
|
| 233 |
+
hidden = (hidden + 63) // 64 * 64
|
| 234 |
+
self.fc1 = nn.Linear(config.d_model, hidden, bias=False)
|
| 235 |
+
self.fc2 = nn.Linear(config.d_model, hidden, bias=False)
|
| 236 |
+
self.fc3 = nn.Linear(hidden, config.d_model, bias=False)
|
| 237 |
+
self.dropout = nn.Dropout(config.dropout)
|
| 238 |
+
|
| 239 |
+
def forward(self, x):
|
| 240 |
+
"""fc3(SiLU(fc1(x)) * fc2(x))."""
|
| 241 |
+
return self.dropout(self.fc3(F.silu(self.fc1(x)) * self.fc2(x)))
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def make_ffn(config: GhostLMConfig) -> nn.Module:
|
| 245 |
+
"""Return SwiGLU or FeedForward based on ``config.use_swiglu``."""
|
| 246 |
+
if getattr(config, "use_swiglu", False):
|
| 247 |
+
return SwiGLU(config)
|
| 248 |
+
return FeedForward(config)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
class TransformerBlock(nn.Module):
|
| 252 |
"""Single transformer decoder block with pre-normalization.
|
| 253 |
|
|
|
|
| 260 |
"""Initialize the transformer block.
|
| 261 |
|
| 262 |
Args:
|
| 263 |
+
config: GhostLMConfig passed to sub-modules. Switches between
|
| 264 |
+
LayerNorm / RMSNorm and FeedForward / SwiGLU based on the
|
| 265 |
+
``use_rmsnorm`` and ``use_swiglu`` flags.
|
| 266 |
"""
|
| 267 |
super().__init__()
|
| 268 |
+
self.ln_1 = make_norm(config, config.d_model)
|
| 269 |
self.attn = CausalSelfAttention(config)
|
| 270 |
+
self.ln_2 = make_norm(config, config.d_model)
|
| 271 |
+
self.ffn = make_ffn(config)
|
| 272 |
|
| 273 |
def forward(self, x):
|
| 274 |
"""Forward pass through the transformer block.
|
|
|
|
| 313 |
[TransformerBlock(config) for _ in range(config.n_layers)]
|
| 314 |
)
|
| 315 |
|
| 316 |
+
# Final layer norm — RMSNorm or LayerNorm depending on config.
|
| 317 |
+
self.ln_f = make_norm(config, config.d_model)
|
| 318 |
|
| 319 |
# Output head with weight tying (no bias)
|
| 320 |
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
|
|
|
| 448 |
no_decay = set()
|
| 449 |
|
| 450 |
whitelist = (nn.Linear,)
|
| 451 |
+
# RMSNorm is custom (defined in this module), so we include it in the
|
| 452 |
+
# blacklist by class — its `.weight` should be no-decay just like
|
| 453 |
+
# LayerNorm's. Without this, the v0.5 ghost-small-v0.5 preset
|
| 454 |
+
# crashes at optimizer setup with every block's ln_*.weight
|
| 455 |
+
# uncategorized.
|
| 456 |
+
blacklist = (nn.LayerNorm, nn.Embedding, RMSNorm)
|
| 457 |
|
| 458 |
for mn, m in self.named_modules():
|
| 459 |
for pn, p in m.named_parameters():
|
|
@@ -319,3 +319,130 @@ class GhostTokenizer:
|
|
| 319 |
String like: GhostTokenizer(vocab_size=50261, special_tokens=4)
|
| 320 |
"""
|
| 321 |
return f"GhostTokenizer(vocab_size={self.vocab_size}, special_tokens={len(self._special_tokens)})"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
String like: GhostTokenizer(vocab_size=50261, special_tokens=4)
|
| 320 |
"""
|
| 321 |
return f"GhostTokenizer(vocab_size={self.vocab_size}, special_tokens={len(self._special_tokens)})"
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
class GhostTokenizerV05:
|
| 325 |
+
"""v0.5 tokenizer — domain-trained 32K BPE via HuggingFace tokenizers.
|
| 326 |
+
|
| 327 |
+
Drop-in replacement for ``GhostTokenizer`` with the same API surface
|
| 328 |
+
(`encode`, `decode`, `encode_chat`, `format_chat_prompt`, `vocab_size`,
|
| 329 |
+
`_special_tokens`) so the existing dataset / trainer / chat code paths
|
| 330 |
+
work unchanged. The seven GhostLM special tokens land at the start of
|
| 331 |
+
the vocab (IDs 0-6) — different from v0.4's tail placement, but the
|
| 332 |
+
chat-format machinery only cares about the name -> ID mapping, not
|
| 333 |
+
the absolute IDs.
|
| 334 |
+
|
| 335 |
+
Trained by ``scripts/train_tokenizer.py``. Load via ``from_file``.
|
| 336 |
+
"""
|
| 337 |
+
|
| 338 |
+
# Same special-token names as the legacy tokenizer.
|
| 339 |
+
BOS = GhostTokenizer.BOS
|
| 340 |
+
EOS = GhostTokenizer.EOS
|
| 341 |
+
PAD = GhostTokenizer.PAD
|
| 342 |
+
UNK = GhostTokenizer.UNK
|
| 343 |
+
USER = GhostTokenizer.USER
|
| 344 |
+
ASSISTANT = GhostTokenizer.ASSISTANT
|
| 345 |
+
END = GhostTokenizer.END
|
| 346 |
+
|
| 347 |
+
def __init__(self, path: str):
|
| 348 |
+
"""Load the trained tokenizer.json file."""
|
| 349 |
+
from tokenizers import Tokenizer
|
| 350 |
+
self._tok = Tokenizer.from_file(path)
|
| 351 |
+
self._vocab_size = self._tok.get_vocab_size()
|
| 352 |
+
self._special_tokens = {
|
| 353 |
+
name: self._tok.token_to_id(name)
|
| 354 |
+
for name in (
|
| 355 |
+
self.BOS, self.EOS, self.PAD, self.UNK,
|
| 356 |
+
self.USER, self.ASSISTANT, self.END,
|
| 357 |
+
)
|
| 358 |
+
}
|
| 359 |
+
self._id_to_special = {v: k for k, v in self._special_tokens.items()}
|
| 360 |
+
|
| 361 |
+
@classmethod
|
| 362 |
+
def from_file(cls, path: str) -> "GhostTokenizerV05":
|
| 363 |
+
"""Alias for the constructor — matches the GhostTokenizer.load shape."""
|
| 364 |
+
return cls(path)
|
| 365 |
+
|
| 366 |
+
@property
|
| 367 |
+
def vocab_size(self) -> int:
|
| 368 |
+
"""Return the total vocabulary size (~32,000 by default)."""
|
| 369 |
+
return self._vocab_size
|
| 370 |
+
|
| 371 |
+
def _special_token_ids(self) -> set:
|
| 372 |
+
"""Return the set of special-token integer IDs."""
|
| 373 |
+
return set(self._special_tokens.values())
|
| 374 |
+
|
| 375 |
+
def encode(self, text: str, add_bos: bool = False, add_eos: bool = False) -> List[int]:
|
| 376 |
+
"""Encode text to token IDs (matches GhostTokenizer.encode)."""
|
| 377 |
+
ids = self._tok.encode(text).ids
|
| 378 |
+
if add_bos:
|
| 379 |
+
ids = [self._special_tokens[self.BOS]] + ids
|
| 380 |
+
if add_eos:
|
| 381 |
+
ids = ids + [self._special_tokens[self.EOS]]
|
| 382 |
+
return ids
|
| 383 |
+
|
| 384 |
+
def decode(self, ids: List[int], skip_special: bool = True) -> str:
|
| 385 |
+
"""Decode token IDs to text (matches GhostTokenizer.decode)."""
|
| 386 |
+
if skip_special:
|
| 387 |
+
specials = self._special_token_ids()
|
| 388 |
+
ids = [i for i in ids if i not in specials]
|
| 389 |
+
return self._tok.decode(ids)
|
| 390 |
+
|
| 391 |
+
def encode_batch(self, texts: List[str], add_bos: bool = False, add_eos: bool = False) -> List[List[int]]:
|
| 392 |
+
"""Encode a batch of texts."""
|
| 393 |
+
return [self.encode(t, add_bos=add_bos, add_eos=add_eos) for t in texts]
|
| 394 |
+
|
| 395 |
+
def encode_chat(self, turns: List[dict]) -> tuple:
|
| 396 |
+
"""Build (token_ids, loss_mask) for a chat conversation.
|
| 397 |
+
|
| 398 |
+
Mirrors ``GhostTokenizer.encode_chat`` exactly — only the underlying
|
| 399 |
+
BPE differs.
|
| 400 |
+
"""
|
| 401 |
+
user_id = self._special_tokens[self.USER]
|
| 402 |
+
assistant_id = self._special_tokens[self.ASSISTANT]
|
| 403 |
+
end_id = self._special_tokens[self.END]
|
| 404 |
+
ids: List[int] = []
|
| 405 |
+
mask: List[int] = []
|
| 406 |
+
for turn in turns:
|
| 407 |
+
role = turn["role"]
|
| 408 |
+
content_ids = self._tok.encode(turn["content"]).ids
|
| 409 |
+
if role == "user":
|
| 410 |
+
ids.append(user_id); mask.append(0)
|
| 411 |
+
ids.extend(content_ids); mask.extend([0] * len(content_ids))
|
| 412 |
+
ids.append(end_id); mask.append(0)
|
| 413 |
+
elif role == "assistant":
|
| 414 |
+
ids.append(assistant_id); mask.append(0)
|
| 415 |
+
ids.extend(content_ids); mask.extend([1] * len(content_ids))
|
| 416 |
+
ids.append(end_id); mask.append(1)
|
| 417 |
+
else:
|
| 418 |
+
raise ValueError(f"Unknown role: {role!r}")
|
| 419 |
+
return ids, mask
|
| 420 |
+
|
| 421 |
+
def format_chat_prompt(self, turns: List[dict]) -> List[int]:
|
| 422 |
+
"""Build an inference prompt ending in <|ghost_assistant|>."""
|
| 423 |
+
ids, _ = self.encode_chat(turns)
|
| 424 |
+
ids.append(self._special_tokens[self.ASSISTANT])
|
| 425 |
+
return ids
|
| 426 |
+
|
| 427 |
+
def to_tensor(self, ids: List[int], device: str = "cpu") -> torch.Tensor:
|
| 428 |
+
"""Convert ids to a (1, T) torch.LongTensor."""
|
| 429 |
+
return torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0)
|
| 430 |
+
|
| 431 |
+
def __len__(self) -> int:
|
| 432 |
+
"""Return the vocab size."""
|
| 433 |
+
return self._vocab_size
|
| 434 |
+
|
| 435 |
+
def __repr__(self) -> str:
|
| 436 |
+
"""Concise summary."""
|
| 437 |
+
return f"GhostTokenizerV05(vocab_size={self._vocab_size}, special_tokens={len(self._special_tokens)})"
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
def load_tokenizer(path: Optional[str] = None):
|
| 441 |
+
"""Factory: return v0.5 BPE if a tokenizer.json is provided, else legacy.
|
| 442 |
+
|
| 443 |
+
Train code paths can call this with ``config.tokenizer_path`` and not
|
| 444 |
+
care about the backend.
|
| 445 |
+
"""
|
| 446 |
+
if path and Path(path).exists():
|
| 447 |
+
return GhostTokenizerV05(path)
|
| 448 |
+
return GhostTokenizer()
|
|
@@ -48,19 +48,54 @@ class GhostTrainer:
|
|
| 48 |
else:
|
| 49 |
self.device = config.device
|
| 50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
self.model = self.model.to(self.device)
|
| 52 |
|
| 53 |
-
# Mixed precision (AMP)
|
| 54 |
if use_amp is None:
|
| 55 |
-
self.use_amp = self.device
|
| 56 |
else:
|
| 57 |
-
self.use_amp = use_amp and self.device
|
| 58 |
|
| 59 |
self.grad_scaler = torch.amp.GradScaler("cuda", enabled=self.use_amp)
|
| 60 |
|
| 61 |
-
# Optimizer
|
| 62 |
self.optimizer = self.model.configure_optimizers(config)
|
| 63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
# Create directories
|
| 65 |
self.checkpoint_dir = Path(config.checkpoint_dir)
|
| 66 |
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
|
@@ -188,13 +223,23 @@ class GhostTrainer:
|
|
| 188 |
state dict, and config. Also saves as "best_model.pt" if the current
|
| 189 |
validation loss is the best seen so far.
|
| 190 |
|
|
|
|
|
|
|
|
|
|
| 191 |
Args:
|
| 192 |
val_loss: Current validation loss for comparison.
|
| 193 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
checkpoint = {
|
| 195 |
"step": self.step,
|
| 196 |
"val_loss": val_loss,
|
| 197 |
-
"model_state_dict":
|
| 198 |
"optimizer_state_dict": self.optimizer.state_dict(),
|
| 199 |
"grad_scaler_state_dict": self.grad_scaler.state_dict(),
|
| 200 |
"config": asdict(self.config),
|
|
@@ -222,7 +267,9 @@ class GhostTrainer:
|
|
| 222 |
"""
|
| 223 |
checkpoint = torch.load(path, map_location=self.device, weights_only=False)
|
| 224 |
|
| 225 |
-
|
|
|
|
|
|
|
| 226 |
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
| 227 |
if "grad_scaler_state_dict" in checkpoint:
|
| 228 |
self.grad_scaler.load_state_dict(checkpoint["grad_scaler_state_dict"])
|
|
|
|
| 48 |
else:
|
| 49 |
self.device = config.device
|
| 50 |
|
| 51 |
+
# Distributed training support (issue #8). Detect whether we are
|
| 52 |
+
# running inside torchrun / torch.distributed.launch by reading the
|
| 53 |
+
# standard env vars; if so, set the local-rank device and wrap the
|
| 54 |
+
# model in DistributedDataParallel after moving to device.
|
| 55 |
+
# Single-GPU / CPU training is the default and unchanged.
|
| 56 |
+
self.is_distributed = (
|
| 57 |
+
"RANK" in os.environ
|
| 58 |
+
and "WORLD_SIZE" in os.environ
|
| 59 |
+
and int(os.environ.get("WORLD_SIZE", "1")) > 1
|
| 60 |
+
)
|
| 61 |
+
self.local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
| 62 |
+
self.world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
| 63 |
+
self.global_rank = int(os.environ.get("RANK", "0"))
|
| 64 |
+
self.is_main_process = self.global_rank == 0
|
| 65 |
+
|
| 66 |
+
if self.is_distributed:
|
| 67 |
+
import torch.distributed as dist
|
| 68 |
+
backend = "nccl" if torch.cuda.is_available() else "gloo"
|
| 69 |
+
if not dist.is_initialized():
|
| 70 |
+
dist.init_process_group(backend=backend)
|
| 71 |
+
if torch.cuda.is_available():
|
| 72 |
+
torch.cuda.set_device(self.local_rank)
|
| 73 |
+
self.device = f"cuda:{self.local_rank}"
|
| 74 |
+
|
| 75 |
self.model = self.model.to(self.device)
|
| 76 |
|
| 77 |
+
# Mixed precision (AMP), only effective on CUDA
|
| 78 |
if use_amp is None:
|
| 79 |
+
self.use_amp = self.device.startswith("cuda")
|
| 80 |
else:
|
| 81 |
+
self.use_amp = use_amp and self.device.startswith("cuda")
|
| 82 |
|
| 83 |
self.grad_scaler = torch.amp.GradScaler("cuda", enabled=self.use_amp)
|
| 84 |
|
| 85 |
+
# Optimizer (built BEFORE wrapping in DDP so param groups see raw modules)
|
| 86 |
self.optimizer = self.model.configure_optimizers(config)
|
| 87 |
|
| 88 |
+
# DDP wrap. Each rank now sees a self.model that does the all-reduce
|
| 89 |
+
# transparently in backward(). Other code paths that touch
|
| 90 |
+
# self.model.* still work because DDP forwards attribute access.
|
| 91 |
+
if self.is_distributed:
|
| 92 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 93 |
+
ddp_kwargs = {}
|
| 94 |
+
if torch.cuda.is_available():
|
| 95 |
+
ddp_kwargs["device_ids"] = [self.local_rank]
|
| 96 |
+
ddp_kwargs["output_device"] = self.local_rank
|
| 97 |
+
self.model = DDP(self.model, **ddp_kwargs)
|
| 98 |
+
|
| 99 |
# Create directories
|
| 100 |
self.checkpoint_dir = Path(config.checkpoint_dir)
|
| 101 |
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
| 223 |
state dict, and config. Also saves as "best_model.pt" if the current
|
| 224 |
validation loss is the best seen so far.
|
| 225 |
|
| 226 |
+
Under distributed training, only rank 0 writes; the saved state_dict
|
| 227 |
+
unwraps DDP so checkpoints remain compatible with single-GPU loading.
|
| 228 |
+
|
| 229 |
Args:
|
| 230 |
val_loss: Current validation loss for comparison.
|
| 231 |
"""
|
| 232 |
+
# Only rank 0 writes checkpoints in DDP runs
|
| 233 |
+
if getattr(self, "is_distributed", False) and not self.is_main_process:
|
| 234 |
+
return
|
| 235 |
+
|
| 236 |
+
# Unwrap DDP to keep checkpoints loadable on a single GPU
|
| 237 |
+
raw_model = self.model.module if hasattr(self.model, "module") else self.model
|
| 238 |
+
|
| 239 |
checkpoint = {
|
| 240 |
"step": self.step,
|
| 241 |
"val_loss": val_loss,
|
| 242 |
+
"model_state_dict": raw_model.state_dict(),
|
| 243 |
"optimizer_state_dict": self.optimizer.state_dict(),
|
| 244 |
"grad_scaler_state_dict": self.grad_scaler.state_dict(),
|
| 245 |
"config": asdict(self.config),
|
|
|
|
| 267 |
"""
|
| 268 |
checkpoint = torch.load(path, map_location=self.device, weights_only=False)
|
| 269 |
|
| 270 |
+
# Load into the raw model (works for DDP-wrapped or single-GPU)
|
| 271 |
+
raw_model = self.model.module if hasattr(self.model, "module") else self.model
|
| 272 |
+
raw_model.load_state_dict(checkpoint["model_state_dict"])
|
| 273 |
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
| 274 |
if "grad_scaler_state_dict" in checkpoint:
|
| 275 |
self.grad_scaler.load_state_dict(checkpoint["grad_scaler_state_dict"])
|