| |
| """Chimera 5.2 — Fast CPU training with pre-tokenized dataset cache.""" |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import math |
| import os |
| import sys |
| import time |
|
|
| |
| ncpus = int(os.environ.get("OMP_NUM_THREADS", os.cpu_count() or 4)) |
| os.environ["OMP_NUM_THREADS"] = str(ncpus) |
| os.environ["MKL_NUM_THREADS"] = str(ncpus) |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader, Dataset |
|
|
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
| from chimera import Chimera51ForCausalLM |
|
|
|
|
| torch.set_num_threads(ncpus) |
| try: |
| torch.set_num_interop_threads(1) |
| except RuntimeError: |
| pass |
|
|
|
|
| |
| |
| |
|
|
| class PreTokenizedDataset(Dataset): |
| def __init__(self, ids: torch.Tensor, seq_len: int): |
| n = ids.numel() // (seq_len + 1) |
| self.chunks = ids[:n * (seq_len + 1)].view(n, seq_len + 1) |
| self.seq_len = seq_len |
|
|
| def __len__(self) -> int: |
| return self.chunks.size(0) |
|
|
| def __getitem__(self, idx: int): |
| c = self.chunks[idx] |
| return {"input_ids": c[:-1], "labels": c[1:]} |
|
|
|
|
| def build_or_load_dataset(seq_len: int, max_samples: int, cache_dir: str = "./cache"): |
| cache_path = os.path.join(cache_dir, f"tiny_stories_{seq_len}_{max_samples}.pt") |
| os.makedirs(cache_dir, exist_ok=True) |
|
|
| if os.path.exists(cache_path): |
| print(f"[CACHE] Loading pre-tokenized dataset from {cache_path}") |
| chunks = torch.load(cache_path, weights_only=False) |
| return PreTokenizedDataset(chunks, seq_len) |
|
|
| from datasets import load_dataset |
| from chimera import ChimeraTokenizer |
|
|
| print(f"[DATA] Downloading TinyStories...") |
| ds = load_dataset("roneneldan/TinyStories", split="train", streaming=True) |
| tok = ChimeraTokenizer(pretrained="o200k_base") |
|
|
| target = max_samples * (seq_len + 1) |
| buffer = torch.empty(target, dtype=torch.long) |
| buf_idx = 0 |
| processed = 0 |
|
|
| for ex in ds: |
| text = ex.get("text", "") |
| if not text: |
| continue |
| ids = tok.encode(text, add_special_tokens=False) |
| ids.append(tok.eos_token_id) |
| n = len(ids) |
| if buf_idx + n > target: |
| n = target - buf_idx |
| if n <= 0: |
| break |
| ids = ids[:n] |
| if n > 0: |
| buffer[buf_idx:buf_idx + n] = torch.tensor(ids, dtype=torch.long) |
| buf_idx += n |
| processed += 1 |
| if (processed % 1000) == 0: |
| print(f" {processed:,} stories, {buf_idx:,}/{target} tokens...") |
| if buf_idx >= target: |
| break |
|
|
| all_ids = buffer[:buf_idx] |
| n = all_ids.numel() // (seq_len + 1) |
| chunks = all_ids[:n * (seq_len + 1)] |
|
|
| torch.save(chunks, cache_path) |
| print(f"[CACHE] Saved {chunks.numel():,} tokens to {cache_path}") |
| return PreTokenizedDataset(chunks, seq_len) |
|
|
|
|
| |
| |
| |
|
|
| def cosine_lr(step: int, warmup: int, total: int, max_lr: float, min_lr: float) -> float: |
| if warmup > 0 and step < warmup: |
| return max_lr * (step + 1) / warmup |
| if step >= total: |
| return min_lr |
| p = (step - warmup) / max(1, total - warmup) |
| return min_lr + 0.5 * (max_lr - min_lr) * (1.0 + math.cos(math.pi * p)) |
|
|
|
|
| _SCALE_PRESETS = { |
| "tiny": dict(hidden_size=256, intermediate_size=512, num_heads=4, head_dim=48), |
| "small": dict(hidden_size=512, intermediate_size=1024, num_heads=8, head_dim=48), |
| "medium": dict(hidden_size=1024, intermediate_size=2048, num_heads=8, head_dim=96), |
| } |
|
|
|
|
| def train(args) -> None: |
| with open(args.config) as f: |
| config = json.load(f) |
|
|
| if args.scale in _SCALE_PRESETS: |
| config.update(_SCALE_PRESETS[args.scale]) |
| config["num_hidden_layers"] = int(config.get("num_hidden_layers", 28)) |
| config["vocab_size"] = config.get("vocab_size", 200073) |
| config.setdefault("gated_deltanet", {})["chunk_size"] = min(args.seq_len, 64) |
| config.setdefault("xlstm", {})["memory_size_per_head"] = [config["head_dim"], config["head_dim"]] |
| config.setdefault("titans", {}).update({ |
| "memory_depth": 2, "persistent_memory_slots": 16, |
| "local_window_size": min(args.seq_len, 256), |
| }) |
| moe_cfg = config.setdefault("backbone", {}).setdefault("moe", {}) |
| moe_cfg.setdefault("layers", [3, 7, 11, 15, 19, 23, 27]) |
| moe_cfg.setdefault("moe_intermediate_size", config["intermediate_size"] // 4) |
| moe_cfg.setdefault("n_routed_experts", 8) |
| moe_cfg.setdefault("n_shared_experts", 1) |
| moe_cfg.setdefault("num_experts_per_tok", 2) |
| config.setdefault("looping", {}).update({ |
| "enabled": True, "prelude": [0, 3], "loop": [4, 23], "coda": [24, 27], |
| "loop_range": [1, 3], "loop_default": 2, |
| }) |
| config.setdefault("span_inference", {})["enabled"] = True |
| config.setdefault("grammar", {})["enabled"] = True |
| config.setdefault("entropy_valve", {})["enabled"] = True |
| config.setdefault("debt_ledger", {})["enabled"] = True |
| config.setdefault("multimodal", {})["enabled"] = False |
|
|
| print("=" * 60) |
| print(f"CHIMERA 5.2 FAST TRAIN — scale={args.scale}, seq_len={args.seq_len}, steps={args.max_steps}") |
| print(f"Layers={config['num_hidden_layers']} hidden={config['hidden_size']} vocab={config['vocab_size']}") |
| print(f"Threads: {torch.get_num_threads()} bf16={args.bf16} compile={args.compile}") |
| print("=" * 60) |
|
|
| model = Chimera51ForCausalLM(config) |
| counts = model.count_parameters() |
| print(f"Params: total={counts['total']:,} ternary={counts['ternary']:,}") |
|
|
| if args.compile: |
| print("[OPT] Compiling model...") |
| model = torch.compile(model, backend="inductor", mode="default", dynamic=True) |
|
|
| dataset = build_or_load_dataset(args.seq_len, args.max_samples, args.cache_dir) |
| loader = DataLoader( |
| dataset, batch_size=args.batch_size, shuffle=True, |
| num_workers=0, drop_last=True, |
| ) |
|
|
| optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.95)) |
|
|
| def compute_loss(batch) -> torch.Tensor: |
| ids = batch["input_ids"] |
| labels = batch["labels"] |
| if args.bf16: |
| with torch.autocast(device_type="cpu", dtype=torch.bfloat16): |
| out = model(ids, labels=labels) |
| else: |
| out = model(ids, labels=labels) |
| return out.loss |
|
|
| os.makedirs(args.output_dir, exist_ok=True) |
| log_path = os.path.join(args.output_dir, "log.jsonl") |
| log_f = open(log_path, "w", encoding="utf-8") |
|
|
| model.train() |
| step = 0 |
| total_loss = 0.0 |
| best_loss = float("inf") |
| toks = 0 |
| t0 = time.time() |
| data_iter = iter(loader) |
| warmup = min(args.warmup, max(1, args.max_steps // 10)) |
|
|
| print(f"\n{'=' * 60}\nTraining starts\n{'=' * 60}\n") |
|
|
| while step < args.max_steps: |
| try: |
| batch = next(data_iter) |
| except StopIteration: |
| data_iter = iter(loader) |
| batch = next(data_iter) |
|
|
| loss = compute_loss(batch) |
| loss.backward() |
| total_loss += float(loss.item()) |
|
|
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| cur_lr = cosine_lr(step, warmup, args.max_steps, args.lr, args.lr * 0.1) |
| for pg in optimizer.param_groups: |
| pg["lr"] = cur_lr |
| optimizer.step() |
| optimizer.zero_grad(set_to_none=True) |
|
|
| toks += batch["input_ids"].numel() |
| step += 1 |
|
|
| if step % args.log_every == 0: |
| dt = time.time() - t0 |
| avg = total_loss / args.log_every |
| ppl = math.exp(min(avg, 20)) |
| tps = toks / dt if dt > 0 else 0 |
| eta_h = (args.max_steps - step) / (step / dt) / 3600 if dt > 0 else 0.0 |
| log_f.write(json.dumps({ |
| "step": step, "loss": round(avg, 4), "ppl": round(ppl, 2), |
| "lr": cur_lr, "tok/s": round(tps), |
| }) + "\n") |
| log_f.flush() |
| print(f" step {step:>6}/{args.max_steps} | loss {avg:.4f} | " |
| f"ppl {ppl:>8.2f} | lr {cur_lr:.2e} | " |
| f"{tps:.0f} tok/s | ETA {eta_h:.1f}h") |
| best_loss = min(best_loss, avg) |
| total_loss = 0.0 |
| toks = 0 |
| t0 = time.time() |
|
|
| if step % args.save_every == 0: |
| ckpt_dir = os.path.join(args.output_dir, f"ckpt-{step}") |
| os.makedirs(ckpt_dir, exist_ok=True) |
| raw = getattr(model, "_orig_mod", model) |
| torch.save({ |
| "model": raw.state_dict(), "config": config, |
| "step": step, |
| }, os.path.join(ckpt_dir, "ckpt.pt")) |
| print(f" [SAVE] {ckpt_dir}") |
|
|
| final_dir = os.path.join(args.output_dir, "final") |
| os.makedirs(final_dir, exist_ok=True) |
| raw = getattr(model, "_orig_mod", model) |
| torch.save({ |
| "model": raw.state_dict(), "config": config, |
| "step": step, "best_loss": best_loss, |
| }, os.path.join(final_dir, "model.pt")) |
| with open(os.path.join(final_dir, "config.json"), "w", encoding="utf-8") as fh: |
| json.dump(config, fh, indent=2) |
| log_f.close() |
|
|
| print(f"\n{'=' * 60}") |
| print(f"DONE — best loss {best_loss:.4f}, ppl {math.exp(min(best_loss, 20)):.2f}") |
| print(f"Saved to {final_dir}") |
|
|
|
|
| if __name__ == "__main__": |
| p = argparse.ArgumentParser(description="Chimera 5.2 Fast CPU training") |
| p.add_argument("--config", default="config.json") |
| p.add_argument("--scale", default="tiny", choices=["tiny", "small", "medium", "full"]) |
| p.add_argument("--seq_len", type=int, default=32) |
| p.add_argument("--batch_size", type=int, default=4) |
| p.add_argument("--lr", type=float, default=1e-3) |
| p.add_argument("--warmup", type=int, default=100) |
| p.add_argument("--max_steps", type=int, default=1000) |
| p.add_argument("--max_samples", type=int, default=5000) |
| p.add_argument("--bf16", action="store_true", default=False) |
| p.add_argument("--compile", action="store_true", default=False) |
| p.add_argument("--cache_dir", default="./cache") |
| p.add_argument("--log_every", type=int, default=10) |
| p.add_argument("--save_every", type=int, default=500) |
| p.add_argument("--output_dir", default="./chimera_output") |
| args = p.parse_args() |
| train(args) |
|
|