| |
| """Chimera 5.2 — CPU-first inference / text generation. |
| |
| Config is source of truth. Checkpoint weights are resized to match the model. |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| import time |
| from typing import Dict, Tuple |
|
|
|
|
| def _setup_cpu_runtime() -> None: |
| n = os.cpu_count() or 4 |
| os.environ.setdefault("OMP_NUM_THREADS", str(n)) |
| os.environ.setdefault("MKL_NUM_THREADS", str(n)) |
| os.environ.setdefault("KMP_AFFINITY", "granularity=fine,compact,1,0") |
| os.environ.setdefault("KMP_BLOCKTIME", "1") |
| os.environ.setdefault("MALLOC_CONF", "background_thread:true,metadata_thp:auto") |
|
|
|
|
| _setup_cpu_runtime() |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
| try: |
| torch.set_num_threads(int(os.environ.get("OMP_NUM_THREADS", os.cpu_count() or 4))) |
| torch.set_num_interop_threads(int(os.environ.get("CHIMERA_INTEROP_THREADS", "1"))) |
| except RuntimeError: |
| pass |
|
|
| from chimera import Chimera51ForCausalLM, ChimeraTokenizer |
| from chimera.paths import DEFAULT_CONFIG_PATH |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def _resize_1d(w: torch.Tensor, target: int) -> torch.Tensor: |
| out = torch.ones(target, dtype=w.dtype, device=w.device) |
| n = min(w.numel(), target) |
| out[:n] = w[:n] |
| return out |
|
|
|
|
| @torch.no_grad() |
| def _resize_2d(w: torch.Tensor, target_shape: Tuple[int, int]) -> torch.Tensor: |
| to, ti = target_shape |
| so, si = w.shape |
| if (so, si) == (to, ti): |
| return w |
| out = torch.empty((to, ti), dtype=w.dtype, device=w.device) |
| std = float(w.std(unbiased=False).item()) if w.numel() > 1 else 0.02 |
| std = max(min(std, 0.2), 1e-4) |
| out.normal_(mean=0.0, std=std) |
| ro, ci = min(so, to), min(si, ti) |
| out[:ro, :ci] = w[:ro, :ci] |
| return out |
|
|
|
|
| |
| |
| |
|
|
| def load_model(checkpoint_path: str, device: str = "cpu"): |
| print(f"[LOAD] Checkpoint: {checkpoint_path}") |
| ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False) |
|
|
| config = ckpt.get("config") |
| if config is None: |
| ckpt_dir = os.path.dirname(checkpoint_path) |
| cand = os.path.join(ckpt_dir, "config.json") if ckpt_dir else "config.json" |
| if not os.path.exists(cand): |
| cand = str(DEFAULT_CONFIG_PATH) |
| with open(cand, encoding="utf-8") as f: |
| config = json.load(f) |
| print(f"[LOAD] Config from {cand}") |
| else: |
| print("[LOAD] Config from checkpoint") |
|
|
| model = Chimera51ForCausalLM(config) |
| counts = model.count_parameters() |
| print(f"[LOAD] Params: {counts['total']:,} (ternary: {counts['ternary']:,})") |
|
|
| state = ckpt.get("model", ckpt) |
| model_state = model.state_dict() |
|
|
| |
| resized: Dict[str, torch.Tensor] = {} |
| for k, v in state.items(): |
| if k in model_state: |
| expected = model_state[k].shape |
| if v.shape != expected: |
| print(f"[WARN] resizing {k}: {tuple(v.shape)} -> {tuple(expected)}") |
| if v.ndim == 1: |
| v = _resize_1d(v, expected[0]) |
| elif v.ndim == 2: |
| v = _resize_2d(v, expected) |
| else: |
| print(f"[SKIP] {k}: cannot resize {v.ndim}D tensor") |
| continue |
| resized[k] = v |
| else: |
| resized[k] = v |
|
|
| |
| model_vocab = int(config.get("vocab_size", model.embed.num_embeddings)) |
| if "embed.weight" in resized: |
| ckpt_vocab = int(resized["embed.weight"].shape[0]) |
| if ckpt_vocab != model_vocab: |
| print(f"[WARN] vocab mismatch ckpt={ckpt_vocab} cfg={model_vocab}; re-init embed+head") |
| with torch.no_grad(): |
| old = model.embed.weight.data |
| new = torch.zeros(ckpt_vocab, old.shape[1], dtype=old.dtype, device=old.device) |
| new[:min(old.shape[0], ckpt_vocab)] = old[:min(old.shape[0], ckpt_vocab)] |
| model.embed = torch.nn.Embedding(ckpt_vocab, old.shape[1]) |
| model.embed.weight.data = new |
| old_h = model.lm_head.weight.data |
| new_h = torch.zeros(ckpt_vocab, old_h.shape[1], dtype=old_h.dtype, device=old_h.device) |
| new_h[:min(old_h.shape[0], ckpt_vocab)] = old_h[:min(old_h.shape[0], ckpt_vocab)] |
| model.lm_head = torch.nn.Linear(old_h.shape[1], ckpt_vocab, bias=False) |
| model.lm_head.weight.data = new_h |
| config["vocab_size"] = ckpt_vocab |
|
|
| missing, unexpected = model.load_state_dict(resized, strict=False) |
| if missing: |
| print(f"[WARN] Missing keys ({len(missing)}): {missing[:5]}...") |
| if unexpected: |
| print(f"[WARN] Unexpected keys ({len(unexpected)}): {unexpected[:5]}...") |
|
|
| model.to(device).eval() |
| model.prepare_for_inference() |
|
|
| step = ckpt.get("step", "?") |
| best_loss = ckpt.get("best_loss") |
| if best_loss is not None: |
| print(f"[LOAD] Step {step}, best_loss={best_loss:.4f}") |
| else: |
| print(f"[LOAD] Step {step}") |
| return model, config |
|
|
|
|
| |
| |
| |
|
|
| def _sample_next(logits: torch.Tensor, temperature: float, top_p: float, top_k: int |
| ) -> int: |
| if logits.dim() == 1: |
| logits = logits.unsqueeze(0) |
| if temperature <= 0.0: |
| return int(torch.argmax(logits, dim=-1).item()) |
| logits = logits / temperature |
| if top_k and top_k > 0: |
| k = min(top_k, logits.size(-1)) |
| cand_logits, cand_indices = torch.topk(logits, k, dim=-1) |
| if top_p < 1.0: |
| sorted_logits, order = torch.sort(cand_logits, descending=True) |
| sorted_indices = cand_indices.gather(-1, order) |
| cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| remove = cum_probs > top_p |
| remove[..., 0] = False |
| sorted_logits = sorted_logits.masked_fill(remove, float("-inf")) |
| probs = F.softmax(sorted_logits, dim=-1) |
| return int(sorted_indices.gather(-1, torch.multinomial(probs, 1)).item()) |
| probs = F.softmax(cand_logits, dim=-1) |
| return int(cand_indices.gather(-1, torch.multinomial(probs, 1)).item()) |
| if top_p < 1.0: |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
| cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| remove = cum_probs > top_p |
| remove[..., 0] = False |
| sorted_logits = sorted_logits.masked_fill(remove, float("-inf")) |
| probs = F.softmax(sorted_logits, dim=-1) |
| return int(sorted_indices.gather(-1, torch.multinomial(probs, 1)).item()) |
| probs = F.softmax(logits, dim=-1) |
| return int(torch.multinomial(probs, 1).item()) |
|
|
|
|
| |
| |
| |
|
|
| def generate(model: Chimera51ForCausalLM, tokenizer: ChimeraTokenizer, |
| prompt: str, max_tokens: int = 100, temperature: float = 0.8, |
| top_p: float = 0.9, top_k: int = 50, device: str = "cpu", |
| bf16: bool = False, stream: bool = True) -> str: |
| model.eval() |
| prompt_ids = tokenizer.encode(prompt, add_special_tokens=False) |
| if not prompt_ids: |
| prompt_ids = [tokenizer.eos_token_id] |
| input_ids = torch.tensor([prompt_ids], dtype=torch.long, device=device) |
|
|
| print(f"\n[GEN] Prompt: {prompt!r}") |
| print(f"[GEN] max_tokens={max_tokens}, temp={temperature}, top_p={top_p}, top_k={top_k}") |
| print("=" * 60, flush=True) |
|
|
| if stream: |
| sys.stdout.write(prompt) |
| sys.stdout.flush() |
|
|
| generated = list(prompt_ids) |
| decoded_so_far = tokenizer.decode(generated, skip_special_tokens=False) |
|
|
| autocast_ctx = (torch.autocast(device_type=device.split(":")[0], dtype=torch.bfloat16) |
| if bf16 else _nullctx()) |
|
|
| t0 = time.time() |
| with torch.inference_mode(), autocast_ctx: |
| out = model(input_ids, use_cache=True, logits_to_keep=1) |
| caches = out.caches |
| next_token = _sample_next(out.logits[:, -1, :].float(), temperature, top_p, top_k) |
| if next_token == tokenizer.eos_token_id: |
| return tokenizer.decode(generated, skip_special_tokens=True) |
| generated.append(next_token) |
|
|
| for _ in range(max_tokens - 1): |
| tok_t = torch.tensor([[next_token]], dtype=torch.long, device=device) |
| out = model(tok_t, caches=caches, use_cache=True, logits_to_keep=1) |
| caches = out.caches |
| next_token = _sample_next(out.logits[:, -1, :].float(), temperature, top_p, top_k) |
| if next_token == tokenizer.eos_token_id: |
| break |
| generated.append(next_token) |
| if stream: |
| full = tokenizer.decode(generated, skip_special_tokens=False) |
| if full.startswith(decoded_so_far): |
| sys.stdout.write(full[len(decoded_so_far):]) |
| sys.stdout.flush() |
| decoded_so_far = full |
|
|
| elapsed = time.time() - t0 |
| n_new = len(generated) - len(prompt_ids) |
| speed = n_new / elapsed if elapsed > 0 else 0.0 |
| final = tokenizer.decode(generated, skip_special_tokens=True) |
|
|
| print() |
| print("=" * 60) |
| if not stream: |
| print(final) |
| print(f"[STATS] {n_new} new tokens in {elapsed:.2f}s ({speed:.1f} tok/s)") |
| return final |
|
|
|
|
| class _nullctx: |
| def __enter__(self): |
| return self |
| def __exit__(self, *args): |
| return False |
|
|
|
|
| |
| |
| |
|
|
| def main() -> None: |
| p = argparse.ArgumentParser(description="Chimera 5.2 CPU inference") |
| p.add_argument("--checkpoint", default="chimera_output/final/model.pt") |
| p.add_argument("--prompt", default="Once upon a time") |
| p.add_argument("--max_tokens", type=int, default=100) |
| p.add_argument("--temperature", type=float, default=0.8) |
| p.add_argument("--top_p", type=float, default=0.9) |
| p.add_argument("--top_k", type=int, default=50) |
| p.add_argument("--device", default="cpu") |
| p.add_argument("--bf16", action="store_true", default=True) |
| p.add_argument("--no-bf16", dest="bf16", action="store_false") |
| p.add_argument("--threads", type=int, default=None) |
| p.add_argument("--compile", action="store_true", default=False) |
| p.add_argument("--no-stream", dest="stream", action="store_false", default=True) |
| args = p.parse_args() |
|
|
| if args.threads: |
| torch.set_num_threads(args.threads) |
| os.environ["OMP_NUM_THREADS"] = str(args.threads) |
| os.environ["MKL_NUM_THREADS"] = str(args.threads) |
|
|
| if not os.path.exists(args.checkpoint): |
| print(f"[ERROR] Checkpoint not found: {args.checkpoint}") |
| return |
|
|
| model, config = load_model(args.checkpoint, device=args.device) |
|
|
| if args.compile: |
| print("[OPT] Compiling model with torch.compile (mode=reduce-overhead)...") |
| model = torch.compile(model, backend="inductor", mode="reduce-overhead") |
|
|
| print("[LOAD] Loading tokenizer (splintr o200k_base)...") |
| tokenizer = ChimeraTokenizer(pretrained="o200k_base") |
|
|
| print("[WARM] Warmup forward...") |
| with torch.inference_mode(): |
| _ = model(torch.tensor([[tokenizer.eos_token_id]], device=args.device), logits_to_keep=1) |
| print("[WARM] Done.") |
|
|
| generate( |
| model, tokenizer, |
| prompt=args.prompt, max_tokens=args.max_tokens, |
| temperature=args.temperature, top_p=args.top_p, top_k=args.top_k, |
| device=args.device, bf16=args.bf16, stream=args.stream, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|