#!/usr/bin/env python3 """Chimera 5.2 — CPU-first inference / text generation. Config is source of truth. Checkpoint weights are resized to match the model. """ from __future__ import annotations import argparse import json import os import sys import time from typing import Dict, Tuple def _setup_cpu_runtime() -> None: n = os.cpu_count() or 4 os.environ.setdefault("OMP_NUM_THREADS", str(n)) os.environ.setdefault("MKL_NUM_THREADS", str(n)) os.environ.setdefault("KMP_AFFINITY", "granularity=fine,compact,1,0") os.environ.setdefault("KMP_BLOCKTIME", "1") os.environ.setdefault("MALLOC_CONF", "background_thread:true,metadata_thp:auto") _setup_cpu_runtime() import torch import torch.nn.functional as F try: torch.set_num_threads(int(os.environ.get("OMP_NUM_THREADS", os.cpu_count() or 4))) torch.set_num_interop_threads(int(os.environ.get("CHIMERA_INTEROP_THREADS", "1"))) except RuntimeError: pass sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from chimera import Chimera51ForCausalLM, ChimeraTokenizer # --------------------------------------------------------------------------- # Resize helpers: checkpoint weights -> model architecture (config is truth) # --------------------------------------------------------------------------- @torch.no_grad() def _resize_1d(w: torch.Tensor, target: int) -> torch.Tensor: out = torch.ones(target, dtype=w.dtype, device=w.device) n = min(w.numel(), target) out[:n] = w[:n] return out @torch.no_grad() def _resize_2d(w: torch.Tensor, target_shape: Tuple[int, int]) -> torch.Tensor: to, ti = target_shape so, si = w.shape if (so, si) == (to, ti): return w out = torch.empty((to, ti), dtype=w.dtype, device=w.device) std = float(w.std(unbiased=False).item()) if w.numel() > 1 else 0.02 std = max(min(std, 0.2), 1e-4) out.normal_(mean=0.0, std=std) ro, ci = min(so, to), min(si, ti) out[:ro, :ci] = w[:ro, :ci] return out # --------------------------------------------------------------------------- # Checkpoint loading # --------------------------------------------------------------------------- def load_model(checkpoint_path: str, device: str = "cpu"): print(f"[LOAD] Checkpoint: {checkpoint_path}") ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False) config = ckpt.get("config") if config is None: ckpt_dir = os.path.dirname(checkpoint_path) cand = os.path.join(ckpt_dir, "config.json") if ckpt_dir else "config.json" if not os.path.exists(cand): cand = "config.json" with open(cand, encoding="utf-8") as f: config = json.load(f) print(f"[LOAD] Config from {cand}") else: print("[LOAD] Config from checkpoint") model = Chimera51ForCausalLM(config) counts = model.count_parameters() print(f"[LOAD] Params: {counts['total']:,} (ternary: {counts['ternary']:,})") state = ckpt.get("model", ckpt) model_state = model.state_dict() # Config is source of truth: resize checkpoint tensors to match model. resized: Dict[str, torch.Tensor] = {} for k, v in state.items(): if k in model_state: expected = model_state[k].shape if v.shape != expected: print(f"[WARN] resizing {k}: {tuple(v.shape)} -> {tuple(expected)}") if v.ndim == 1: v = _resize_1d(v, expected[0]) elif v.ndim == 2: v = _resize_2d(v, expected) else: print(f"[SKIP] {k}: cannot resize {v.ndim}D tensor") continue resized[k] = v else: resized[k] = v # Vocab reconciliation: if vocab mismatch, re-init embed + lm_head. model_vocab = int(config.get("vocab_size", model.embed.num_embeddings)) if "embed.weight" in resized: ckpt_vocab = int(resized["embed.weight"].shape[0]) if ckpt_vocab != model_vocab: print(f"[WARN] vocab mismatch ckpt={ckpt_vocab} cfg={model_vocab}; re-init embed+head") with torch.no_grad(): old = model.embed.weight.data new = torch.zeros(ckpt_vocab, old.shape[1], dtype=old.dtype, device=old.device) new[:min(old.shape[0], ckpt_vocab)] = old[:min(old.shape[0], ckpt_vocab)] model.embed = torch.nn.Embedding(ckpt_vocab, old.shape[1]) model.embed.weight.data = new old_h = model.lm_head.weight.data new_h = torch.zeros(ckpt_vocab, old_h.shape[1], dtype=old_h.dtype, device=old_h.device) new_h[:min(old_h.shape[0], ckpt_vocab)] = old_h[:min(old_h.shape[0], ckpt_vocab)] model.lm_head = torch.nn.Linear(old_h.shape[1], ckpt_vocab, bias=False) model.lm_head.weight.data = new_h config["vocab_size"] = ckpt_vocab missing, unexpected = model.load_state_dict(resized, strict=False) if missing: print(f"[WARN] Missing keys ({len(missing)}): {missing[:5]}...") if unexpected: print(f"[WARN] Unexpected keys ({len(unexpected)}): {unexpected[:5]}...") model.to(device).eval() model.prepare_for_inference() step = ckpt.get("step", "?") best_loss = ckpt.get("best_loss") if best_loss is not None: print(f"[LOAD] Step {step}, best_loss={best_loss:.4f}") else: print(f"[LOAD] Step {step}") return model, config # --------------------------------------------------------------------------- # Sampling helpers # --------------------------------------------------------------------------- def _sample_next(logits: torch.Tensor, temperature: float, top_p: float, top_k: int ) -> int: if logits.dim() == 1: logits = logits.unsqueeze(0) if temperature <= 0.0: return int(torch.argmax(logits, dim=-1).item()) logits = logits / temperature if top_k and top_k > 0: k = min(top_k, logits.size(-1)) cand_logits, cand_indices = torch.topk(logits, k, dim=-1) if top_p < 1.0: sorted_logits, order = torch.sort(cand_logits, descending=True) sorted_indices = cand_indices.gather(-1, order) cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) remove = cum_probs > top_p remove[..., 0] = False sorted_logits = sorted_logits.masked_fill(remove, float("-inf")) probs = F.softmax(sorted_logits, dim=-1) return int(sorted_indices.gather(-1, torch.multinomial(probs, 1)).item()) probs = F.softmax(cand_logits, dim=-1) return int(cand_indices.gather(-1, torch.multinomial(probs, 1)).item()) if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) remove = cum_probs > top_p remove[..., 0] = False sorted_logits = sorted_logits.masked_fill(remove, float("-inf")) probs = F.softmax(sorted_logits, dim=-1) return int(sorted_indices.gather(-1, torch.multinomial(probs, 1)).item()) probs = F.softmax(logits, dim=-1) return int(torch.multinomial(probs, 1).item()) # --------------------------------------------------------------------------- # Generation loop # --------------------------------------------------------------------------- def generate(model: Chimera51ForCausalLM, tokenizer: ChimeraTokenizer, prompt: str, max_tokens: int = 100, temperature: float = 0.8, top_p: float = 0.9, top_k: int = 50, device: str = "cpu", bf16: bool = False, stream: bool = True) -> str: model.eval() prompt_ids = tokenizer.encode(prompt, add_special_tokens=False) if not prompt_ids: prompt_ids = [tokenizer.eos_token_id] input_ids = torch.tensor([prompt_ids], dtype=torch.long, device=device) print(f"\n[GEN] Prompt: {prompt!r}") print(f"[GEN] max_tokens={max_tokens}, temp={temperature}, top_p={top_p}, top_k={top_k}") print("=" * 60, flush=True) if stream: sys.stdout.write(prompt) sys.stdout.flush() generated = list(prompt_ids) decoded_so_far = tokenizer.decode(generated, skip_special_tokens=False) autocast_ctx = (torch.autocast(device_type=device.split(":")[0], dtype=torch.bfloat16) if bf16 else _nullctx()) t0 = time.time() with torch.inference_mode(), autocast_ctx: out = model(input_ids, use_cache=True, logits_to_keep=1) caches = out.caches next_token = _sample_next(out.logits[:, -1, :].float(), temperature, top_p, top_k) if next_token == tokenizer.eos_token_id: return tokenizer.decode(generated, skip_special_tokens=True) generated.append(next_token) for _ in range(max_tokens - 1): tok_t = torch.tensor([[next_token]], dtype=torch.long, device=device) out = model(tok_t, caches=caches, use_cache=True, logits_to_keep=1) caches = out.caches next_token = _sample_next(out.logits[:, -1, :].float(), temperature, top_p, top_k) if next_token == tokenizer.eos_token_id: break generated.append(next_token) if stream: full = tokenizer.decode(generated, skip_special_tokens=False) if full.startswith(decoded_so_far): sys.stdout.write(full[len(decoded_so_far):]) sys.stdout.flush() decoded_so_far = full elapsed = time.time() - t0 n_new = len(generated) - len(prompt_ids) speed = n_new / elapsed if elapsed > 0 else 0.0 final = tokenizer.decode(generated, skip_special_tokens=True) print() print("=" * 60) if not stream: print(final) print(f"[STATS] {n_new} new tokens in {elapsed:.2f}s ({speed:.1f} tok/s)") return final class _nullctx: def __enter__(self): return self def __exit__(self, *args): return False # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- def main() -> None: p = argparse.ArgumentParser(description="Chimera 5.2 CPU inference") p.add_argument("--checkpoint", default="chimera_output/final/model.pt") p.add_argument("--prompt", default="Once upon a time") p.add_argument("--max_tokens", type=int, default=100) p.add_argument("--temperature", type=float, default=0.8) p.add_argument("--top_p", type=float, default=0.9) p.add_argument("--top_k", type=int, default=50) p.add_argument("--device", default="cpu") p.add_argument("--bf16", action="store_true", default=True) p.add_argument("--no-bf16", dest="bf16", action="store_false") p.add_argument("--threads", type=int, default=None) p.add_argument("--compile", action="store_true", default=False) p.add_argument("--no-stream", dest="stream", action="store_false", default=True) args = p.parse_args() if args.threads: torch.set_num_threads(args.threads) os.environ["OMP_NUM_THREADS"] = str(args.threads) os.environ["MKL_NUM_THREADS"] = str(args.threads) if not os.path.exists(args.checkpoint): print(f"[ERROR] Checkpoint not found: {args.checkpoint}") return model, config = load_model(args.checkpoint, device=args.device) if args.compile: print("[OPT] Compiling model with torch.compile (mode=reduce-overhead)...") model = torch.compile(model, backend="inductor", mode="reduce-overhead") print("[LOAD] Loading tokenizer (splintr o200k_base)...") tokenizer = ChimeraTokenizer(pretrained="o200k_base") print("[WARM] Warmup forward...") with torch.inference_mode(): _ = model(torch.tensor([[tokenizer.eos_token_id]], device=args.device), logits_to_keep=1) print("[WARM] Done.") generate( model, tokenizer, prompt=args.prompt, max_tokens=args.max_tokens, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k, device=args.device, bf16=args.bf16, stream=args.stream, ) if __name__ == "__main__": main()