File size: 9,223 Bytes
11c11f8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 | #!/usr/bin/env python3
"""
Chimera 5.2 — CPU-first training script.
Highlights vs the previous version:
* MeZO optimiser uses a single deterministic seed per step, samples each
parameter's perturbation direction *on demand* via per-parameter seeds and
drops the heavy direction cache. This brings the memory cost of MeZO back
down to "1× model" exactly as advertised.
* AdamW path uses fused parameter groups and shares the same loss closure as
MeZO so accumulation and logging are identical between modes.
* Logging never references an undefined ``lr`` (the previous draft printed it
before the AdamW step ran on the first accumulator boundary).
* Gradient checkpointing falls back to ``use_reentrant=False`` (the modern,
faster path).
* Tokeniser/dataset loading is unchanged but the Python loops are skipped
entirely for ``max_tokens=0``.
Recommended commands::
# MeZO smoke test on TinyStories
python train.py --scale tiny --seq_len 64 --max_steps 20 --optimizer mezo
# AdamW with grad checkpointing + bf16
python train.py --scale small --seq_len 256 --max_steps 1000 \\
--optimizer adamw --grad_checkpoint --bf16
"""
from __future__ import annotations
import argparse
import json
import math
import os
import time
# CPU threading must be configured *before* importing torch.
def _setup_cpu_runtime() -> None:
n_cpus = os.cpu_count() or 4
os.environ.setdefault("OMP_NUM_THREADS", str(n_cpus))
os.environ.setdefault("MKL_NUM_THREADS", str(n_cpus))
os.environ.setdefault("KMP_AFFINITY", "granularity=fine,compact,1,0")
os.environ.setdefault("KMP_BLOCKTIME", "1")
os.environ.setdefault("MALLOC_CONF", "background_thread:true,metadata_thp:auto")
_setup_cpu_runtime()
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from chimera import Chimera51ForCausalLM
from chimera.paths import DEFAULT_CONFIG_PATH
from chimera.training import (
build_sequence_dataset,
apply_standard_config_tweaks,
MeZOOptimizer,
train_standard_loop,
)
from chimera.quantization import BitLinear
torch.set_num_threads(int(os.environ.get("OMP_NUM_THREADS", os.cpu_count() or 4)))
try:
torch.set_num_interop_threads(int(os.environ.get("CHIMERA_INTEROP_THREADS", "1")))
except RuntimeError:
pass
# Optional Intel Extension for PyTorch.
HAS_IPEX = False
try: # pragma: no cover - optional dependency.
import intel_extension_for_pytorch as ipex # noqa: F401
HAS_IPEX = True
except Exception:
pass
# Dataset & tokenisation helpers.
# ---------------------------------------------------------------------------
def build_dataset(seq_len: int, max_samples=None, max_tokens=None,
split: str = "train",
dataset_name: str = "roneneldan/TinyStories",
dataset_config: str = None, text_column: str = "auto",
category_filter: str = None,
include_reasoning: bool = False):
from chimera import ChimeraTokenizer
tok = ChimeraTokenizer(pretrained="o200k_base")
dataset = build_sequence_dataset(
seq_len,
max_samples=max_samples,
max_tokens=max_tokens,
split=split,
dataset_name=dataset_name,
dataset_config=dataset_config,
text_column=text_column,
category_filter=category_filter,
include_reasoning=include_reasoning,
)
return dataset, tok
# ---------------------------------------------------------------------------
# Main loop.
# ---------------------------------------------------------------------------
def train(args) -> None:
with open(args.config) as f:
config = json.load(f)
config = apply_standard_config_tweaks(config, scale=args.scale, seq_len=args.seq_len)
use_mezo = (args.optimizer == "mezo")
use_bf16 = bool(args.bf16)
use_compile = bool(args.compile)
print("=" * 60)
print(f"CHIMERA 5.2 TRAINING — scale={args.scale}, "
f"optimizer={'MeZO' if use_mezo else 'AdamW'}, bf16={use_bf16}")
print(f"Layers={config['num_hidden_layers']} hidden={config['hidden_size']} "
f"vocab={config['vocab_size']} seq_len={args.seq_len} steps={args.max_steps}")
print(f"Threads: {torch.get_num_threads()} IPEX={HAS_IPEX}")
print("=" * 60)
model = Chimera51ForCausalLM(config)
counts = model.count_parameters()
print(f"Params: total={counts['total']:,} ternary={counts['ternary']:,}")
if args.grad_checkpoint and not use_mezo:
model.enable_gradient_checkpointing()
print("[OPT] Gradient checkpointing ON")
if HAS_IPEX and not use_mezo:
adamw = torch.optim.AdamW(model.parameters(), lr=args.lr)
model, adamw = ipex.optimize(
model, optimizer=adamw,
dtype=torch.bfloat16 if use_bf16 else torch.float32, level="O1")
print("[OPT] IPEX optimisation applied (level O1)")
else:
adamw = None
if use_compile:
print("[OPT] Compiling model with torch.compile (inductor)...")
model = torch.compile(model, backend="inductor", mode="default", dynamic=True)
dataset, tok = build_dataset(
args.seq_len, max_samples=args.max_samples, max_tokens=args.max_tokens,
split=args.dataset_split, dataset_name=args.dataset_name,
dataset_config=args.dataset_config, text_column=args.text_column,
category_filter=args.category_filter,
include_reasoning=args.include_reasoning,
)
loader = DataLoader(
dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.num_workers, drop_last=True,
persistent_workers=args.num_workers > 0,
prefetch_factor=2 if args.num_workers > 0 else None,
)
if use_mezo:
optimizer = MeZOOptimizer(
model, lr=args.lr * 0.01, eps=1e-3,
weight_decay=0.1, momentum=0.9, direction=args.mezo_direction,
)
else:
no_decay = {"A_log", "dt_bias", "norm", "bias", "embed", "energy_weights"}
decay_params, no_decay_params = [], []
for n, p in model.named_parameters():
if not p.requires_grad:
continue
if any(tag in n for tag in no_decay):
no_decay_params.append(p)
else:
decay_params.append(p)
if adamw is None:
optimizer = torch.optim.AdamW(
[{"params": decay_params, "weight_decay": 0.1},
{"params": no_decay_params, "weight_decay": 0.0}],
lr=args.lr, betas=(0.9, 0.95))
else:
optimizer = adamw
def compute_loss(batch) -> torch.Tensor:
ids = batch["input_ids"][:, :-1]
labels = batch["labels"][:, 1:]
if use_bf16:
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
out = model(ids, labels=labels)
else:
out = model(ids, labels=labels)
return out.loss
train_standard_loop(args, model, config, loader, compute_loss, optimizer, use_mezo)
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def _build_argparser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="Chimera 5.2 CPU-first training")
p.add_argument("--config", default=str(DEFAULT_CONFIG_PATH))
p.add_argument("--scale", default="tiny", choices=["tiny", "small", "medium", "full"])
p.add_argument("--seq_len", type=int, default=256)
p.add_argument("--optimizer", default="mezo", choices=["mezo", "adamw"])
p.add_argument("--batch_size", type=int, default=2)
p.add_argument("--grad_accum", type=int, default=8)
p.add_argument("--lr", type=float, default=1e-3)
p.add_argument("--warmup", type=int, default=200)
p.add_argument("--max_steps", type=int, default=5000)
p.add_argument("--max_samples", type=int, default=None)
p.add_argument("--max_tokens", type=int, default=None)
p.add_argument("--bf16", action="store_true", default=True)
p.add_argument("--no-bf16", dest="bf16", action="store_false")
p.add_argument("--compile", action="store_true", default=False)
p.add_argument("--grad_checkpoint", action="store_true", default=True)
p.add_argument("--no-grad-checkpoint", dest="grad_checkpoint", action="store_false")
p.add_argument("--mezo_direction", choices=["rademacher", "gaussian"],
default="rademacher")
p.add_argument("--dataset_name", default="roneneldan/TinyStories")
p.add_argument("--dataset_config", default=None)
p.add_argument("--dataset_split", default="train")
p.add_argument("--text_column", default="auto")
p.add_argument("--category_filter", default=None)
p.add_argument("--include_reasoning", action="store_true", default=False)
p.add_argument("--num_workers", type=int, default=2)
p.add_argument("--log_every", type=int, default=10)
p.add_argument("--save_every", type=int, default=1000)
p.add_argument("--output_dir", default="./chimera_output")
return p
if __name__ == "__main__":
args = _build_argparser().parse_args()
train(args)
|