#!/usr/bin/env python3 """ Chimera 5.2 — CPU-first training script. Highlights vs the previous version: * MeZO optimiser uses a single deterministic seed per step, samples each parameter's perturbation direction *on demand* via per-parameter seeds and drops the heavy direction cache. This brings the memory cost of MeZO back down to "1× model" exactly as advertised. * AdamW path uses fused parameter groups and shares the same loss closure as MeZO so accumulation and logging are identical between modes. * Logging never references an undefined ``lr`` (the previous draft printed it before the AdamW step ran on the first accumulator boundary). * Gradient checkpointing falls back to ``use_reentrant=False`` (the modern, faster path). * Tokeniser/dataset loading is unchanged but the Python loops are skipped entirely for ``max_tokens=0``. Recommended commands:: # MeZO smoke test on TinyStories python train.py --scale tiny --seq_len 64 --max_steps 20 --optimizer mezo # AdamW with grad checkpointing + bf16 python train.py --scale small --seq_len 256 --max_steps 1000 \\ --optimizer adamw --grad_checkpoint --bf16 """ from __future__ import annotations import argparse import json import math import os import sys import time # CPU threading must be configured *before* importing torch. def _setup_cpu_runtime() -> None: n_cpus = os.cpu_count() or 4 os.environ.setdefault("OMP_NUM_THREADS", str(n_cpus)) os.environ.setdefault("MKL_NUM_THREADS", str(n_cpus)) os.environ.setdefault("KMP_AFFINITY", "granularity=fine,compact,1,0") os.environ.setdefault("KMP_BLOCKTIME", "1") os.environ.setdefault("MALLOC_CONF", "background_thread:true,metadata_thp:auto") _setup_cpu_runtime() import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from chimera import Chimera51ForCausalLM from chimera.quantization import BitLinear torch.set_num_threads(int(os.environ.get("OMP_NUM_THREADS", os.cpu_count() or 4))) try: torch.set_num_interop_threads(int(os.environ.get("CHIMERA_INTEROP_THREADS", "1"))) except RuntimeError: pass # Optional Intel Extension for PyTorch. HAS_IPEX = False try: # pragma: no cover - optional dependency. import intel_extension_for_pytorch as ipex # noqa: F401 HAS_IPEX = True except Exception: pass # --------------------------------------------------------------------------- # MeZO optimiser # --------------------------------------------------------------------------- class MeZOOptimizer: """Memory-Efficient Zeroth-Order optimiser (Princeton MeZO). Each step runs *two* forward passes around ``θ`` and uses the resulting loss difference to estimate a projected gradient. No backward pass and no per-parameter optimiser state — memory cost is exactly ``1× model``. For BitLinear layers we mask perturbations to currently non-zero ternary positions, so ``~1/3`` of the weights skip both perturbation and update. """ def __init__(self, model: nn.Module, lr: float = 1e-4, eps: float = 1e-3, weight_decay: float = 0.0, momentum: float = 0.0, direction: str = "rademacher"): self.model = model self.lr = float(lr) self.eps = float(eps) self.wd = float(weight_decay) self.momentum = float(momentum) if direction not in ("rademacher", "gaussian"): raise ValueError(f"unknown direction: {direction!r}") self.direction = direction # Collect trainable parameters once and deduplicate tied weights. self._bitlinear_modules: list[tuple[str, BitLinear]] = [] self._dense_params: list[tuple[str, torch.Tensor]] = [] seen: set[int] = set() for name, module in model.named_modules(): if isinstance(module, BitLinear): self._bitlinear_modules.append((name, module)) seen.add(id(module.weight)) if module.bias is not None: seen.add(id(module.bias)) for name, p in model.named_parameters(): if p.requires_grad and id(p) not in seen: self._dense_params.append((name, p)) seen.add(id(p)) # Optional momentum buffer — only allocated when momentum > 0. self._momentum: dict[int, torch.Tensor] = {} if self.momentum > 0: for _, p in self._dense_params: self._momentum[id(p)] = torch.zeros_like(p.data) for _, m in self._bitlinear_modules: self._momentum[id(m.weight)] = torch.zeros_like(m.weight.data) # Snapshot ternary non-zero masks once per step. self._step_masks: dict[int, torch.Tensor] = {} # ------------------------------------------------------------------ # Direction sampling — deterministic per (step seed, parameter index). # ------------------------------------------------------------------ def _direction(self, p: torch.Tensor, seed: int) -> torch.Tensor: gen = torch.Generator(device="cpu") gen.manual_seed(int(seed) & 0x7FFF_FFFF_FFFF_FFFF) if self.direction == "gaussian": return torch.randn(p.shape, dtype=p.dtype, device="cpu", generator=gen).to(p.device) z = torch.empty(p.shape, dtype=p.dtype, device="cpu") z.bernoulli_(0.5, generator=gen).mul_(2).sub_(1) return z.to(p.device) def _walk_params(self): """Yield ``(seed_offset, param, mask_or_None)`` for every trainable tensor.""" offset = 0 for _, module in self._bitlinear_modules: yield offset, module.weight.data, self._step_masks.get(id(module.weight)) offset += 1 if module.bias is not None: yield offset, module.bias.data, None offset += 1 for _, p in self._dense_params: yield offset, p.data, None offset += 1 def _perturb(self, base_seed: int, scale: float) -> None: for off, p, mask in self._walk_params(): z = self._direction(p, base_seed + off * 1_000_003) if mask is not None: z = z * mask.to(dtype=z.dtype, device=z.device) p.add_(z, alpha=scale) # Mark BitLinear caches stale. for _, m in self._bitlinear_modules: m.invalidate_packed() def _update(self, base_seed: int, projected_grad: float) -> None: for off, p, mask in self._walk_params(): z = self._direction(p, base_seed + off * 1_000_003) if mask is not None: z = z * mask.to(dtype=z.dtype, device=z.device) buf = self._momentum.get(id(p)) if buf is not None: buf.mul_(self.momentum).add_(z, alpha=projected_grad) p.add_(buf, alpha=-self.lr) else: p.add_(z, alpha=-self.lr * projected_grad) if self.wd > 0: p.mul_(1 - self.lr * self.wd) for _, m in self._bitlinear_modules: m.invalidate_packed() @torch.no_grad() def step(self, loss_fn, batch) -> float: """Run one MeZO step (two forward passes) and return the mean loss.""" seed = int(torch.randint(0, 2**31, (1,)).item()) # Snapshot ternary non-zero masks once for this step. self._step_masks = { id(m.weight): m.ternary_nonzero_mask().detach() for _, m in self._bitlinear_modules } # Forward at θ + εz. self._perturb(seed, +self.eps) loss_pos = float(loss_fn(batch).item()) # Net displacement: θ + εz - 2εz = θ - εz. self._perturb(seed, -2.0 * self.eps) loss_neg = float(loss_fn(batch).item()) # Restore θ. self._perturb(seed, +self.eps) projected_grad = (loss_pos - loss_neg) / (2.0 * self.eps) self._update(seed, projected_grad) self._step_masks = {} return 0.5 * (loss_pos + loss_neg) # --------------------------------------------------------------------------- # Dataset & tokenisation helpers. # --------------------------------------------------------------------------- class TokenDataset(Dataset): def __init__(self, chunks: torch.Tensor): self.chunks = chunks def __len__(self) -> int: return self.chunks.size(0) def __getitem__(self, idx: int) -> dict: c = self.chunks[idx] return {"input_ids": c, "labels": c} def _matches_category_filter(ex: dict, filters: list) -> bool: cat = ex.get("category", "") or "" if not cat: return False cat_lower = cat.lower() return any(f.lower() in cat_lower for f in filters) def _format_example(ex: dict, tok, text_column: str = "auto", include_reasoning: bool = False) -> str: if text_column == "auto": for cand in ("messages", "text", "content", "conversation"): if cand in ex: text_column = cand break else: text_column = "" if text_column == "messages" and "messages" in ex: msgs = ex["messages"] if include_reasoning and isinstance(msgs, list): new_msgs = [] for m in msgs: if isinstance(m, dict) and m.get("role") == "assistant" and "reasoning" in m: new_msgs.append({ "role": "assistant", "content": (f"<|thinking|>\n{m['reasoning']}\n<|/thinking|>\n" f"{m.get('content', '')}"), }) else: new_msgs.append(m) msgs = new_msgs return tok.apply_chat_template(msgs) if text_column and text_column in ex: val = ex[text_column] if isinstance(val, str): return val if isinstance(val, list) and val and isinstance(val[0], dict): return tok.apply_chat_template(val) return str(val) return str(ex) def build_dataset(seq_len: int, max_samples=None, max_tokens=None, split: str = "train", dataset_name: str = "roneneldan/TinyStories", dataset_config: str = None, text_column: str = "auto", category_filter: str = None, include_reasoning: bool = False): from datasets import load_dataset from chimera import ChimeraTokenizer print(f"[DATA] Loading {dataset_name} ({split})...") load_kwargs = {"split": split, "streaming": True} if dataset_config: load_kwargs["name"] = dataset_config ds = load_dataset(dataset_name, **load_kwargs) tok = ChimeraTokenizer(pretrained="o200k_base") cat_filters = ([c.strip() for c in category_filter.split(",") if c.strip()] if category_filter else None) if cat_filters: print(f"[DATA] Filtering categories: {cat_filters}") if max_tokens is not None: token_budget = int(max_tokens) elif max_samples is not None: token_budget = int(max_samples) * (seq_len + 1) else: token_budget = None if token_budget is None or token_budget <= 0: # Fallback: list-based collection. all_ids: list[int] = [] target = (max_samples * (seq_len + 1)) if max_samples else float("inf") for ex in ds: if cat_filters and not _matches_category_filter(ex, cat_filters): continue text = _format_example(ex, tok, text_column, include_reasoning) if not text or not text.strip(): continue ids = tok.encode(text, add_special_tokens=False) ids.append(tok.eos_token_id) all_ids.extend(ids) if len(all_ids) >= target: break all_ids = torch.tensor(all_ids, dtype=torch.long) else: # Pre-allocated token buffer. buffer = torch.empty(token_budget, dtype=torch.long) buf_idx = 0 processed = skipped = 0 for ex in ds: if cat_filters and not _matches_category_filter(ex, cat_filters): skipped += 1 continue text = _format_example(ex, tok, text_column, include_reasoning) if not text or not text.strip(): skipped += 1 continue ids = tok.encode(text, add_special_tokens=False) ids.append(tok.eos_token_id) n = len(ids) if buf_idx + n > token_budget: n = token_budget - buf_idx if n <= 0: break ids = ids[:n] if n > 0: buffer[buf_idx:buf_idx + n] = torch.tensor(ids, dtype=torch.long) buf_idx += n processed += 1 if buf_idx >= token_budget: break if (processed % 10_000) == 0: print(f" {processed:,} examples, {buf_idx:,} tokens...") all_ids = buffer[:buf_idx] print(f"[DATA] Processed {processed:,} examples, skipped {skipped:,}.") if all_ids.numel() == 0: raise ValueError("No data matched filters.") n = all_ids.numel() // (seq_len + 1) if max_samples: n = min(n, max_samples) chunks = all_ids[:n * (seq_len + 1)].view(n, seq_len + 1) print(f"[DATA] {n:,} chunks × {seq_len} tokens = {n * seq_len:,} total") return TokenDataset(chunks), tok # --------------------------------------------------------------------------- # Learning-rate schedule. # --------------------------------------------------------------------------- def cosine_lr(step: int, warmup: int, total: int, max_lr: float, min_lr: float ) -> float: if warmup > 0 and step < warmup: return max_lr * (step + 1) / warmup if step >= total: return min_lr p = (step - warmup) / max(1, total - warmup) return min_lr + 0.5 * (max_lr - min_lr) * (1.0 + math.cos(math.pi * p)) # --------------------------------------------------------------------------- # Main loop. # --------------------------------------------------------------------------- _SCALE_PRESETS = { "tiny": dict(hidden_size=256, intermediate_size=512, num_heads=4, head_dim=48), "small": dict(hidden_size=512, intermediate_size=1024, num_heads=8, head_dim=48), "medium": dict(hidden_size=1024, intermediate_size=2048, num_heads=8, head_dim=96), } def train(args) -> None: with open(args.config) as f: config = json.load(f) if args.scale in _SCALE_PRESETS: config.update(_SCALE_PRESETS[args.scale]) config["num_hidden_layers"] = int(config.get("num_hidden_layers", 28)) config["vocab_size"] = config.get("vocab_size", 200073) config.setdefault("gated_deltanet", {})["chunk_size"] = min(args.seq_len, 64) config.setdefault("xlstm", {})["memory_size_per_head"] = [config["head_dim"], config["head_dim"]] config.setdefault("titans", {}).update({ "memory_depth": 2, "persistent_memory_slots": 16, "local_window_size": min(args.seq_len, 256), }) moe_cfg = config.setdefault("backbone", {}).setdefault("moe", {}) moe_cfg.setdefault("layers", [3, 7, 11, 15, 19, 23, 27]) moe_cfg.setdefault("moe_intermediate_size", config["intermediate_size"] // 4) moe_cfg.setdefault("n_routed_experts", 8) moe_cfg.setdefault("n_shared_experts", 1) moe_cfg.setdefault("num_experts_per_tok", 2) config.setdefault("looping", {}).update({ "enabled": True, "prelude": [0, 3], "loop": [4, 23], "coda": [24, 27], "loop_range": [1, 3], "loop_default": 2, }) config.setdefault("span_inference", {})["enabled"] = True config.setdefault("grammar", {})["enabled"] = True config.setdefault("entropy_valve", {})["enabled"] = True config.setdefault("debt_ledger", {})["enabled"] = True config.setdefault("multimodal", {})["enabled"] = False use_mezo = (args.optimizer == "mezo") use_bf16 = bool(args.bf16) use_compile = bool(args.compile) print("=" * 60) print(f"CHIMERA 5.2 TRAINING — scale={args.scale}, " f"optimizer={'MeZO' if use_mezo else 'AdamW'}, bf16={use_bf16}") print(f"Layers={config['num_hidden_layers']} hidden={config['hidden_size']} " f"vocab={config['vocab_size']} seq_len={args.seq_len} steps={args.max_steps}") print(f"Threads: {torch.get_num_threads()} IPEX={HAS_IPEX}") print("=" * 60) model = Chimera51ForCausalLM(config) counts = model.count_parameters() print(f"Params: total={counts['total']:,} ternary={counts['ternary']:,}") if args.grad_checkpoint and not use_mezo: model.enable_gradient_checkpointing() print("[OPT] Gradient checkpointing ON") if HAS_IPEX and not use_mezo: adamw = torch.optim.AdamW(model.parameters(), lr=args.lr) model, adamw = ipex.optimize( model, optimizer=adamw, dtype=torch.bfloat16 if use_bf16 else torch.float32, level="O1") print("[OPT] IPEX optimisation applied (level O1)") else: adamw = None if use_compile: print("[OPT] Compiling model with torch.compile (inductor)...") model = torch.compile(model, backend="inductor", mode="default", dynamic=True) dataset, tok = build_dataset( args.seq_len, max_samples=args.max_samples, max_tokens=args.max_tokens, split=args.dataset_split, dataset_name=args.dataset_name, dataset_config=args.dataset_config, text_column=args.text_column, category_filter=args.category_filter, include_reasoning=args.include_reasoning, ) loader = DataLoader( dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, drop_last=True, persistent_workers=args.num_workers > 0, prefetch_factor=2 if args.num_workers > 0 else None, ) if use_mezo: optimizer = MeZOOptimizer( model, lr=args.lr * 0.01, eps=1e-3, weight_decay=0.1, momentum=0.9, direction=args.mezo_direction, ) else: no_decay = {"A_log", "dt_bias", "norm", "bias", "embed", "energy_weights"} decay_params, no_decay_params = [], [] for n, p in model.named_parameters(): if not p.requires_grad: continue if any(tag in n for tag in no_decay): no_decay_params.append(p) else: decay_params.append(p) if adamw is None: optimizer = torch.optim.AdamW( [{"params": decay_params, "weight_decay": 0.1}, {"params": no_decay_params, "weight_decay": 0.0}], lr=args.lr, betas=(0.9, 0.95)) else: optimizer = adamw def compute_loss(batch) -> torch.Tensor: ids = batch["input_ids"][:, :-1] labels = batch["labels"][:, 1:] if use_bf16: with torch.autocast(device_type="cpu", dtype=torch.bfloat16): out = model(ids, labels=labels) else: out = model(ids, labels=labels) return out.loss os.makedirs(args.output_dir, exist_ok=True) log_path = os.path.join(args.output_dir, "log.jsonl") log_f = open(log_path, "w", encoding="utf-8") model.train() step = 0 cur_lr = args.lr total_loss = 0.0 best_loss = float("inf") toks = 0 t0 = time.time() data_iter = iter(loader) warmup = min(args.warmup, max(1, args.max_steps // 10)) if not use_mezo: optimizer.zero_grad(set_to_none=True) print(f"\n{'=' * 60}\nTraining starts\n{'=' * 60}\n") while step < args.max_steps: try: batch = next(data_iter) except StopIteration: data_iter = iter(loader) batch = next(data_iter) if use_mezo: cur_lr = cosine_lr(step, warmup, args.max_steps, args.lr * 0.01, args.lr * 0.001) optimizer.lr = cur_lr loss_val = optimizer.step(compute_loss, batch) total_loss += loss_val else: loss = compute_loss(batch) (loss / args.grad_accum).backward() total_loss += float(loss.item()) if (step + 1) % args.grad_accum == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) cur_lr = cosine_lr(step, warmup, args.max_steps, args.lr, args.lr * 0.1) for pg in optimizer.param_groups: pg["lr"] = cur_lr optimizer.step() optimizer.zero_grad(set_to_none=True) toks += batch["input_ids"][:, :-1].numel() step += 1 if step % args.log_every == 0: dt = time.time() - t0 avg = total_loss / args.log_every ppl = math.exp(min(avg, 20)) tps = toks / dt if dt > 0 else 0 eta_h = (args.max_steps - step) / (step / dt) / 3600 if dt > 0 else 0.0 log_f.write(json.dumps({ "step": step, "loss": round(avg, 4), "ppl": round(ppl, 2), "lr": cur_lr, "tok/s": round(tps), "optimizer": "mezo" if use_mezo else "adamw", }) + "\n") log_f.flush() print(f" step {step:>6}/{args.max_steps} | loss {avg:.4f} | " f"ppl {ppl:>8.2f} | lr {cur_lr:.2e} | " f"{tps:.0f} tok/s | ETA {eta_h:.1f}h") best_loss = min(best_loss, avg) total_loss = 0.0 toks = 0 t0 = time.time() if step % args.save_every == 0: ckpt_dir = os.path.join(args.output_dir, f"ckpt-{step}") os.makedirs(ckpt_dir, exist_ok=True) raw = getattr(model, "_orig_mod", model) torch.save({ "model": raw.state_dict(), "config": config, "step": step, "optimizer": args.optimizer, }, os.path.join(ckpt_dir, "ckpt.pt")) print(f" [SAVE] {ckpt_dir}") final_dir = os.path.join(args.output_dir, "final") os.makedirs(final_dir, exist_ok=True) raw = getattr(model, "_orig_mod", model) torch.save({ "model": raw.state_dict(), "config": config, "step": step, "best_loss": best_loss, }, os.path.join(final_dir, "model.pt")) with open(os.path.join(final_dir, "config.json"), "w", encoding="utf-8") as fh: json.dump(config, fh, indent=2) log_f.close() print(f"\n{'=' * 60}") print(f"DONE — best loss {best_loss:.4f}, ppl {math.exp(min(best_loss, 20)):.2f}") print(f"Saved to {final_dir}") # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- def _build_argparser() -> argparse.ArgumentParser: p = argparse.ArgumentParser(description="Chimera 5.2 CPU-first training") p.add_argument("--config", default="config.json") p.add_argument("--scale", default="tiny", choices=["tiny", "small", "medium", "full"]) p.add_argument("--seq_len", type=int, default=256) p.add_argument("--optimizer", default="mezo", choices=["mezo", "adamw"]) p.add_argument("--batch_size", type=int, default=2) p.add_argument("--grad_accum", type=int, default=8) p.add_argument("--lr", type=float, default=1e-3) p.add_argument("--warmup", type=int, default=200) p.add_argument("--max_steps", type=int, default=5000) p.add_argument("--max_samples", type=int, default=None) p.add_argument("--max_tokens", type=int, default=None) p.add_argument("--bf16", action="store_true", default=True) p.add_argument("--no-bf16", dest="bf16", action="store_false") p.add_argument("--compile", action="store_true", default=False) p.add_argument("--grad_checkpoint", action="store_true", default=True) p.add_argument("--no-grad-checkpoint", dest="grad_checkpoint", action="store_false") p.add_argument("--mezo_direction", choices=["rademacher", "gaussian"], default="rademacher") p.add_argument("--dataset_name", default="roneneldan/TinyStories") p.add_argument("--dataset_config", default=None) p.add_argument("--dataset_split", default="train") p.add_argument("--text_column", default="auto") p.add_argument("--category_filter", default=None) p.add_argument("--include_reasoning", action="store_true", default=False) p.add_argument("--num_workers", type=int, default=2) p.add_argument("--log_every", type=int, default=10) p.add_argument("--save_every", type=int, default=1000) p.add_argument("--output_dir", default="./chimera_output") return p if __name__ == "__main__": args = _build_argparser().parse_args() train(args)