| |
| """ |
| Chimera 5.2 — CPU-first training script. |
| |
| Highlights vs the previous version: |
| |
| * MeZO optimiser uses a single deterministic seed per step, samples each |
| parameter's perturbation direction *on demand* via per-parameter seeds and |
| drops the heavy direction cache. This brings the memory cost of MeZO back |
| down to "1× model" exactly as advertised. |
| * AdamW path uses fused parameter groups and shares the same loss closure as |
| MeZO so accumulation and logging are identical between modes. |
| * Logging never references an undefined ``lr`` (the previous draft printed it |
| before the AdamW step ran on the first accumulator boundary). |
| * Gradient checkpointing falls back to ``use_reentrant=False`` (the modern, |
| faster path). |
| * Tokeniser/dataset loading is unchanged but the Python loops are skipped |
| entirely for ``max_tokens=0``. |
| |
| Recommended commands:: |
| |
| # MeZO smoke test on TinyStories |
| python train.py --scale tiny --seq_len 64 --max_steps 20 --optimizer mezo |
| |
| # AdamW with grad checkpointing + bf16 |
| python train.py --scale small --seq_len 256 --max_steps 1000 \\ |
| --optimizer adamw --grad_checkpoint --bf16 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import math |
| import os |
| import time |
|
|
| |
| def _setup_cpu_runtime() -> None: |
| n_cpus = os.cpu_count() or 4 |
| os.environ.setdefault("OMP_NUM_THREADS", str(n_cpus)) |
| os.environ.setdefault("MKL_NUM_THREADS", str(n_cpus)) |
| os.environ.setdefault("KMP_AFFINITY", "granularity=fine,compact,1,0") |
| os.environ.setdefault("KMP_BLOCKTIME", "1") |
| os.environ.setdefault("MALLOC_CONF", "background_thread:true,metadata_thp:auto") |
|
|
|
|
| _setup_cpu_runtime() |
|
|
|
|
| import torch |
| import torch.nn as nn |
| from torch.utils.data import DataLoader |
|
|
| from chimera import Chimera51ForCausalLM |
| from chimera.paths import DEFAULT_CONFIG_PATH |
| from chimera.training import ( |
| build_sequence_dataset, |
| apply_standard_config_tweaks, |
| MeZOOptimizer, |
| train_standard_loop, |
| ) |
| from chimera.quantization import BitLinear |
|
|
|
|
| torch.set_num_threads(int(os.environ.get("OMP_NUM_THREADS", os.cpu_count() or 4))) |
| try: |
| torch.set_num_interop_threads(int(os.environ.get("CHIMERA_INTEROP_THREADS", "1"))) |
| except RuntimeError: |
| pass |
|
|
|
|
| |
| HAS_IPEX = False |
| try: |
| import intel_extension_for_pytorch as ipex |
| HAS_IPEX = True |
| except Exception: |
| pass |
|
|
|
|
| |
| |
|
|
| def build_dataset(seq_len: int, max_samples=None, max_tokens=None, |
| split: str = "train", |
| dataset_name: str = "roneneldan/TinyStories", |
| dataset_config: str = None, text_column: str = "auto", |
| category_filter: str = None, |
| include_reasoning: bool = False): |
| from chimera import ChimeraTokenizer |
|
|
| tok = ChimeraTokenizer(pretrained="o200k_base") |
| dataset = build_sequence_dataset( |
| seq_len, |
| max_samples=max_samples, |
| max_tokens=max_tokens, |
| split=split, |
| dataset_name=dataset_name, |
| dataset_config=dataset_config, |
| text_column=text_column, |
| category_filter=category_filter, |
| include_reasoning=include_reasoning, |
| ) |
| return dataset, tok |
|
|
|
|
| |
| |
| |
|
|
| def train(args) -> None: |
| with open(args.config) as f: |
| config = json.load(f) |
| config = apply_standard_config_tweaks(config, scale=args.scale, seq_len=args.seq_len) |
|
|
| use_mezo = (args.optimizer == "mezo") |
| use_bf16 = bool(args.bf16) |
| use_compile = bool(args.compile) |
|
|
| print("=" * 60) |
| print(f"CHIMERA 5.2 TRAINING — scale={args.scale}, " |
| f"optimizer={'MeZO' if use_mezo else 'AdamW'}, bf16={use_bf16}") |
| print(f"Layers={config['num_hidden_layers']} hidden={config['hidden_size']} " |
| f"vocab={config['vocab_size']} seq_len={args.seq_len} steps={args.max_steps}") |
| print(f"Threads: {torch.get_num_threads()} IPEX={HAS_IPEX}") |
| print("=" * 60) |
|
|
| model = Chimera51ForCausalLM(config) |
| counts = model.count_parameters() |
| print(f"Params: total={counts['total']:,} ternary={counts['ternary']:,}") |
|
|
| if args.grad_checkpoint and not use_mezo: |
| model.enable_gradient_checkpointing() |
| print("[OPT] Gradient checkpointing ON") |
|
|
| if HAS_IPEX and not use_mezo: |
| adamw = torch.optim.AdamW(model.parameters(), lr=args.lr) |
| model, adamw = ipex.optimize( |
| model, optimizer=adamw, |
| dtype=torch.bfloat16 if use_bf16 else torch.float32, level="O1") |
| print("[OPT] IPEX optimisation applied (level O1)") |
| else: |
| adamw = None |
|
|
| if use_compile: |
| print("[OPT] Compiling model with torch.compile (inductor)...") |
| model = torch.compile(model, backend="inductor", mode="default", dynamic=True) |
|
|
| dataset, tok = build_dataset( |
| args.seq_len, max_samples=args.max_samples, max_tokens=args.max_tokens, |
| split=args.dataset_split, dataset_name=args.dataset_name, |
| dataset_config=args.dataset_config, text_column=args.text_column, |
| category_filter=args.category_filter, |
| include_reasoning=args.include_reasoning, |
| ) |
| loader = DataLoader( |
| dataset, batch_size=args.batch_size, shuffle=True, |
| num_workers=args.num_workers, drop_last=True, |
| persistent_workers=args.num_workers > 0, |
| prefetch_factor=2 if args.num_workers > 0 else None, |
| ) |
|
|
| if use_mezo: |
| optimizer = MeZOOptimizer( |
| model, lr=args.lr * 0.01, eps=1e-3, |
| weight_decay=0.1, momentum=0.9, direction=args.mezo_direction, |
| ) |
| else: |
| no_decay = {"A_log", "dt_bias", "norm", "bias", "embed", "energy_weights"} |
| decay_params, no_decay_params = [], [] |
| for n, p in model.named_parameters(): |
| if not p.requires_grad: |
| continue |
| if any(tag in n for tag in no_decay): |
| no_decay_params.append(p) |
| else: |
| decay_params.append(p) |
| if adamw is None: |
| optimizer = torch.optim.AdamW( |
| [{"params": decay_params, "weight_decay": 0.1}, |
| {"params": no_decay_params, "weight_decay": 0.0}], |
| lr=args.lr, betas=(0.9, 0.95)) |
| else: |
| optimizer = adamw |
|
|
| def compute_loss(batch) -> torch.Tensor: |
| ids = batch["input_ids"][:, :-1] |
| labels = batch["labels"][:, 1:] |
| if use_bf16: |
| with torch.autocast(device_type="cpu", dtype=torch.bfloat16): |
| out = model(ids, labels=labels) |
| else: |
| out = model(ids, labels=labels) |
| return out.loss |
|
|
| train_standard_loop(args, model, config, loader, compute_loss, optimizer, use_mezo) |
|
|
|
|
| |
| |
| |
|
|
| def _build_argparser() -> argparse.ArgumentParser: |
| p = argparse.ArgumentParser(description="Chimera 5.2 CPU-first training") |
| p.add_argument("--config", default=str(DEFAULT_CONFIG_PATH)) |
| p.add_argument("--scale", default="tiny", choices=["tiny", "small", "medium", "full"]) |
| p.add_argument("--seq_len", type=int, default=256) |
| p.add_argument("--optimizer", default="mezo", choices=["mezo", "adamw"]) |
| p.add_argument("--batch_size", type=int, default=2) |
| p.add_argument("--grad_accum", type=int, default=8) |
| p.add_argument("--lr", type=float, default=1e-3) |
| p.add_argument("--warmup", type=int, default=200) |
| p.add_argument("--max_steps", type=int, default=5000) |
| p.add_argument("--max_samples", type=int, default=None) |
| p.add_argument("--max_tokens", type=int, default=None) |
| p.add_argument("--bf16", action="store_true", default=True) |
| p.add_argument("--no-bf16", dest="bf16", action="store_false") |
| p.add_argument("--compile", action="store_true", default=False) |
| p.add_argument("--grad_checkpoint", action="store_true", default=True) |
| p.add_argument("--no-grad-checkpoint", dest="grad_checkpoint", action="store_false") |
| p.add_argument("--mezo_direction", choices=["rademacher", "gaussian"], |
| default="rademacher") |
| p.add_argument("--dataset_name", default="roneneldan/TinyStories") |
| p.add_argument("--dataset_config", default=None) |
| p.add_argument("--dataset_split", default="train") |
| p.add_argument("--text_column", default="auto") |
| p.add_argument("--category_filter", default=None) |
| p.add_argument("--include_reasoning", action="store_true", default=False) |
| p.add_argument("--num_workers", type=int, default=2) |
| p.add_argument("--log_every", type=int, default=10) |
| p.add_argument("--save_every", type=int, default=1000) |
| p.add_argument("--output_dir", default="./chimera_output") |
| return p |
|
|
|
|
| if __name__ == "__main__": |
| args = _build_argparser().parse_args() |
| train(args) |
|
|