#!/usr/bin/env python3 """ Chimera 5.2 — CPU-first training script. Highlights vs the previous version: * MeZO optimiser uses a single deterministic seed per step, samples each parameter's perturbation direction *on demand* via per-parameter seeds and drops the heavy direction cache. This brings the memory cost of MeZO back down to "1× model" exactly as advertised. * AdamW path uses fused parameter groups and shares the same loss closure as MeZO so accumulation and logging are identical between modes. * Logging never references an undefined ``lr`` (the previous draft printed it before the AdamW step ran on the first accumulator boundary). * Gradient checkpointing falls back to ``use_reentrant=False`` (the modern, faster path). * Tokeniser/dataset loading is unchanged but the Python loops are skipped entirely for ``max_tokens=0``. Recommended commands:: # MeZO smoke test on TinyStories python train.py --scale tiny --seq_len 64 --max_steps 20 --optimizer mezo # AdamW with grad checkpointing + bf16 python train.py --scale small --seq_len 256 --max_steps 1000 \\ --optimizer adamw --grad_checkpoint --bf16 """ from __future__ import annotations import argparse import json import math import os import time # CPU threading must be configured *before* importing torch. def _setup_cpu_runtime() -> None: n_cpus = os.cpu_count() or 4 os.environ.setdefault("OMP_NUM_THREADS", str(n_cpus)) os.environ.setdefault("MKL_NUM_THREADS", str(n_cpus)) os.environ.setdefault("KMP_AFFINITY", "granularity=fine,compact,1,0") os.environ.setdefault("KMP_BLOCKTIME", "1") os.environ.setdefault("MALLOC_CONF", "background_thread:true,metadata_thp:auto") _setup_cpu_runtime() import torch import torch.nn as nn from torch.utils.data import DataLoader from chimera import Chimera51ForCausalLM from chimera.paths import DEFAULT_CONFIG_PATH from chimera.training import ( build_sequence_dataset, apply_standard_config_tweaks, MeZOOptimizer, train_standard_loop, ) from chimera.quantization import BitLinear torch.set_num_threads(int(os.environ.get("OMP_NUM_THREADS", os.cpu_count() or 4))) try: torch.set_num_interop_threads(int(os.environ.get("CHIMERA_INTEROP_THREADS", "1"))) except RuntimeError: pass # Optional Intel Extension for PyTorch. HAS_IPEX = False try: # pragma: no cover - optional dependency. import intel_extension_for_pytorch as ipex # noqa: F401 HAS_IPEX = True except Exception: pass # Dataset & tokenisation helpers. # --------------------------------------------------------------------------- def build_dataset(seq_len: int, max_samples=None, max_tokens=None, split: str = "train", dataset_name: str = "roneneldan/TinyStories", dataset_config: str = None, text_column: str = "auto", category_filter: str = None, include_reasoning: bool = False): from chimera import ChimeraTokenizer tok = ChimeraTokenizer(pretrained="o200k_base") dataset = build_sequence_dataset( seq_len, max_samples=max_samples, max_tokens=max_tokens, split=split, dataset_name=dataset_name, dataset_config=dataset_config, text_column=text_column, category_filter=category_filter, include_reasoning=include_reasoning, ) return dataset, tok # --------------------------------------------------------------------------- # Main loop. # --------------------------------------------------------------------------- def train(args) -> None: with open(args.config) as f: config = json.load(f) config = apply_standard_config_tweaks(config, scale=args.scale, seq_len=args.seq_len) use_mezo = (args.optimizer == "mezo") use_bf16 = bool(args.bf16) use_compile = bool(args.compile) print("=" * 60) print(f"CHIMERA 5.2 TRAINING — scale={args.scale}, " f"optimizer={'MeZO' if use_mezo else 'AdamW'}, bf16={use_bf16}") print(f"Layers={config['num_hidden_layers']} hidden={config['hidden_size']} " f"vocab={config['vocab_size']} seq_len={args.seq_len} steps={args.max_steps}") print(f"Threads: {torch.get_num_threads()} IPEX={HAS_IPEX}") print("=" * 60) model = Chimera51ForCausalLM(config) counts = model.count_parameters() print(f"Params: total={counts['total']:,} ternary={counts['ternary']:,}") if args.grad_checkpoint and not use_mezo: model.enable_gradient_checkpointing() print("[OPT] Gradient checkpointing ON") if HAS_IPEX and not use_mezo: adamw = torch.optim.AdamW(model.parameters(), lr=args.lr) model, adamw = ipex.optimize( model, optimizer=adamw, dtype=torch.bfloat16 if use_bf16 else torch.float32, level="O1") print("[OPT] IPEX optimisation applied (level O1)") else: adamw = None if use_compile: print("[OPT] Compiling model with torch.compile (inductor)...") model = torch.compile(model, backend="inductor", mode="default", dynamic=True) dataset, tok = build_dataset( args.seq_len, max_samples=args.max_samples, max_tokens=args.max_tokens, split=args.dataset_split, dataset_name=args.dataset_name, dataset_config=args.dataset_config, text_column=args.text_column, category_filter=args.category_filter, include_reasoning=args.include_reasoning, ) loader = DataLoader( dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, drop_last=True, persistent_workers=args.num_workers > 0, prefetch_factor=2 if args.num_workers > 0 else None, ) if use_mezo: optimizer = MeZOOptimizer( model, lr=args.lr * 0.01, eps=1e-3, weight_decay=0.1, momentum=0.9, direction=args.mezo_direction, ) else: no_decay = {"A_log", "dt_bias", "norm", "bias", "embed", "energy_weights"} decay_params, no_decay_params = [], [] for n, p in model.named_parameters(): if not p.requires_grad: continue if any(tag in n for tag in no_decay): no_decay_params.append(p) else: decay_params.append(p) if adamw is None: optimizer = torch.optim.AdamW( [{"params": decay_params, "weight_decay": 0.1}, {"params": no_decay_params, "weight_decay": 0.0}], lr=args.lr, betas=(0.9, 0.95)) else: optimizer = adamw def compute_loss(batch) -> torch.Tensor: ids = batch["input_ids"][:, :-1] labels = batch["labels"][:, 1:] if use_bf16: with torch.autocast(device_type="cpu", dtype=torch.bfloat16): out = model(ids, labels=labels) else: out = model(ids, labels=labels) return out.loss train_standard_loop(args, model, config, loader, compute_loss, optimizer, use_mezo) # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- def _build_argparser() -> argparse.ArgumentParser: p = argparse.ArgumentParser(description="Chimera 5.2 CPU-first training") p.add_argument("--config", default=str(DEFAULT_CONFIG_PATH)) p.add_argument("--scale", default="tiny", choices=["tiny", "small", "medium", "full"]) p.add_argument("--seq_len", type=int, default=256) p.add_argument("--optimizer", default="mezo", choices=["mezo", "adamw"]) p.add_argument("--batch_size", type=int, default=2) p.add_argument("--grad_accum", type=int, default=8) p.add_argument("--lr", type=float, default=1e-3) p.add_argument("--warmup", type=int, default=200) p.add_argument("--max_steps", type=int, default=5000) p.add_argument("--max_samples", type=int, default=None) p.add_argument("--max_tokens", type=int, default=None) p.add_argument("--bf16", action="store_true", default=True) p.add_argument("--no-bf16", dest="bf16", action="store_false") p.add_argument("--compile", action="store_true", default=False) p.add_argument("--grad_checkpoint", action="store_true", default=True) p.add_argument("--no-grad-checkpoint", dest="grad_checkpoint", action="store_false") p.add_argument("--mezo_direction", choices=["rademacher", "gaussian"], default="rademacher") p.add_argument("--dataset_name", default="roneneldan/TinyStories") p.add_argument("--dataset_config", default=None) p.add_argument("--dataset_split", default="train") p.add_argument("--text_column", default="auto") p.add_argument("--category_filter", default=None) p.add_argument("--include_reasoning", action="store_true", default=False) p.add_argument("--num_workers", type=int, default=2) p.add_argument("--log_every", type=int, default=10) p.add_argument("--save_every", type=int, default=1000) p.add_argument("--output_dir", default="./chimera_output") return p if __name__ == "__main__": args = _build_argparser().parse_args() train(args)