chomera / inference.py
Lgr54HFi's picture
Upload folder using huggingface_hub
11c11f8 verified
#!/usr/bin/env python3
"""Chimera 5.2 — CPU-first inference / text generation.
Config is source of truth. Checkpoint weights are resized to match the model.
"""
from __future__ import annotations
import argparse
import json
import os
import time
from typing import Dict, Tuple
def _setup_cpu_runtime() -> None:
n = os.cpu_count() or 4
os.environ.setdefault("OMP_NUM_THREADS", str(n))
os.environ.setdefault("MKL_NUM_THREADS", str(n))
os.environ.setdefault("KMP_AFFINITY", "granularity=fine,compact,1,0")
os.environ.setdefault("KMP_BLOCKTIME", "1")
os.environ.setdefault("MALLOC_CONF", "background_thread:true,metadata_thp:auto")
_setup_cpu_runtime()
import torch
import torch.nn.functional as F
try:
torch.set_num_threads(int(os.environ.get("OMP_NUM_THREADS", os.cpu_count() or 4)))
torch.set_num_interop_threads(int(os.environ.get("CHIMERA_INTEROP_THREADS", "1")))
except RuntimeError:
pass
from chimera import Chimera51ForCausalLM, ChimeraTokenizer
from chimera.paths import DEFAULT_CONFIG_PATH
# ---------------------------------------------------------------------------
# Resize helpers: checkpoint weights -> model architecture (config is truth)
# ---------------------------------------------------------------------------
@torch.no_grad()
def _resize_1d(w: torch.Tensor, target: int) -> torch.Tensor:
out = torch.ones(target, dtype=w.dtype, device=w.device)
n = min(w.numel(), target)
out[:n] = w[:n]
return out
@torch.no_grad()
def _resize_2d(w: torch.Tensor, target_shape: Tuple[int, int]) -> torch.Tensor:
to, ti = target_shape
so, si = w.shape
if (so, si) == (to, ti):
return w
out = torch.empty((to, ti), dtype=w.dtype, device=w.device)
std = float(w.std(unbiased=False).item()) if w.numel() > 1 else 0.02
std = max(min(std, 0.2), 1e-4)
out.normal_(mean=0.0, std=std)
ro, ci = min(so, to), min(si, ti)
out[:ro, :ci] = w[:ro, :ci]
return out
# ---------------------------------------------------------------------------
# Checkpoint loading
# ---------------------------------------------------------------------------
def load_model(checkpoint_path: str, device: str = "cpu"):
print(f"[LOAD] Checkpoint: {checkpoint_path}")
ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
config = ckpt.get("config")
if config is None:
ckpt_dir = os.path.dirname(checkpoint_path)
cand = os.path.join(ckpt_dir, "config.json") if ckpt_dir else "config.json"
if not os.path.exists(cand):
cand = str(DEFAULT_CONFIG_PATH)
with open(cand, encoding="utf-8") as f:
config = json.load(f)
print(f"[LOAD] Config from {cand}")
else:
print("[LOAD] Config from checkpoint")
model = Chimera51ForCausalLM(config)
counts = model.count_parameters()
print(f"[LOAD] Params: {counts['total']:,} (ternary: {counts['ternary']:,})")
state = ckpt.get("model", ckpt)
model_state = model.state_dict()
# Config is source of truth: resize checkpoint tensors to match model.
resized: Dict[str, torch.Tensor] = {}
for k, v in state.items():
if k in model_state:
expected = model_state[k].shape
if v.shape != expected:
print(f"[WARN] resizing {k}: {tuple(v.shape)} -> {tuple(expected)}")
if v.ndim == 1:
v = _resize_1d(v, expected[0])
elif v.ndim == 2:
v = _resize_2d(v, expected)
else:
print(f"[SKIP] {k}: cannot resize {v.ndim}D tensor")
continue
resized[k] = v
else:
resized[k] = v
# Vocab reconciliation: if vocab mismatch, re-init embed + lm_head.
model_vocab = int(config.get("vocab_size", model.embed.num_embeddings))
if "embed.weight" in resized:
ckpt_vocab = int(resized["embed.weight"].shape[0])
if ckpt_vocab != model_vocab:
print(f"[WARN] vocab mismatch ckpt={ckpt_vocab} cfg={model_vocab}; re-init embed+head")
with torch.no_grad():
old = model.embed.weight.data
new = torch.zeros(ckpt_vocab, old.shape[1], dtype=old.dtype, device=old.device)
new[:min(old.shape[0], ckpt_vocab)] = old[:min(old.shape[0], ckpt_vocab)]
model.embed = torch.nn.Embedding(ckpt_vocab, old.shape[1])
model.embed.weight.data = new
old_h = model.lm_head.weight.data
new_h = torch.zeros(ckpt_vocab, old_h.shape[1], dtype=old_h.dtype, device=old_h.device)
new_h[:min(old_h.shape[0], ckpt_vocab)] = old_h[:min(old_h.shape[0], ckpt_vocab)]
model.lm_head = torch.nn.Linear(old_h.shape[1], ckpt_vocab, bias=False)
model.lm_head.weight.data = new_h
config["vocab_size"] = ckpt_vocab
missing, unexpected = model.load_state_dict(resized, strict=False)
if missing:
print(f"[WARN] Missing keys ({len(missing)}): {missing[:5]}...")
if unexpected:
print(f"[WARN] Unexpected keys ({len(unexpected)}): {unexpected[:5]}...")
model.to(device).eval()
model.prepare_for_inference()
step = ckpt.get("step", "?")
best_loss = ckpt.get("best_loss")
if best_loss is not None:
print(f"[LOAD] Step {step}, best_loss={best_loss:.4f}")
else:
print(f"[LOAD] Step {step}")
return model, config
# ---------------------------------------------------------------------------
# Sampling helpers
# ---------------------------------------------------------------------------
def _sample_next(logits: torch.Tensor, temperature: float, top_p: float, top_k: int
) -> int:
if logits.dim() == 1:
logits = logits.unsqueeze(0)
if temperature <= 0.0:
return int(torch.argmax(logits, dim=-1).item())
logits = logits / temperature
if top_k and top_k > 0:
k = min(top_k, logits.size(-1))
cand_logits, cand_indices = torch.topk(logits, k, dim=-1)
if top_p < 1.0:
sorted_logits, order = torch.sort(cand_logits, descending=True)
sorted_indices = cand_indices.gather(-1, order)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
remove = cum_probs > top_p
remove[..., 0] = False
sorted_logits = sorted_logits.masked_fill(remove, float("-inf"))
probs = F.softmax(sorted_logits, dim=-1)
return int(sorted_indices.gather(-1, torch.multinomial(probs, 1)).item())
probs = F.softmax(cand_logits, dim=-1)
return int(cand_indices.gather(-1, torch.multinomial(probs, 1)).item())
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
remove = cum_probs > top_p
remove[..., 0] = False
sorted_logits = sorted_logits.masked_fill(remove, float("-inf"))
probs = F.softmax(sorted_logits, dim=-1)
return int(sorted_indices.gather(-1, torch.multinomial(probs, 1)).item())
probs = F.softmax(logits, dim=-1)
return int(torch.multinomial(probs, 1).item())
# ---------------------------------------------------------------------------
# Generation loop
# ---------------------------------------------------------------------------
def generate(model: Chimera51ForCausalLM, tokenizer: ChimeraTokenizer,
prompt: str, max_tokens: int = 100, temperature: float = 0.8,
top_p: float = 0.9, top_k: int = 50, device: str = "cpu",
bf16: bool = False, stream: bool = True) -> str:
model.eval()
prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
if not prompt_ids:
prompt_ids = [tokenizer.eos_token_id]
input_ids = torch.tensor([prompt_ids], dtype=torch.long, device=device)
print(f"\n[GEN] Prompt: {prompt!r}")
print(f"[GEN] max_tokens={max_tokens}, temp={temperature}, top_p={top_p}, top_k={top_k}")
print("=" * 60, flush=True)
if stream:
sys.stdout.write(prompt)
sys.stdout.flush()
generated = list(prompt_ids)
decoded_so_far = tokenizer.decode(generated, skip_special_tokens=False)
autocast_ctx = (torch.autocast(device_type=device.split(":")[0], dtype=torch.bfloat16)
if bf16 else _nullctx())
t0 = time.time()
with torch.inference_mode(), autocast_ctx:
out = model(input_ids, use_cache=True, logits_to_keep=1)
caches = out.caches
next_token = _sample_next(out.logits[:, -1, :].float(), temperature, top_p, top_k)
if next_token == tokenizer.eos_token_id:
return tokenizer.decode(generated, skip_special_tokens=True)
generated.append(next_token)
for _ in range(max_tokens - 1):
tok_t = torch.tensor([[next_token]], dtype=torch.long, device=device)
out = model(tok_t, caches=caches, use_cache=True, logits_to_keep=1)
caches = out.caches
next_token = _sample_next(out.logits[:, -1, :].float(), temperature, top_p, top_k)
if next_token == tokenizer.eos_token_id:
break
generated.append(next_token)
if stream:
full = tokenizer.decode(generated, skip_special_tokens=False)
if full.startswith(decoded_so_far):
sys.stdout.write(full[len(decoded_so_far):])
sys.stdout.flush()
decoded_so_far = full
elapsed = time.time() - t0
n_new = len(generated) - len(prompt_ids)
speed = n_new / elapsed if elapsed > 0 else 0.0
final = tokenizer.decode(generated, skip_special_tokens=True)
print()
print("=" * 60)
if not stream:
print(final)
print(f"[STATS] {n_new} new tokens in {elapsed:.2f}s ({speed:.1f} tok/s)")
return final
class _nullctx:
def __enter__(self):
return self
def __exit__(self, *args):
return False
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def main() -> None:
p = argparse.ArgumentParser(description="Chimera 5.2 CPU inference")
p.add_argument("--checkpoint", default="chimera_output/final/model.pt")
p.add_argument("--prompt", default="Once upon a time")
p.add_argument("--max_tokens", type=int, default=100)
p.add_argument("--temperature", type=float, default=0.8)
p.add_argument("--top_p", type=float, default=0.9)
p.add_argument("--top_k", type=int, default=50)
p.add_argument("--device", default="cpu")
p.add_argument("--bf16", action="store_true", default=True)
p.add_argument("--no-bf16", dest="bf16", action="store_false")
p.add_argument("--threads", type=int, default=None)
p.add_argument("--compile", action="store_true", default=False)
p.add_argument("--no-stream", dest="stream", action="store_false", default=True)
args = p.parse_args()
if args.threads:
torch.set_num_threads(args.threads)
os.environ["OMP_NUM_THREADS"] = str(args.threads)
os.environ["MKL_NUM_THREADS"] = str(args.threads)
if not os.path.exists(args.checkpoint):
print(f"[ERROR] Checkpoint not found: {args.checkpoint}")
return
model, config = load_model(args.checkpoint, device=args.device)
if args.compile:
print("[OPT] Compiling model with torch.compile (mode=reduce-overhead)...")
model = torch.compile(model, backend="inductor", mode="reduce-overhead")
print("[LOAD] Loading tokenizer (splintr o200k_base)...")
tokenizer = ChimeraTokenizer(pretrained="o200k_base")
print("[WARM] Warmup forward...")
with torch.inference_mode():
_ = model(torch.tensor([[tokenizer.eos_token_id]], device=args.device), logits_to_keep=1)
print("[WARM] Done.")
generate(
model, tokenizer,
prompt=args.prompt, max_tokens=args.max_tokens,
temperature=args.temperature, top_p=args.top_p, top_k=args.top_k,
device=args.device, bf16=args.bf16, stream=args.stream,
)
if __name__ == "__main__":
main()