ch1mera / train_fast.py
Lgr54HFi's picture
Upload train_fast.py
6639e7f verified
#!/usr/bin/env python3
"""Chimera 5.2 — Fast CPU training with pre-tokenized dataset cache."""
from __future__ import annotations
import argparse
import json
import math
import os
import sys
import time
# CPU threading must be configured *before* importing torch.
ncpus = int(os.environ.get("OMP_NUM_THREADS", os.cpu_count() or 4))
os.environ["OMP_NUM_THREADS"] = str(ncpus)
os.environ["MKL_NUM_THREADS"] = str(ncpus)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from chimera import Chimera51ForCausalLM
torch.set_num_threads(ncpus)
try:
torch.set_num_interop_threads(1)
except RuntimeError:
pass
# ---------------------------------------------------------------------------
# Pre-tokenized dataset cache
# ---------------------------------------------------------------------------
class PreTokenizedDataset(Dataset):
def __init__(self, ids: torch.Tensor, seq_len: int):
n = ids.numel() // (seq_len + 1)
self.chunks = ids[:n * (seq_len + 1)].view(n, seq_len + 1)
self.seq_len = seq_len
def __len__(self) -> int:
return self.chunks.size(0)
def __getitem__(self, idx: int):
c = self.chunks[idx]
return {"input_ids": c[:-1], "labels": c[1:]}
def build_or_load_dataset(seq_len: int, max_samples: int, cache_dir: str = "./cache"):
cache_path = os.path.join(cache_dir, f"tiny_stories_{seq_len}_{max_samples}.pt")
os.makedirs(cache_dir, exist_ok=True)
if os.path.exists(cache_path):
print(f"[CACHE] Loading pre-tokenized dataset from {cache_path}")
chunks = torch.load(cache_path, weights_only=False)
return PreTokenizedDataset(chunks, seq_len)
from datasets import load_dataset
from chimera import ChimeraTokenizer
print(f"[DATA] Downloading TinyStories...")
ds = load_dataset("roneneldan/TinyStories", split="train", streaming=True)
tok = ChimeraTokenizer(pretrained="o200k_base")
target = max_samples * (seq_len + 1)
buffer = torch.empty(target, dtype=torch.long)
buf_idx = 0
processed = 0
for ex in ds:
text = ex.get("text", "")
if not text:
continue
ids = tok.encode(text, add_special_tokens=False)
ids.append(tok.eos_token_id)
n = len(ids)
if buf_idx + n > target:
n = target - buf_idx
if n <= 0:
break
ids = ids[:n]
if n > 0:
buffer[buf_idx:buf_idx + n] = torch.tensor(ids, dtype=torch.long)
buf_idx += n
processed += 1
if (processed % 1000) == 0:
print(f" {processed:,} stories, {buf_idx:,}/{target} tokens...")
if buf_idx >= target:
break
all_ids = buffer[:buf_idx]
n = all_ids.numel() // (seq_len + 1)
chunks = all_ids[:n * (seq_len + 1)]
torch.save(chunks, cache_path)
print(f"[CACHE] Saved {chunks.numel():,} tokens to {cache_path}")
return PreTokenizedDataset(chunks, seq_len)
# ---------------------------------------------------------------------------
# Fast training loop
# ---------------------------------------------------------------------------
def cosine_lr(step: int, warmup: int, total: int, max_lr: float, min_lr: float) -> float:
if warmup > 0 and step < warmup:
return max_lr * (step + 1) / warmup
if step >= total:
return min_lr
p = (step - warmup) / max(1, total - warmup)
return min_lr + 0.5 * (max_lr - min_lr) * (1.0 + math.cos(math.pi * p))
_SCALE_PRESETS = {
"tiny": dict(hidden_size=256, intermediate_size=512, num_heads=4, head_dim=48),
"small": dict(hidden_size=512, intermediate_size=1024, num_heads=8, head_dim=48),
"medium": dict(hidden_size=1024, intermediate_size=2048, num_heads=8, head_dim=96),
}
def train(args) -> None:
with open(args.config) as f:
config = json.load(f)
if args.scale in _SCALE_PRESETS:
config.update(_SCALE_PRESETS[args.scale])
config["num_hidden_layers"] = int(config.get("num_hidden_layers", 28))
config["vocab_size"] = config.get("vocab_size", 200073)
config.setdefault("gated_deltanet", {})["chunk_size"] = min(args.seq_len, 64)
config.setdefault("xlstm", {})["memory_size_per_head"] = [config["head_dim"], config["head_dim"]]
config.setdefault("titans", {}).update({
"memory_depth": 2, "persistent_memory_slots": 16,
"local_window_size": min(args.seq_len, 256),
})
moe_cfg = config.setdefault("backbone", {}).setdefault("moe", {})
moe_cfg.setdefault("layers", [3, 7, 11, 15, 19, 23, 27])
moe_cfg.setdefault("moe_intermediate_size", config["intermediate_size"] // 4)
moe_cfg.setdefault("n_routed_experts", 8)
moe_cfg.setdefault("n_shared_experts", 1)
moe_cfg.setdefault("num_experts_per_tok", 2)
config.setdefault("looping", {}).update({
"enabled": True, "prelude": [0, 3], "loop": [4, 23], "coda": [24, 27],
"loop_range": [1, 3], "loop_default": 2,
})
config.setdefault("span_inference", {})["enabled"] = True
config.setdefault("grammar", {})["enabled"] = True
config.setdefault("entropy_valve", {})["enabled"] = True
config.setdefault("debt_ledger", {})["enabled"] = True
config.setdefault("multimodal", {})["enabled"] = False
print("=" * 60)
print(f"CHIMERA 5.2 FAST TRAIN — scale={args.scale}, seq_len={args.seq_len}, steps={args.max_steps}")
print(f"Layers={config['num_hidden_layers']} hidden={config['hidden_size']} vocab={config['vocab_size']}")
print(f"Threads: {torch.get_num_threads()} bf16={args.bf16} compile={args.compile}")
print("=" * 60)
model = Chimera51ForCausalLM(config)
counts = model.count_parameters()
print(f"Params: total={counts['total']:,} ternary={counts['ternary']:,}")
if args.compile:
print("[OPT] Compiling model...")
model = torch.compile(model, backend="inductor", mode="default", dynamic=True)
dataset = build_or_load_dataset(args.seq_len, args.max_samples, args.cache_dir)
loader = DataLoader(
dataset, batch_size=args.batch_size, shuffle=True,
num_workers=0, drop_last=True,
)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.95))
def compute_loss(batch) -> torch.Tensor:
ids = batch["input_ids"]
labels = batch["labels"]
if args.bf16:
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
out = model(ids, labels=labels)
else:
out = model(ids, labels=labels)
return out.loss
os.makedirs(args.output_dir, exist_ok=True)
log_path = os.path.join(args.output_dir, "log.jsonl")
log_f = open(log_path, "w", encoding="utf-8")
model.train()
step = 0
total_loss = 0.0
best_loss = float("inf")
toks = 0
t0 = time.time()
data_iter = iter(loader)
warmup = min(args.warmup, max(1, args.max_steps // 10))
print(f"\n{'=' * 60}\nTraining starts\n{'=' * 60}\n")
while step < args.max_steps:
try:
batch = next(data_iter)
except StopIteration:
data_iter = iter(loader)
batch = next(data_iter)
loss = compute_loss(batch)
loss.backward()
total_loss += float(loss.item())
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
cur_lr = cosine_lr(step, warmup, args.max_steps, args.lr, args.lr * 0.1)
for pg in optimizer.param_groups:
pg["lr"] = cur_lr
optimizer.step()
optimizer.zero_grad(set_to_none=True)
toks += batch["input_ids"].numel()
step += 1
if step % args.log_every == 0:
dt = time.time() - t0
avg = total_loss / args.log_every
ppl = math.exp(min(avg, 20))
tps = toks / dt if dt > 0 else 0
eta_h = (args.max_steps - step) / (step / dt) / 3600 if dt > 0 else 0.0
log_f.write(json.dumps({
"step": step, "loss": round(avg, 4), "ppl": round(ppl, 2),
"lr": cur_lr, "tok/s": round(tps),
}) + "\n")
log_f.flush()
print(f" step {step:>6}/{args.max_steps} | loss {avg:.4f} | "
f"ppl {ppl:>8.2f} | lr {cur_lr:.2e} | "
f"{tps:.0f} tok/s | ETA {eta_h:.1f}h")
best_loss = min(best_loss, avg)
total_loss = 0.0
toks = 0
t0 = time.time()
if step % args.save_every == 0:
ckpt_dir = os.path.join(args.output_dir, f"ckpt-{step}")
os.makedirs(ckpt_dir, exist_ok=True)
raw = getattr(model, "_orig_mod", model)
torch.save({
"model": raw.state_dict(), "config": config,
"step": step,
}, os.path.join(ckpt_dir, "ckpt.pt"))
print(f" [SAVE] {ckpt_dir}")
final_dir = os.path.join(args.output_dir, "final")
os.makedirs(final_dir, exist_ok=True)
raw = getattr(model, "_orig_mod", model)
torch.save({
"model": raw.state_dict(), "config": config,
"step": step, "best_loss": best_loss,
}, os.path.join(final_dir, "model.pt"))
with open(os.path.join(final_dir, "config.json"), "w", encoding="utf-8") as fh:
json.dump(config, fh, indent=2)
log_f.close()
print(f"\n{'=' * 60}")
print(f"DONE — best loss {best_loss:.4f}, ppl {math.exp(min(best_loss, 20)):.2f}")
print(f"Saved to {final_dir}")
if __name__ == "__main__":
p = argparse.ArgumentParser(description="Chimera 5.2 Fast CPU training")
p.add_argument("--config", default="config.json")
p.add_argument("--scale", default="tiny", choices=["tiny", "small", "medium", "full"])
p.add_argument("--seq_len", type=int, default=32)
p.add_argument("--batch_size", type=int, default=4)
p.add_argument("--lr", type=float, default=1e-3)
p.add_argument("--warmup", type=int, default=100)
p.add_argument("--max_steps", type=int, default=1000)
p.add_argument("--max_samples", type=int, default=5000)
p.add_argument("--bf16", action="store_true", default=False)
p.add_argument("--compile", action="store_true", default=False)
p.add_argument("--cache_dir", default="./cache")
p.add_argument("--log_every", type=int, default=10)
p.add_argument("--save_every", type=int, default=500)
p.add_argument("--output_dir", default="./chimera_output")
args = p.parse_args()
train(args)