ch1mera / train.py
Lgr54HFi's picture
Upload folder using huggingface_hub
6e408ce verified
#!/usr/bin/env python3
"""
Chimera 5.2 — CPU-first training script.
Highlights vs the previous version:
* MeZO optimiser uses a single deterministic seed per step, samples each
parameter's perturbation direction *on demand* via per-parameter seeds and
drops the heavy direction cache. This brings the memory cost of MeZO back
down to "1× model" exactly as advertised.
* AdamW path uses fused parameter groups and shares the same loss closure as
MeZO so accumulation and logging are identical between modes.
* Logging never references an undefined ``lr`` (the previous draft printed it
before the AdamW step ran on the first accumulator boundary).
* Gradient checkpointing falls back to ``use_reentrant=False`` (the modern,
faster path).
* Tokeniser/dataset loading is unchanged but the Python loops are skipped
entirely for ``max_tokens=0``.
Recommended commands::
# MeZO smoke test on TinyStories
python train.py --scale tiny --seq_len 64 --max_steps 20 --optimizer mezo
# AdamW with grad checkpointing + bf16
python train.py --scale small --seq_len 256 --max_steps 1000 \\
--optimizer adamw --grad_checkpoint --bf16
"""
from __future__ import annotations
import argparse
import json
import math
import os
import sys
import time
# CPU threading must be configured *before* importing torch.
def _setup_cpu_runtime() -> None:
n_cpus = os.cpu_count() or 4
os.environ.setdefault("OMP_NUM_THREADS", str(n_cpus))
os.environ.setdefault("MKL_NUM_THREADS", str(n_cpus))
os.environ.setdefault("KMP_AFFINITY", "granularity=fine,compact,1,0")
os.environ.setdefault("KMP_BLOCKTIME", "1")
os.environ.setdefault("MALLOC_CONF", "background_thread:true,metadata_thp:auto")
_setup_cpu_runtime()
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from chimera import Chimera51ForCausalLM
from chimera.quantization import BitLinear
torch.set_num_threads(int(os.environ.get("OMP_NUM_THREADS", os.cpu_count() or 4)))
try:
torch.set_num_interop_threads(int(os.environ.get("CHIMERA_INTEROP_THREADS", "1")))
except RuntimeError:
pass
# Optional Intel Extension for PyTorch.
HAS_IPEX = False
try: # pragma: no cover - optional dependency.
import intel_extension_for_pytorch as ipex # noqa: F401
HAS_IPEX = True
except Exception:
pass
# ---------------------------------------------------------------------------
# MeZO optimiser
# ---------------------------------------------------------------------------
class MeZOOptimizer:
"""Memory-Efficient Zeroth-Order optimiser (Princeton MeZO).
Each step runs *two* forward passes around ``θ`` and uses the resulting
loss difference to estimate a projected gradient. No backward pass and
no per-parameter optimiser state — memory cost is exactly ``1× model``.
For BitLinear layers we mask perturbations to currently non-zero ternary
positions, so ``~1/3`` of the weights skip both perturbation and update.
"""
def __init__(self, model: nn.Module, lr: float = 1e-4, eps: float = 1e-3,
weight_decay: float = 0.0, momentum: float = 0.0,
direction: str = "rademacher"):
self.model = model
self.lr = float(lr)
self.eps = float(eps)
self.wd = float(weight_decay)
self.momentum = float(momentum)
if direction not in ("rademacher", "gaussian"):
raise ValueError(f"unknown direction: {direction!r}")
self.direction = direction
# Collect trainable parameters once and deduplicate tied weights.
self._bitlinear_modules: list[tuple[str, BitLinear]] = []
self._dense_params: list[tuple[str, torch.Tensor]] = []
seen: set[int] = set()
for name, module in model.named_modules():
if isinstance(module, BitLinear):
self._bitlinear_modules.append((name, module))
seen.add(id(module.weight))
if module.bias is not None:
seen.add(id(module.bias))
for name, p in model.named_parameters():
if p.requires_grad and id(p) not in seen:
self._dense_params.append((name, p))
seen.add(id(p))
# Optional momentum buffer — only allocated when momentum > 0.
self._momentum: dict[int, torch.Tensor] = {}
if self.momentum > 0:
for _, p in self._dense_params:
self._momentum[id(p)] = torch.zeros_like(p.data)
for _, m in self._bitlinear_modules:
self._momentum[id(m.weight)] = torch.zeros_like(m.weight.data)
# Snapshot ternary non-zero masks once per step.
self._step_masks: dict[int, torch.Tensor] = {}
# ------------------------------------------------------------------
# Direction sampling — deterministic per (step seed, parameter index).
# ------------------------------------------------------------------
def _direction(self, p: torch.Tensor, seed: int) -> torch.Tensor:
gen = torch.Generator(device="cpu")
gen.manual_seed(int(seed) & 0x7FFF_FFFF_FFFF_FFFF)
if self.direction == "gaussian":
return torch.randn(p.shape, dtype=p.dtype, device="cpu",
generator=gen).to(p.device)
z = torch.empty(p.shape, dtype=p.dtype, device="cpu")
z.bernoulli_(0.5, generator=gen).mul_(2).sub_(1)
return z.to(p.device)
def _walk_params(self):
"""Yield ``(seed_offset, param, mask_or_None)`` for every trainable tensor."""
offset = 0
for _, module in self._bitlinear_modules:
yield offset, module.weight.data, self._step_masks.get(id(module.weight))
offset += 1
if module.bias is not None:
yield offset, module.bias.data, None
offset += 1
for _, p in self._dense_params:
yield offset, p.data, None
offset += 1
def _perturb(self, base_seed: int, scale: float) -> None:
for off, p, mask in self._walk_params():
z = self._direction(p, base_seed + off * 1_000_003)
if mask is not None:
z = z * mask.to(dtype=z.dtype, device=z.device)
p.add_(z, alpha=scale)
# Mark BitLinear caches stale.
for _, m in self._bitlinear_modules:
m.invalidate_packed()
def _update(self, base_seed: int, projected_grad: float) -> None:
for off, p, mask in self._walk_params():
z = self._direction(p, base_seed + off * 1_000_003)
if mask is not None:
z = z * mask.to(dtype=z.dtype, device=z.device)
buf = self._momentum.get(id(p))
if buf is not None:
buf.mul_(self.momentum).add_(z, alpha=projected_grad)
p.add_(buf, alpha=-self.lr)
else:
p.add_(z, alpha=-self.lr * projected_grad)
if self.wd > 0:
p.mul_(1 - self.lr * self.wd)
for _, m in self._bitlinear_modules:
m.invalidate_packed()
@torch.no_grad()
def step(self, loss_fn, batch) -> float:
"""Run one MeZO step (two forward passes) and return the mean loss."""
seed = int(torch.randint(0, 2**31, (1,)).item())
# Snapshot ternary non-zero masks once for this step.
self._step_masks = {
id(m.weight): m.ternary_nonzero_mask().detach()
for _, m in self._bitlinear_modules
}
# Forward at θ + εz.
self._perturb(seed, +self.eps)
loss_pos = float(loss_fn(batch).item())
# Net displacement: θ + εz - 2εz = θ - εz.
self._perturb(seed, -2.0 * self.eps)
loss_neg = float(loss_fn(batch).item())
# Restore θ.
self._perturb(seed, +self.eps)
projected_grad = (loss_pos - loss_neg) / (2.0 * self.eps)
self._update(seed, projected_grad)
self._step_masks = {}
return 0.5 * (loss_pos + loss_neg)
# ---------------------------------------------------------------------------
# Dataset & tokenisation helpers.
# ---------------------------------------------------------------------------
class TokenDataset(Dataset):
def __init__(self, chunks: torch.Tensor):
self.chunks = chunks
def __len__(self) -> int:
return self.chunks.size(0)
def __getitem__(self, idx: int) -> dict:
c = self.chunks[idx]
return {"input_ids": c, "labels": c}
def _matches_category_filter(ex: dict, filters: list) -> bool:
cat = ex.get("category", "") or ""
if not cat:
return False
cat_lower = cat.lower()
return any(f.lower() in cat_lower for f in filters)
def _format_example(ex: dict, tok, text_column: str = "auto",
include_reasoning: bool = False) -> str:
if text_column == "auto":
for cand in ("messages", "text", "content", "conversation"):
if cand in ex:
text_column = cand
break
else:
text_column = ""
if text_column == "messages" and "messages" in ex:
msgs = ex["messages"]
if include_reasoning and isinstance(msgs, list):
new_msgs = []
for m in msgs:
if isinstance(m, dict) and m.get("role") == "assistant" and "reasoning" in m:
new_msgs.append({
"role": "assistant",
"content": (f"<|thinking|>\n{m['reasoning']}\n<|/thinking|>\n"
f"{m.get('content', '')}"),
})
else:
new_msgs.append(m)
msgs = new_msgs
return tok.apply_chat_template(msgs)
if text_column and text_column in ex:
val = ex[text_column]
if isinstance(val, str):
return val
if isinstance(val, list) and val and isinstance(val[0], dict):
return tok.apply_chat_template(val)
return str(val)
return str(ex)
def build_dataset(seq_len: int, max_samples=None, max_tokens=None,
split: str = "train",
dataset_name: str = "roneneldan/TinyStories",
dataset_config: str = None, text_column: str = "auto",
category_filter: str = None,
include_reasoning: bool = False):
from datasets import load_dataset
from chimera import ChimeraTokenizer
print(f"[DATA] Loading {dataset_name} ({split})...")
load_kwargs = {"split": split, "streaming": True}
if dataset_config:
load_kwargs["name"] = dataset_config
ds = load_dataset(dataset_name, **load_kwargs)
tok = ChimeraTokenizer(pretrained="o200k_base")
cat_filters = ([c.strip() for c in category_filter.split(",") if c.strip()]
if category_filter else None)
if cat_filters:
print(f"[DATA] Filtering categories: {cat_filters}")
if max_tokens is not None:
token_budget = int(max_tokens)
elif max_samples is not None:
token_budget = int(max_samples) * (seq_len + 1)
else:
token_budget = None
if token_budget is None or token_budget <= 0:
# Fallback: list-based collection.
all_ids: list[int] = []
target = (max_samples * (seq_len + 1)) if max_samples else float("inf")
for ex in ds:
if cat_filters and not _matches_category_filter(ex, cat_filters):
continue
text = _format_example(ex, tok, text_column, include_reasoning)
if not text or not text.strip():
continue
ids = tok.encode(text, add_special_tokens=False)
ids.append(tok.eos_token_id)
all_ids.extend(ids)
if len(all_ids) >= target:
break
all_ids = torch.tensor(all_ids, dtype=torch.long)
else:
# Pre-allocated token buffer.
buffer = torch.empty(token_budget, dtype=torch.long)
buf_idx = 0
processed = skipped = 0
for ex in ds:
if cat_filters and not _matches_category_filter(ex, cat_filters):
skipped += 1
continue
text = _format_example(ex, tok, text_column, include_reasoning)
if not text or not text.strip():
skipped += 1
continue
ids = tok.encode(text, add_special_tokens=False)
ids.append(tok.eos_token_id)
n = len(ids)
if buf_idx + n > token_budget:
n = token_budget - buf_idx
if n <= 0:
break
ids = ids[:n]
if n > 0:
buffer[buf_idx:buf_idx + n] = torch.tensor(ids, dtype=torch.long)
buf_idx += n
processed += 1
if buf_idx >= token_budget:
break
if (processed % 10_000) == 0:
print(f" {processed:,} examples, {buf_idx:,} tokens...")
all_ids = buffer[:buf_idx]
print(f"[DATA] Processed {processed:,} examples, skipped {skipped:,}.")
if all_ids.numel() == 0:
raise ValueError("No data matched filters.")
n = all_ids.numel() // (seq_len + 1)
if max_samples:
n = min(n, max_samples)
chunks = all_ids[:n * (seq_len + 1)].view(n, seq_len + 1)
print(f"[DATA] {n:,} chunks × {seq_len} tokens = {n * seq_len:,} total")
return TokenDataset(chunks), tok
# ---------------------------------------------------------------------------
# Learning-rate schedule.
# ---------------------------------------------------------------------------
def cosine_lr(step: int, warmup: int, total: int, max_lr: float, min_lr: float
) -> float:
if warmup > 0 and step < warmup:
return max_lr * (step + 1) / warmup
if step >= total:
return min_lr
p = (step - warmup) / max(1, total - warmup)
return min_lr + 0.5 * (max_lr - min_lr) * (1.0 + math.cos(math.pi * p))
# ---------------------------------------------------------------------------
# Main loop.
# ---------------------------------------------------------------------------
_SCALE_PRESETS = {
"tiny": dict(hidden_size=256, intermediate_size=512, num_heads=4, head_dim=48),
"small": dict(hidden_size=512, intermediate_size=1024, num_heads=8, head_dim=48),
"medium": dict(hidden_size=1024, intermediate_size=2048, num_heads=8, head_dim=96),
}
def train(args) -> None:
with open(args.config) as f:
config = json.load(f)
if args.scale in _SCALE_PRESETS:
config.update(_SCALE_PRESETS[args.scale])
config["num_hidden_layers"] = int(config.get("num_hidden_layers", 28))
config["vocab_size"] = config.get("vocab_size", 200073)
config.setdefault("gated_deltanet", {})["chunk_size"] = min(args.seq_len, 64)
config.setdefault("xlstm", {})["memory_size_per_head"] = [config["head_dim"], config["head_dim"]]
config.setdefault("titans", {}).update({
"memory_depth": 2, "persistent_memory_slots": 16,
"local_window_size": min(args.seq_len, 256),
})
moe_cfg = config.setdefault("backbone", {}).setdefault("moe", {})
moe_cfg.setdefault("layers", [3, 7, 11, 15, 19, 23, 27])
moe_cfg.setdefault("moe_intermediate_size", config["intermediate_size"] // 4)
moe_cfg.setdefault("n_routed_experts", 8)
moe_cfg.setdefault("n_shared_experts", 1)
moe_cfg.setdefault("num_experts_per_tok", 2)
config.setdefault("looping", {}).update({
"enabled": True, "prelude": [0, 3], "loop": [4, 23], "coda": [24, 27],
"loop_range": [1, 3], "loop_default": 2,
})
config.setdefault("span_inference", {})["enabled"] = True
config.setdefault("grammar", {})["enabled"] = True
config.setdefault("entropy_valve", {})["enabled"] = True
config.setdefault("debt_ledger", {})["enabled"] = True
config.setdefault("multimodal", {})["enabled"] = False
use_mezo = (args.optimizer == "mezo")
use_bf16 = bool(args.bf16)
use_compile = bool(args.compile)
print("=" * 60)
print(f"CHIMERA 5.2 TRAINING — scale={args.scale}, "
f"optimizer={'MeZO' if use_mezo else 'AdamW'}, bf16={use_bf16}")
print(f"Layers={config['num_hidden_layers']} hidden={config['hidden_size']} "
f"vocab={config['vocab_size']} seq_len={args.seq_len} steps={args.max_steps}")
print(f"Threads: {torch.get_num_threads()} IPEX={HAS_IPEX}")
print("=" * 60)
model = Chimera51ForCausalLM(config)
counts = model.count_parameters()
print(f"Params: total={counts['total']:,} ternary={counts['ternary']:,}")
if args.grad_checkpoint and not use_mezo:
model.enable_gradient_checkpointing()
print("[OPT] Gradient checkpointing ON")
if HAS_IPEX and not use_mezo:
adamw = torch.optim.AdamW(model.parameters(), lr=args.lr)
model, adamw = ipex.optimize(
model, optimizer=adamw,
dtype=torch.bfloat16 if use_bf16 else torch.float32, level="O1")
print("[OPT] IPEX optimisation applied (level O1)")
else:
adamw = None
if use_compile:
print("[OPT] Compiling model with torch.compile (inductor)...")
model = torch.compile(model, backend="inductor", mode="default", dynamic=True)
dataset, tok = build_dataset(
args.seq_len, max_samples=args.max_samples, max_tokens=args.max_tokens,
split=args.dataset_split, dataset_name=args.dataset_name,
dataset_config=args.dataset_config, text_column=args.text_column,
category_filter=args.category_filter,
include_reasoning=args.include_reasoning,
)
loader = DataLoader(
dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.num_workers, drop_last=True,
persistent_workers=args.num_workers > 0,
prefetch_factor=2 if args.num_workers > 0 else None,
)
if use_mezo:
optimizer = MeZOOptimizer(
model, lr=args.lr * 0.01, eps=1e-3,
weight_decay=0.1, momentum=0.9, direction=args.mezo_direction,
)
else:
no_decay = {"A_log", "dt_bias", "norm", "bias", "embed", "energy_weights"}
decay_params, no_decay_params = [], []
for n, p in model.named_parameters():
if not p.requires_grad:
continue
if any(tag in n for tag in no_decay):
no_decay_params.append(p)
else:
decay_params.append(p)
if adamw is None:
optimizer = torch.optim.AdamW(
[{"params": decay_params, "weight_decay": 0.1},
{"params": no_decay_params, "weight_decay": 0.0}],
lr=args.lr, betas=(0.9, 0.95))
else:
optimizer = adamw
def compute_loss(batch) -> torch.Tensor:
ids = batch["input_ids"][:, :-1]
labels = batch["labels"][:, 1:]
if use_bf16:
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
out = model(ids, labels=labels)
else:
out = model(ids, labels=labels)
return out.loss
os.makedirs(args.output_dir, exist_ok=True)
log_path = os.path.join(args.output_dir, "log.jsonl")
log_f = open(log_path, "w", encoding="utf-8")
model.train()
step = 0
cur_lr = args.lr
total_loss = 0.0
best_loss = float("inf")
toks = 0
t0 = time.time()
data_iter = iter(loader)
warmup = min(args.warmup, max(1, args.max_steps // 10))
if not use_mezo:
optimizer.zero_grad(set_to_none=True)
print(f"\n{'=' * 60}\nTraining starts\n{'=' * 60}\n")
while step < args.max_steps:
try:
batch = next(data_iter)
except StopIteration:
data_iter = iter(loader)
batch = next(data_iter)
if use_mezo:
cur_lr = cosine_lr(step, warmup, args.max_steps,
args.lr * 0.01, args.lr * 0.001)
optimizer.lr = cur_lr
loss_val = optimizer.step(compute_loss, batch)
total_loss += loss_val
else:
loss = compute_loss(batch)
(loss / args.grad_accum).backward()
total_loss += float(loss.item())
if (step + 1) % args.grad_accum == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
cur_lr = cosine_lr(step, warmup, args.max_steps,
args.lr, args.lr * 0.1)
for pg in optimizer.param_groups:
pg["lr"] = cur_lr
optimizer.step()
optimizer.zero_grad(set_to_none=True)
toks += batch["input_ids"][:, :-1].numel()
step += 1
if step % args.log_every == 0:
dt = time.time() - t0
avg = total_loss / args.log_every
ppl = math.exp(min(avg, 20))
tps = toks / dt if dt > 0 else 0
eta_h = (args.max_steps - step) / (step / dt) / 3600 if dt > 0 else 0.0
log_f.write(json.dumps({
"step": step, "loss": round(avg, 4), "ppl": round(ppl, 2),
"lr": cur_lr, "tok/s": round(tps),
"optimizer": "mezo" if use_mezo else "adamw",
}) + "\n")
log_f.flush()
print(f" step {step:>6}/{args.max_steps} | loss {avg:.4f} | "
f"ppl {ppl:>8.2f} | lr {cur_lr:.2e} | "
f"{tps:.0f} tok/s | ETA {eta_h:.1f}h")
best_loss = min(best_loss, avg)
total_loss = 0.0
toks = 0
t0 = time.time()
if step % args.save_every == 0:
ckpt_dir = os.path.join(args.output_dir, f"ckpt-{step}")
os.makedirs(ckpt_dir, exist_ok=True)
raw = getattr(model, "_orig_mod", model)
torch.save({
"model": raw.state_dict(), "config": config,
"step": step, "optimizer": args.optimizer,
}, os.path.join(ckpt_dir, "ckpt.pt"))
print(f" [SAVE] {ckpt_dir}")
final_dir = os.path.join(args.output_dir, "final")
os.makedirs(final_dir, exist_ok=True)
raw = getattr(model, "_orig_mod", model)
torch.save({
"model": raw.state_dict(), "config": config,
"step": step, "best_loss": best_loss,
}, os.path.join(final_dir, "model.pt"))
with open(os.path.join(final_dir, "config.json"), "w", encoding="utf-8") as fh:
json.dump(config, fh, indent=2)
log_f.close()
print(f"\n{'=' * 60}")
print(f"DONE — best loss {best_loss:.4f}, ppl {math.exp(min(best_loss, 20)):.2f}")
print(f"Saved to {final_dir}")
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def _build_argparser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="Chimera 5.2 CPU-first training")
p.add_argument("--config", default="config.json")
p.add_argument("--scale", default="tiny", choices=["tiny", "small", "medium", "full"])
p.add_argument("--seq_len", type=int, default=256)
p.add_argument("--optimizer", default="mezo", choices=["mezo", "adamw"])
p.add_argument("--batch_size", type=int, default=2)
p.add_argument("--grad_accum", type=int, default=8)
p.add_argument("--lr", type=float, default=1e-3)
p.add_argument("--warmup", type=int, default=200)
p.add_argument("--max_steps", type=int, default=5000)
p.add_argument("--max_samples", type=int, default=None)
p.add_argument("--max_tokens", type=int, default=None)
p.add_argument("--bf16", action="store_true", default=True)
p.add_argument("--no-bf16", dest="bf16", action="store_false")
p.add_argument("--compile", action="store_true", default=False)
p.add_argument("--grad_checkpoint", action="store_true", default=True)
p.add_argument("--no-grad-checkpoint", dest="grad_checkpoint", action="store_false")
p.add_argument("--mezo_direction", choices=["rademacher", "gaussian"],
default="rademacher")
p.add_argument("--dataset_name", default="roneneldan/TinyStories")
p.add_argument("--dataset_config", default=None)
p.add_argument("--dataset_split", default="train")
p.add_argument("--text_column", default="auto")
p.add_argument("--category_filter", default=None)
p.add_argument("--include_reasoning", action="store_true", default=False)
p.add_argument("--num_workers", type=int, default=2)
p.add_argument("--log_every", type=int, default=10)
p.add_argument("--save_every", type=int, default=1000)
p.add_argument("--output_dir", default="./chimera_output")
return p
if __name__ == "__main__":
args = _build_argparser().parse_args()
train(args)