#!/usr/bin/env python3 """Chimera 5.2 — Fast CPU training with pre-tokenized dataset cache.""" from __future__ import annotations import argparse import json import math import os # CPU threading must be configured *before* importing torch. ncpus = int(os.environ.get("OMP_NUM_THREADS", os.cpu_count() or 4)) os.environ["OMP_NUM_THREADS"] = str(ncpus) os.environ["MKL_NUM_THREADS"] = str(ncpus) import torch from torch.utils.data import DataLoader from chimera import Chimera51ForCausalLM from chimera.paths import DEFAULT_CONFIG_PATH from chimera.training import ( PreTokenizedDataset, apply_standard_config_tweaks, train_fast_loop, ) torch.set_num_threads(ncpus) try: torch.set_num_interop_threads(1) except RuntimeError: pass def build_or_load_dataset(seq_len: int, max_samples: int, cache_dir: str = "./cache"): cache_path = os.path.join(cache_dir, f"tiny_stories_{seq_len}_{max_samples}.pt") os.makedirs(cache_dir, exist_ok=True) if os.path.exists(cache_path): print(f"[CACHE] Loading pre-tokenized dataset from {cache_path}") chunks = torch.load(cache_path, weights_only=False) return PreTokenizedDataset(chunks, seq_len) from datasets import load_dataset from chimera import ChimeraTokenizer print(f"[DATA] Downloading TinyStories...") ds = load_dataset("roneneldan/TinyStories", split="train", streaming=True) tok = ChimeraTokenizer(pretrained="o200k_base") target = max_samples * (seq_len + 1) buffer = torch.empty(target, dtype=torch.long) buf_idx = 0 processed = 0 for ex in ds: text = ex.get("text", "") if not text: continue ids = tok.encode(text, add_special_tokens=False) ids.append(tok.eos_token_id) n = len(ids) if buf_idx + n > target: n = target - buf_idx if n <= 0: break ids = ids[:n] if n > 0: buffer[buf_idx:buf_idx + n] = torch.tensor(ids, dtype=torch.long) buf_idx += n processed += 1 if (processed % 1000) == 0: print(f" {processed:,} stories, {buf_idx:,}/{target} tokens...") if buf_idx >= target: break all_ids = buffer[:buf_idx] n = all_ids.numel() // (seq_len + 1) chunks = all_ids[:n * (seq_len + 1)] torch.save(chunks, cache_path) print(f"[CACHE] Saved {chunks.numel():,} tokens to {cache_path}") return PreTokenizedDataset(chunks, seq_len) def train(args) -> None: with open(args.config) as f: config = json.load(f) config = apply_standard_config_tweaks(config, scale=args.scale, seq_len=args.seq_len) print("=" * 60) print(f"CHIMERA 5.2 FAST TRAIN — scale={args.scale}, seq_len={args.seq_len}, steps={args.max_steps}") print(f"Layers={config['num_hidden_layers']} hidden={config['hidden_size']} vocab={config['vocab_size']}") print(f"Threads: {torch.get_num_threads()} bf16={args.bf16} compile={args.compile}") print("=" * 60) model = Chimera51ForCausalLM(config) counts = model.count_parameters() print(f"Params: total={counts['total']:,} ternary={counts['ternary']:,}") if args.compile: print("[OPT] Compiling model...") model = torch.compile(model, backend="inductor", mode="default", dynamic=True) dataset = build_or_load_dataset(args.seq_len, args.max_samples, args.cache_dir) loader = DataLoader( dataset, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True, ) def compute_loss(batch) -> torch.Tensor: ids = batch["input_ids"] labels = batch["labels"] if args.bf16: with torch.autocast(device_type="cpu", dtype=torch.bfloat16): out = model(ids, labels=labels) else: out = model(ids, labels=labels) return out.loss train_fast_loop(args, model, config, loader, compute_loss) if __name__ == "__main__": p = argparse.ArgumentParser(description="Chimera 5.2 Fast CPU training") p.add_argument("--config", default=str(DEFAULT_CONFIG_PATH)) p.add_argument("--scale", default="tiny", choices=["tiny", "small", "medium", "full"]) p.add_argument("--seq_len", type=int, default=32) p.add_argument("--batch_size", type=int, default=4) p.add_argument("--lr", type=float, default=1e-3) p.add_argument("--warmup", type=int, default=100) p.add_argument("--max_steps", type=int, default=1000) p.add_argument("--max_samples", type=int, default=5000) p.add_argument("--bf16", action="store_true", default=False) p.add_argument("--compile", action="store_true", default=False) p.add_argument("--cache_dir", default="./cache") p.add_argument("--log_every", type=int, default=10) p.add_argument("--save_every", type=int, default=500) p.add_argument("--output_dir", default="./chimera_output") args = p.parse_args() train(args)