| """ |
| Chimera 5.1 β Training Script (CPU-Optimized) |
| ================================================== |
| Optimizations implemented: |
| 1. MeZO (Memory-Efficient Zeroth-Order) optimizer β eliminates backward pass entirely |
| - 2Γ forward only, no activation storage, no gradient computation |
| - arxiv:2305.17333 |
| 2. BFloat16 autocast on CPU β 2-4Γ faster matmuls on AVX-512/AMX hardware |
| 3. torch.compile with Inductor backend β fused ops, reduced Python overhead |
| 4. Gradient checkpointing (for AdamW mode) β trades compute for memory |
| 5. Optimal CPU threading β KMP_AFFINITY, OMP tuning, NUMA-aware |
| 6. Persistent DataLoader workers β no worker restart overhead |
| 7. Intel IPEX integration (optional) β auto-detected |
| 8. Cosine LR with warmup |
| 9. Standard AdamW with backprop as fallback mode |
| 10. Generic dataset loading β supports any HF dataset, messages/text columns, category filtering |
| |
| Usage: |
| # MeZO mode (recommended for CPU β no backward pass): |
| python train.py --optimizer mezo --scale tiny --seq_len 64 --max_steps 100 |
| |
| # AdamW mode (standard backprop with gradient checkpointing + bf16): |
| python train.py --optimizer adamw --scale tiny --seq_len 64 --max_steps 100 |
| |
| # Full run with custom dataset and category filter: |
| python train.py --optimizer mezo --scale tiny --seq_len 64 --max_steps 10000 \ |
| --dataset_name Roman1111111/claude-sonnet-4.6-120000x \ |
| --dataset_split train --text_column messages \ |
| --category_filter "C++,organic chemistry" |
| """ |
|
|
| import os |
| import sys |
| import json |
| import time |
| import math |
| import argparse |
|
|
| |
| def _setup_cpu_threading(): |
| """Configure optimal CPU threading for training.""" |
| n_cpus = os.cpu_count() or 4 |
| |
| os.environ.setdefault('OMP_NUM_THREADS', str(n_cpus)) |
| os.environ.setdefault('MKL_NUM_THREADS', str(n_cpus)) |
| |
| os.environ.setdefault('KMP_AFFINITY', 'granularity=fine,compact,1,0') |
| |
| os.environ.setdefault('KMP_BLOCKTIME', '1') |
| |
| os.environ.setdefault('MALLOC_CONF', 'background_thread:true,metadata_thp:auto') |
|
|
| _setup_cpu_threading() |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader, Dataset |
|
|
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
| from chimera import Chimera51ForCausalLM |
| from chimera.quantization import BitLinear |
|
|
| |
| torch.set_num_threads(int(os.environ.get('OMP_NUM_THREADS', os.cpu_count() or 4))) |
| try: |
| torch.set_num_interop_threads(int(os.environ.get('CHIMERA_INTEROP_THREADS', '1'))) |
| except RuntimeError: |
| pass |
|
|
| |
| HAS_IPEX = False |
| try: |
| import intel_extension_for_pytorch as ipex |
| HAS_IPEX = True |
| print("[IPEX] Intel Extension for PyTorch detected β will use optimized kernels") |
| except ImportError: |
| pass |
|
|
|
|
| |
| |
| |
| class MeZOOptimizer: |
| """Ternary-Aware Memory-Efficient Zeroth-Order Optimizer. |
| |
| Eliminates the backward pass entirely: |
| - 2 forward passes per step (ΞΈ+Ξ΅z and ΞΈ-Ξ΅z) |
| - Memory = model size only (no activations, no gradients, no optimizer states) |
| - Gradient estimated via finite differences |
| |
| TERNARY OPTIMIZATION: For BitLinear layers, perturbation and update |
| skip zero-weight positions (~33% of weights), saving ~33% of the |
| perturbation and update compute. Uses C++ kernel when available. |
| """ |
|
|
| def __init__(self, model, lr=1e-4, eps=1e-3, weight_decay=0.0, |
| momentum=0.0, direction="rademacher", cache_directions=True): |
| self.model = model |
| self.lr = lr |
| self.eps = eps |
| self.wd = weight_decay |
| self.momentum = momentum |
| self.direction = direction |
| self.cache_directions = cache_directions |
|
|
| |
| |
| |
| self._bitlinear_params = [] |
| self._other_params = [] |
| found_params = set() |
|
|
| def add_other(name, param): |
| if param.requires_grad and id(param) not in found_params: |
| self._other_params.append((name, param)) |
| found_params.add(id(param)) |
|
|
| for name, module in model.named_modules(): |
| if isinstance(module, BitLinear): |
| self._bitlinear_params.append((name, module)) |
| for p in module.parameters(recurse=False): |
| found_params.add(id(p)) |
| elif isinstance(module, (nn.Linear, nn.Embedding)): |
| for pn, p in module.named_parameters(recurse=False): |
| add_other(f"{name}.{pn}", p) |
|
|
| |
| for name, p in model.named_parameters(): |
| add_other(name, p) |
|
|
| self._mezo_masks = {} |
| self._direction_cache = {} |
|
|
| |
| if momentum > 0: |
| self._momentum_buffer = {} |
| for n, p in model.named_parameters(): |
| if p.requires_grad: |
| self._momentum_buffer[n] = torch.zeros_like(p.data) |
|
|
| def _sample_direction(self, p: torch.Tensor, seed: int) -> torch.Tensor: |
| gen = torch.Generator(device=p.device if p.device.type != 'cpu' else 'cpu') |
| gen.manual_seed(int(seed) & 0x7FFFFFFFFFFFFFFF) |
| if self.direction == "gaussian": |
| return torch.randn(p.shape, dtype=p.dtype, device=p.device, generator=gen) |
| |
| |
| z = torch.empty(p.shape, dtype=p.dtype, device=p.device) |
| z.bernoulli_(0.5, generator=gen).mul_(2).sub_(1) |
| return z |
|
|
| def _direction_for(self, name: str, p: torch.Tensor, seed: int, mask=None) -> torch.Tensor: |
| if self.cache_directions and name in self._direction_cache: |
| return self._direction_cache[name] |
| z = self._sample_direction(p, seed) |
| if mask is not None: |
| z.mul_(mask.to(device=p.device, dtype=z.dtype)) |
| if self.cache_directions: |
| self._direction_cache[name] = z |
| return z |
|
|
| def _perturb_params(self, seed: int, scale: float): |
| """Ternary-aware perturbation with cached deterministic directions.""" |
| sub_seed = seed |
| for name, module in self._bitlinear_params: |
| mask = self._mezo_masks.get(name) |
| if mask is None: |
| mask = module.ternary_nonzero_mask() |
| z = self._direction_for(f"{name}.weight", module.weight.data, sub_seed, mask=mask) |
| module.weight.data.add_(z, alpha=scale) |
| module.invalidate_packed() |
| sub_seed += 1000003 |
|
|
| for i, (name, p) in enumerate(self._other_params): |
| z = self._direction_for(name, p.data, seed + 500000007 + i * 1000003) |
| p.data.add_(z, alpha=scale) |
|
|
| def _update_params(self, seed: int, projected_grad: float): |
| """Ternary-aware parameter update using the same cached directions.""" |
| sub_seed = seed |
| for name, module in self._bitlinear_params: |
| z = self._direction_for(f"{name}.weight", module.weight.data, sub_seed, |
| mask=self._mezo_masks.get(name)) |
| if self.momentum > 0 and f"{name}.weight" in self._momentum_buffer: |
| buf = self._momentum_buffer[f"{name}.weight"] |
| buf.mul_(self.momentum).add_(z, alpha=projected_grad) |
| module.weight.data.add_(buf, alpha=-self.lr) |
| else: |
| module.weight.data.add_(z, alpha=-self.lr * projected_grad) |
| if self.wd > 0: |
| module.weight.data.mul_(1 - self.lr * self.wd) |
| module.invalidate_packed() |
| sub_seed += 1000003 |
|
|
| for i, (name, p) in enumerate(self._other_params): |
| z = self._direction_for(name, p.data, seed + 500000007 + i * 1000003) |
| if self.momentum > 0 and name in self._momentum_buffer: |
| buf = self._momentum_buffer[name] |
| buf.mul_(self.momentum).add_(z, alpha=projected_grad) |
| p.data.add_(buf, alpha=-self.lr) |
| else: |
| p.data.add_(z, alpha=-self.lr * projected_grad) |
| if self.wd > 0: |
| p.data.mul_(1 - self.lr * self.wd) |
|
|
| @torch.no_grad() |
| def step(self, loss_fn, batch) -> float: |
| """Single MeZO step: 2 forward passes, no backward. |
| |
| Returns: loss estimate (average of pos/neg) |
| """ |
| seed = torch.randint(0, 2**31, (1,)).item() |
|
|
| |
| |
| |
| self._mezo_masks = {name: module.ternary_nonzero_mask().detach() |
| for name, module in self._bitlinear_params} |
| self._direction_cache = {} |
|
|
| |
| self._perturb_params(seed, self.eps) |
| loss_pos = loss_fn(batch).item() |
|
|
| |
| self._perturb_params(seed, -2 * self.eps) |
| loss_neg = loss_fn(batch).item() |
|
|
| |
| self._perturb_params(seed, self.eps) |
|
|
| |
| projected_grad = (loss_pos - loss_neg) / (2 * self.eps) |
|
|
| |
| self._update_params(seed, projected_grad) |
|
|
| |
| for _, module in self._bitlinear_params: |
| module.invalidate_packed() |
| self._mezo_masks = {} |
| self._direction_cache = {} |
|
|
| return (loss_pos + loss_neg) / 2 |
|
|
|
|
| |
| |
| |
| class TokenDataset(Dataset): |
| def __init__(self, chunks: torch.Tensor): |
| self.chunks = chunks |
|
|
| def __len__(self) -> int: |
| return len(self.chunks) |
|
|
| def __getitem__(self, idx: int) -> dict: |
| return {"input_ids": self.chunks[idx], "labels": self.chunks[idx]} |
|
|
|
|
| def _matches_category_filter(ex: dict, filters: list) -> bool: |
| """Check if example matches any of the requested category substrings.""" |
| cat = ex.get("category", "") |
| if not cat: |
| return False |
| cat_lower = cat.lower() |
| return any(f.lower() in cat_lower for f in filters) |
|
|
|
|
| def _format_example(ex: dict, tok, text_column: str = "auto", include_reasoning: bool = False) -> str: |
| """Convert an example dict to a single text string for tokenization.""" |
| |
| if text_column == "auto": |
| if "messages" in ex: |
| text_column = "messages" |
| elif "text" in ex: |
| text_column = "text" |
| elif "content" in ex: |
| text_column = "content" |
| elif "conversation" in ex: |
| text_column = "conversation" |
| else: |
| text_column = None |
|
|
| if text_column == "messages" and "messages" in ex: |
| msgs = ex["messages"] |
| |
| if include_reasoning and isinstance(msgs, list): |
| msgs = [] |
| for m in ex["messages"]: |
| if isinstance(m, dict) and m.get("role") == "assistant" and "reasoning" in m: |
| content = f"<|thinking|>\n{m['reasoning']}\n<|/thinking|>\n{m.get('content', '')}" |
| msgs.append({"role": "assistant", "content": content}) |
| else: |
| msgs.append(m) |
| return tok.apply_chat_template(msgs) |
|
|
| if text_column and text_column in ex: |
| val = ex[text_column] |
| if isinstance(val, str): |
| return val |
| |
| if isinstance(val, list) and len(val) > 0 and isinstance(val[0], dict): |
| return tok.apply_chat_template(val) |
| return str(val) |
|
|
| |
| return str(ex) |
|
|
|
|
| def build_dataset(seq_len: int, max_samples=None, max_tokens=None, split: str = "train", |
| dataset_name: str = "roneneldan/TinyStories", |
| dataset_config: str = None, |
| text_column: str = "auto", |
| category_filter: str = None, |
| include_reasoning: bool = False): |
| """Build dataset from any HuggingFace dataset with splintr tokenizer. |
| |
| Supports: |
| - Generic text columns ('text', 'content', etc.) |
| - Messages/chat format (auto-detected, uses apply_chat_template) |
| - Category filtering (comma-separated substrings) |
| - Streaming for huge datasets |
| - Pre-allocated token buffer to avoid OOM on billion-token datasets |
| """ |
| from datasets import load_dataset |
| from chimera import ChimeraTokenizer |
|
|
| print(f"[DATA] Loading {dataset_name} ({split})...") |
| load_kwargs = {"split": split, "streaming": True} |
| if dataset_config: |
| load_kwargs["name"] = dataset_config |
| ds = load_dataset(dataset_name, **load_kwargs) |
|
|
| print(f"[DATA] Loading tokenizer (splintr o200k_base)...") |
| tok = ChimeraTokenizer(pretrained="o200k_base") |
|
|
| |
| cat_filters = None |
| if category_filter: |
| cat_filters = [c.strip() for c in category_filter.split(",") if c.strip()] |
| print(f"[DATA] Filtering categories: {cat_filters}") |
|
|
| |
| if max_tokens is not None: |
| token_budget = max_tokens |
| elif max_samples is not None: |
| token_budget = max_samples * (seq_len + 1) |
| else: |
| token_budget = None |
|
|
| processed = 0 |
| skipped = 0 |
|
|
| if token_budget is not None and token_budget > 0: |
| |
| buffer = torch.empty(token_budget, dtype=torch.long) |
| buf_idx = 0 |
|
|
| for i, ex in enumerate(ds): |
| if cat_filters and not _matches_category_filter(ex, cat_filters): |
| skipped += 1 |
| continue |
|
|
| text = _format_example(ex, tok, text_column, include_reasoning) |
| if not text or not text.strip(): |
| skipped += 1 |
| continue |
|
|
| ids = tok.encode(text, add_special_tokens=False) |
| ids.append(tok.eos_token_id) |
| n_ids = len(ids) |
|
|
| |
| if buf_idx + n_ids > token_budget: |
| n_ids = token_budget - buf_idx |
| if n_ids <= 0: |
| break |
| ids = ids[:n_ids] |
|
|
| if n_ids > 0: |
| buffer[buf_idx:buf_idx + n_ids] = torch.tensor(ids, dtype=torch.long) |
| buf_idx += n_ids |
| processed += 1 |
|
|
| if buf_idx >= token_budget: |
| break |
| if (processed + 1) % 10000 == 0: |
| print(f" {processed:,} examples, {buf_idx:,} tokens...") |
|
|
| all_ids = buffer[:buf_idx] |
| else: |
| |
| all_ids = [] |
| target = max_samples * (seq_len + 1) if max_samples else float('inf') |
| for i, ex in enumerate(ds): |
| if cat_filters and not _matches_category_filter(ex, cat_filters): |
| skipped += 1 |
| continue |
|
|
| text = _format_example(ex, tok, text_column, include_reasoning) |
| if not text or not text.strip(): |
| skipped += 1 |
| continue |
|
|
| all_ids.extend(tok.encode(text, add_special_tokens=False)) |
| all_ids.append(tok.eos_token_id) |
| processed += 1 |
|
|
| if len(all_ids) >= target: |
| break |
| if (processed + 1) % 10000 == 0: |
| print(f" {processed:,} examples, {len(all_ids):,} tokens...") |
| all_ids = torch.tensor(all_ids, dtype=torch.long) |
|
|
| print(f"[DATA] Processed {processed:,} examples, skipped {skipped:,} (category/text mismatch)") |
|
|
| if len(all_ids) == 0: |
| raise ValueError( |
| f"No data matched filters. dataset={dataset_name}, " |
| f"category_filter={category_filter}, text_column={text_column}" |
| ) |
|
|
| n = len(all_ids) // (seq_len + 1) |
| if max_samples: |
| n = min(n, max_samples) |
| chunks = all_ids[:n * (seq_len + 1)].view(n, seq_len + 1) |
| print(f"[DATA] {n:,} chunks Γ {seq_len} tokens = {n * seq_len:,} total") |
| return TokenDataset(chunks), tok |
|
|
|
|
| |
| |
| |
| def cosine_lr(step: int, warmup: int, total: int, |
| max_lr: float, min_lr: float) -> float: |
| if step < warmup: |
| return max_lr * (step + 1) / warmup |
| if step >= total: |
| return min_lr |
| p = (step - warmup) / (total - warmup) |
| return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * p)) |
|
|
|
|
| |
| |
| |
| def train(args): |
| with open(args.config) as f: |
| config = json.load(f) |
|
|
| |
| if args.scale == "tiny": |
| config['hidden_size'] = 256 |
| config['intermediate_size'] = 512 |
| config['num_hidden_layers'] = 28 |
| config['num_heads'] = 4 |
| config['head_dim'] = 48 |
| elif args.scale == "small": |
| config['hidden_size'] = 512 |
| config['intermediate_size'] = 1024 |
| config['num_hidden_layers'] = 28 |
| config['num_heads'] = 8 |
| config['head_dim'] = 48 |
| elif args.scale == "medium": |
| config['hidden_size'] = 1024 |
| config['intermediate_size'] = 2048 |
| config['num_hidden_layers'] = 28 |
| config['num_heads'] = 8 |
| config['head_dim'] = 96 |
|
|
| config['vocab_size'] = 200073 |
| config.setdefault('gated_deltanet', {})['chunk_size'] = min(args.seq_len, 64) |
| config.setdefault('xlstm', {})['memory_size_per_head'] = [config['head_dim'], config['head_dim']] |
| config.setdefault('titans', {}).update({ |
| 'memory_depth': 2, 'persistent_memory_slots': 16, |
| 'local_window_size': min(args.seq_len, 256) |
| }) |
| moe_cfg = config.setdefault('backbone', {}).setdefault('moe', {}) |
| moe_cfg.update({ |
| 'layers': [3, 7, 11, 15, 19, 23, 27], |
| 'moe_intermediate_size': config['intermediate_size'] // 4, |
| 'n_routed_experts': 8, 'n_shared_experts': 1, 'num_experts_per_tok': 2 |
| }) |
| config.setdefault('looping', {}).update({ |
| 'enabled': True, 'prelude': [0, 3], 'loop': [4, 23], 'coda': [24, 27], |
| 'loop_range': [1, 3], 'loop_default': 2, 'adaptive_exit_threshold': 0.01 |
| }) |
| config.setdefault('span_inference', {})['enabled'] = True |
| config.setdefault('grammar', {})['enabled'] = True |
| config.setdefault('entropy_valve', {})['enabled'] = True |
| config.setdefault('debt_ledger', {}).update({ |
| 'enabled': True, 'obligations': ['close_bracket', 'close_string'], |
| 'max_outstanding': 32, 'pressure_weight': 0.3 |
| }) |
| config.setdefault('self_evolution', {}).update({ |
| 'tier1': { |
| 'ttt': {'enabled': True, 'target_layers': [13, 23], 'inner_lr': 0.0003, |
| 'momentum': 0.9, 'chunk_size': 256, 'reset_decay': 0.95}, |
| 'memory_growth': {'enabled': True, 'pool_size_fixed': True} |
| }, |
| 'tier2': { |
| 'meta_guidelines': {'enabled': True, 'max': 64}, |
| 'episodic_cases': {'enabled': True, 'max_cases': 256, 'case_bytes': 512}, |
| 'self_feedback': {'enabled': True, 'confidence_threshold': 0.6, |
| 'max_refinement_rounds': 1} |
| }, |
| 'tier3': {'loop_depth_learning': {'enabled': True}}, |
| 'safety': {'freeze_threshold': 0.05}, |
| }) |
| config.setdefault('semantic_memory', {}).update({ |
| 'vector_bits': 1024, 'capacity': 1000, 'pool_size_fixed': True |
| }) |
| config.setdefault('multimodal', {})['enabled'] = False |
|
|
| |
| use_mezo = args.optimizer == 'mezo' |
| use_bf16 = args.bf16 and torch.cpu.is_available() |
| use_compile = args.compile |
|
|
| print("=" * 60) |
| print("CHIMERA 5.1 TRAINING β CPU-OPTIMIZED") |
| print("=" * 60) |
| print(f"Scale: {args.scale} (h={config['hidden_size']})") |
| print(f"Layers: {config['num_hidden_layers']}") |
| print(f"Seq len: {args.seq_len}") |
| print(f"Steps: {args.max_steps}") |
| print(f"Optimizer: {'MeZO (no backward)' if use_mezo else 'AdamW (backprop)'}") |
| print(f"BFloat16: {use_bf16}") |
| print(f"torch.compile:{use_compile}") |
| print(f"Grad ckpt: {args.grad_checkpoint and not use_mezo}") |
| print(f"Device: CPU ({torch.get_num_threads()} threads)") |
| print(f"IPEX: {HAS_IPEX}") |
| print(f"Tokenizer: splintr o200k_base ({config['vocab_size']} tokens)") |
| print(f"Dataset: {args.dataset_name} / {args.dataset_split}") |
| if args.dataset_config: |
| print(f"Dataset config: {args.dataset_config}") |
| if args.category_filter: |
| print(f"Category filter: {args.category_filter}") |
| if args.include_reasoning: |
| print("Reasoning: INCLUDED (<|thinking|> ... <|/thinking|>)") |
|
|
| |
| model = Chimera51ForCausalLM(config) |
| p = model.count_parameters() |
| print(f"Params: {p['total']:,} (ternary: {p['ternary']:,})") |
|
|
| if use_mezo: |
| mem_mb = p['total'] * 4 * 2 / 1024 ** 2 |
| print(f"Memory: ~{mem_mb:.0f} MB (MeZO: 2Γ model only)") |
| else: |
| mem_mb = p['total'] * 12 / 1024 ** 2 |
| print(f"Memory: ~{mem_mb:.0f} MB (AdamW: params + grads + states)") |
|
|
| |
| if args.grad_checkpoint and not use_mezo: |
| model.enable_gradient_checkpointing() |
| print("[OPT] Gradient checkpointing enabled") |
|
|
| |
| if HAS_IPEX and not use_mezo: |
| optimizer_for_ipex = torch.optim.AdamW(model.parameters(), lr=args.lr) |
| model, optimizer_for_ipex = ipex.optimize( |
| model, optimizer=optimizer_for_ipex, |
| dtype=torch.bfloat16 if use_bf16 else torch.float32, |
| level='O1' |
| ) |
| print("[OPT] IPEX optimization applied (level O1)") |
|
|
| |
| if use_compile: |
| print("[OPT] Compiling model with torch.compile (inductor)...") |
| model = torch.compile(model, backend="inductor", mode="default", |
| dynamic=True) |
| print("[OPT] Compilation deferred (will compile on first forward pass)") |
|
|
| |
| dataset, tok = build_dataset( |
| args.seq_len, |
| max_samples=args.max_samples, |
| max_tokens=args.max_tokens, |
| split=args.dataset_split, |
| dataset_name=args.dataset_name, |
| dataset_config=args.dataset_config, |
| text_column=args.text_column, |
| category_filter=args.category_filter, |
| include_reasoning=args.include_reasoning, |
| ) |
| loader = DataLoader( |
| dataset, |
| batch_size=args.batch_size, |
| shuffle=True, |
| num_workers=args.num_workers, |
| drop_last=True, |
| persistent_workers=args.num_workers > 0, |
| prefetch_factor=2 if args.num_workers > 0 else None, |
| ) |
|
|
| |
| if use_mezo: |
| optimizer = MeZOOptimizer( |
| model, |
| lr=args.lr * 0.01, |
| eps=1e-3, |
| weight_decay=0.1, |
| momentum=0.9, |
| direction=args.mezo_direction, |
| cache_directions=args.mezo_direction_cache, |
| ) |
| else: |
| no_decay = {"A_log", "dt_bias", "norm", "bias", "embed", "energy_weights"} |
| param_groups = [ |
| {"params": [p for n, p in model.named_parameters() |
| if not any(nd in n for nd in no_decay) and p.requires_grad], |
| "weight_decay": 0.1}, |
| {"params": [p for n, p in model.named_parameters() |
| if any(nd in n for nd in no_decay) and p.requires_grad], |
| "weight_decay": 0.0}, |
| ] |
| if HAS_IPEX: |
| optimizer = optimizer_for_ipex |
| else: |
| optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) |
|
|
| |
| def compute_loss(batch): |
| ids = batch["input_ids"][:, :-1] |
| labels = batch["labels"][:, 1:] |
| if use_bf16: |
| with torch.autocast(device_type='cpu', dtype=torch.bfloat16): |
| loss, _ = model(ids, labels=labels) |
| else: |
| loss, _ = model(ids, labels=labels) |
| return loss |
|
|
| |
| os.makedirs(args.output_dir, exist_ok=True) |
| log_f = open(os.path.join(args.output_dir, "log.jsonl"), "w") |
|
|
| model.train() |
| step = 0 |
| total_loss = 0.0 |
| best = float('inf') |
| t0 = time.time() |
| toks = 0 |
| data_iter = iter(loader) |
| warmup = min(args.warmup, args.max_steps // 10) |
|
|
| if not use_mezo: |
| optimizer.zero_grad() |
|
|
| print(f"\n{'=' * 60}") |
| print(f"Starting training...") |
| print(f"{'=' * 60}\n") |
|
|
| while step < args.max_steps: |
| |
| try: |
| batch = next(data_iter) |
| except StopIteration: |
| data_iter = iter(loader) |
| batch = next(data_iter) |
|
|
| |
| if use_mezo: |
| |
| lr = cosine_lr(step, warmup, args.max_steps, |
| args.lr * 0.01, args.lr * 0.001) |
| optimizer.lr = lr |
|
|
| loss_val = optimizer.step(compute_loss, batch) |
| total_loss += loss_val |
| toks += batch["input_ids"][:, :-1].numel() |
|
|
| |
| else: |
| loss = compute_loss(batch) |
| (loss / args.grad_accum).backward() |
| total_loss += loss.item() |
| toks += batch["input_ids"][:, :-1].numel() |
|
|
| if (step + 1) % args.grad_accum == 0: |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| lr = cosine_lr(step, warmup, args.max_steps, args.lr, args.lr * 0.1) |
| for pg in optimizer.param_groups: |
| pg['lr'] = lr |
| optimizer.step() |
| optimizer.zero_grad() |
|
|
| step += 1 |
|
|
| |
| if step % args.log_every == 0: |
| dt = time.time() - t0 |
| avg = total_loss / args.log_every |
| ppl = math.exp(min(avg, 20)) |
| tps = toks / dt if dt > 0 else 0 |
| eta = (args.max_steps - step) / (step / dt) / 3600 if dt > 0 else 0 |
| entry = { |
| "step": step, "loss": round(avg, 4), "ppl": round(ppl, 2), |
| "lr": round(lr, 8), "tok/s": round(tps), "eta_h": round(eta, 1), |
| "optimizer": "mezo" if use_mezo else "adamw", |
| } |
| print(f" step {step:>6}/{args.max_steps} | loss {avg:.4f} | " |
| f"ppl {ppl:>8.2f} | {tps:.0f} tok/s | ETA {eta:.1f}h") |
| log_f.write(json.dumps(entry) + "\n") |
| log_f.flush() |
| if avg < best: |
| best = avg |
| total_loss = 0.0 |
| toks = 0 |
| t0 = time.time() |
|
|
| |
| if step % args.save_every == 0: |
| path = os.path.join(args.output_dir, f"ckpt-{step}") |
| os.makedirs(path, exist_ok=True) |
| |
| raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model |
| torch.save({ |
| "model": raw_model.state_dict(), |
| "config": config, |
| "step": step, |
| "optimizer": args.optimizer, |
| }, os.path.join(path, "ckpt.pt")) |
| print(f" [SAVE] {path}") |
|
|
| |
| path = os.path.join(args.output_dir, "final") |
| os.makedirs(path, exist_ok=True) |
| raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model |
| torch.save({ |
| "model": raw_model.state_dict(), |
| "config": config, |
| "step": step, |
| "best_loss": best, |
| }, os.path.join(path, "model.pt")) |
| json.dump(config, open(os.path.join(path, "config.json"), "w"), indent=2) |
| print(f"\n{'=' * 60}") |
| print(f"DONE β Best loss: {best:.4f}, PPL: {math.exp(min(best, 20)):.2f}") |
| print(f"Optimizer: {'MeZO (no backward)' if use_mezo else 'AdamW'}") |
| print(f"Saved: {path}") |
| log_f.close() |
|
|
|
|
| if __name__ == "__main__": |
| p = argparse.ArgumentParser(description="Chimera 5.1 CPU-Optimized Training") |
|
|
| |
| p.add_argument("--config", default="config.json") |
| p.add_argument("--scale", default="tiny", choices=["tiny", "small", "medium", "full"]) |
| p.add_argument("--seq_len", type=int, default=256) |
|
|
| |
| p.add_argument("--optimizer", default="mezo", choices=["mezo", "adamw"], |
| help="mezo: no backward pass (CPU-optimal). adamw: standard backprop.") |
| p.add_argument("--batch_size", type=int, default=2) |
| p.add_argument("--grad_accum", type=int, default=8) |
| p.add_argument("--lr", type=float, default=1e-3) |
| p.add_argument("--warmup", type=int, default=200) |
| p.add_argument("--max_steps", type=int, default=5000) |
| p.add_argument("--max_samples", type=int, default=None, |
| help="Maximum number of chunks to generate") |
| p.add_argument("--max_tokens", type=int, default=None, |
| help="Maximum total tokens to collect (pre-allocated buffer, prevents OOM on huge datasets)") |
|
|
| |
| p.add_argument("--bf16", action="store_true", default=True, |
| help="Enable BFloat16 autocast on CPU (default: True)") |
| p.add_argument("--no-bf16", dest="bf16", action="store_false") |
| p.add_argument("--compile", action="store_true", default=False, |
| help="Enable torch.compile with Inductor backend") |
| p.add_argument("--grad_checkpoint", action="store_true", default=True, |
| help="Enable gradient checkpointing (AdamW mode only)") |
| p.add_argument("--no-grad-checkpoint", dest="grad_checkpoint", action="store_false") |
| p.add_argument("--mezo_direction", choices=["rademacher", "gaussian"], |
| default="rademacher", |
| help="ZO perturbation distribution; rademacher is fastest on CPU") |
| p.add_argument("--no-mezo-direction-cache", dest="mezo_direction_cache", |
| action="store_false", default=True, |
| help="Regenerate directions instead of caching them for the step") |
|
|
| |
| p.add_argument("--dataset_name", default="roneneldan/TinyStories", |
| help="HuggingFace dataset name (e.g. Roman1111111/claude-sonnet-4.6-120000x)") |
| p.add_argument("--dataset_config", default=None, |
| help="Dataset config/subset name") |
| p.add_argument("--dataset_split", default="train", |
| help="Dataset split to use") |
| p.add_argument("--text_column", default="auto", |
| help="Column containing text. 'auto' detects 'messages'/'text'/'content'/'conversation'") |
| p.add_argument("--category_filter", default=None, |
| help="Comma-separated category substrings to filter on (e.g. 'C++,python,math')") |
| p.add_argument("--include_reasoning", action="store_true", default=False, |
| help="Include reasoning/thinking content from assistant messages as <|thinking|>...<|/thinking|>") |
|
|
| |
| p.add_argument("--num_workers", type=int, default=4) |
| p.add_argument("--log_every", type=int, default=10) |
| p.add_argument("--save_every", type=int, default=1000) |
| p.add_argument("--output_dir", default="./chimera_output") |
|
|
| train(p.parse_args()) |
|
|