chomera / train.py
Lgr54HFi's picture
Upload folder using huggingface_hub
11c11f8 verified
#!/usr/bin/env python3
"""
Chimera 5.2 — CPU-first training script.
Highlights vs the previous version:
* MeZO optimiser uses a single deterministic seed per step, samples each
parameter's perturbation direction *on demand* via per-parameter seeds and
drops the heavy direction cache. This brings the memory cost of MeZO back
down to "1× model" exactly as advertised.
* AdamW path uses fused parameter groups and shares the same loss closure as
MeZO so accumulation and logging are identical between modes.
* Logging never references an undefined ``lr`` (the previous draft printed it
before the AdamW step ran on the first accumulator boundary).
* Gradient checkpointing falls back to ``use_reentrant=False`` (the modern,
faster path).
* Tokeniser/dataset loading is unchanged but the Python loops are skipped
entirely for ``max_tokens=0``.
Recommended commands::
# MeZO smoke test on TinyStories
python train.py --scale tiny --seq_len 64 --max_steps 20 --optimizer mezo
# AdamW with grad checkpointing + bf16
python train.py --scale small --seq_len 256 --max_steps 1000 \\
--optimizer adamw --grad_checkpoint --bf16
"""
from __future__ import annotations
import argparse
import json
import math
import os
import time
# CPU threading must be configured *before* importing torch.
def _setup_cpu_runtime() -> None:
n_cpus = os.cpu_count() or 4
os.environ.setdefault("OMP_NUM_THREADS", str(n_cpus))
os.environ.setdefault("MKL_NUM_THREADS", str(n_cpus))
os.environ.setdefault("KMP_AFFINITY", "granularity=fine,compact,1,0")
os.environ.setdefault("KMP_BLOCKTIME", "1")
os.environ.setdefault("MALLOC_CONF", "background_thread:true,metadata_thp:auto")
_setup_cpu_runtime()
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from chimera import Chimera51ForCausalLM
from chimera.paths import DEFAULT_CONFIG_PATH
from chimera.training import (
build_sequence_dataset,
apply_standard_config_tweaks,
MeZOOptimizer,
train_standard_loop,
)
from chimera.quantization import BitLinear
torch.set_num_threads(int(os.environ.get("OMP_NUM_THREADS", os.cpu_count() or 4)))
try:
torch.set_num_interop_threads(int(os.environ.get("CHIMERA_INTEROP_THREADS", "1")))
except RuntimeError:
pass
# Optional Intel Extension for PyTorch.
HAS_IPEX = False
try: # pragma: no cover - optional dependency.
import intel_extension_for_pytorch as ipex # noqa: F401
HAS_IPEX = True
except Exception:
pass
# Dataset & tokenisation helpers.
# ---------------------------------------------------------------------------
def build_dataset(seq_len: int, max_samples=None, max_tokens=None,
split: str = "train",
dataset_name: str = "roneneldan/TinyStories",
dataset_config: str = None, text_column: str = "auto",
category_filter: str = None,
include_reasoning: bool = False):
from chimera import ChimeraTokenizer
tok = ChimeraTokenizer(pretrained="o200k_base")
dataset = build_sequence_dataset(
seq_len,
max_samples=max_samples,
max_tokens=max_tokens,
split=split,
dataset_name=dataset_name,
dataset_config=dataset_config,
text_column=text_column,
category_filter=category_filter,
include_reasoning=include_reasoning,
)
return dataset, tok
# ---------------------------------------------------------------------------
# Main loop.
# ---------------------------------------------------------------------------
def train(args) -> None:
with open(args.config) as f:
config = json.load(f)
config = apply_standard_config_tweaks(config, scale=args.scale, seq_len=args.seq_len)
use_mezo = (args.optimizer == "mezo")
use_bf16 = bool(args.bf16)
use_compile = bool(args.compile)
print("=" * 60)
print(f"CHIMERA 5.2 TRAINING — scale={args.scale}, "
f"optimizer={'MeZO' if use_mezo else 'AdamW'}, bf16={use_bf16}")
print(f"Layers={config['num_hidden_layers']} hidden={config['hidden_size']} "
f"vocab={config['vocab_size']} seq_len={args.seq_len} steps={args.max_steps}")
print(f"Threads: {torch.get_num_threads()} IPEX={HAS_IPEX}")
print("=" * 60)
model = Chimera51ForCausalLM(config)
counts = model.count_parameters()
print(f"Params: total={counts['total']:,} ternary={counts['ternary']:,}")
if args.grad_checkpoint and not use_mezo:
model.enable_gradient_checkpointing()
print("[OPT] Gradient checkpointing ON")
if HAS_IPEX and not use_mezo:
adamw = torch.optim.AdamW(model.parameters(), lr=args.lr)
model, adamw = ipex.optimize(
model, optimizer=adamw,
dtype=torch.bfloat16 if use_bf16 else torch.float32, level="O1")
print("[OPT] IPEX optimisation applied (level O1)")
else:
adamw = None
if use_compile:
print("[OPT] Compiling model with torch.compile (inductor)...")
model = torch.compile(model, backend="inductor", mode="default", dynamic=True)
dataset, tok = build_dataset(
args.seq_len, max_samples=args.max_samples, max_tokens=args.max_tokens,
split=args.dataset_split, dataset_name=args.dataset_name,
dataset_config=args.dataset_config, text_column=args.text_column,
category_filter=args.category_filter,
include_reasoning=args.include_reasoning,
)
loader = DataLoader(
dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.num_workers, drop_last=True,
persistent_workers=args.num_workers > 0,
prefetch_factor=2 if args.num_workers > 0 else None,
)
if use_mezo:
optimizer = MeZOOptimizer(
model, lr=args.lr * 0.01, eps=1e-3,
weight_decay=0.1, momentum=0.9, direction=args.mezo_direction,
)
else:
no_decay = {"A_log", "dt_bias", "norm", "bias", "embed", "energy_weights"}
decay_params, no_decay_params = [], []
for n, p in model.named_parameters():
if not p.requires_grad:
continue
if any(tag in n for tag in no_decay):
no_decay_params.append(p)
else:
decay_params.append(p)
if adamw is None:
optimizer = torch.optim.AdamW(
[{"params": decay_params, "weight_decay": 0.1},
{"params": no_decay_params, "weight_decay": 0.0}],
lr=args.lr, betas=(0.9, 0.95))
else:
optimizer = adamw
def compute_loss(batch) -> torch.Tensor:
ids = batch["input_ids"][:, :-1]
labels = batch["labels"][:, 1:]
if use_bf16:
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
out = model(ids, labels=labels)
else:
out = model(ids, labels=labels)
return out.loss
train_standard_loop(args, model, config, loader, compute_loss, optimizer, use_mezo)
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def _build_argparser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="Chimera 5.2 CPU-first training")
p.add_argument("--config", default=str(DEFAULT_CONFIG_PATH))
p.add_argument("--scale", default="tiny", choices=["tiny", "small", "medium", "full"])
p.add_argument("--seq_len", type=int, default=256)
p.add_argument("--optimizer", default="mezo", choices=["mezo", "adamw"])
p.add_argument("--batch_size", type=int, default=2)
p.add_argument("--grad_accum", type=int, default=8)
p.add_argument("--lr", type=float, default=1e-3)
p.add_argument("--warmup", type=int, default=200)
p.add_argument("--max_steps", type=int, default=5000)
p.add_argument("--max_samples", type=int, default=None)
p.add_argument("--max_tokens", type=int, default=None)
p.add_argument("--bf16", action="store_true", default=True)
p.add_argument("--no-bf16", dest="bf16", action="store_false")
p.add_argument("--compile", action="store_true", default=False)
p.add_argument("--grad_checkpoint", action="store_true", default=True)
p.add_argument("--no-grad-checkpoint", dest="grad_checkpoint", action="store_false")
p.add_argument("--mezo_direction", choices=["rademacher", "gaussian"],
default="rademacher")
p.add_argument("--dataset_name", default="roneneldan/TinyStories")
p.add_argument("--dataset_config", default=None)
p.add_argument("--dataset_split", default="train")
p.add_argument("--text_column", default="auto")
p.add_argument("--category_filter", default=None)
p.add_argument("--include_reasoning", action="store_true", default=False)
p.add_argument("--num_workers", type=int, default=2)
p.add_argument("--log_every", type=int, default=10)
p.add_argument("--save_every", type=int, default=1000)
p.add_argument("--output_dir", default="./chimera_output")
return p
if __name__ == "__main__":
args = _build_argparser().parse_args()
train(args)