Lgr54HFi commited on
Commit
0a7fd59
·
verified ·
1 Parent(s): f1df870

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +45 -80
inference.py CHANGED
@@ -1,26 +1,5 @@
1
  #!/usr/bin/env python3
2
- """Chimera 5.2 — CPU-first inference / text generation.
3
-
4
- Significant CPU-friendly changes vs the previous draft:
5
-
6
- * **KV-cache aware loop** — after the first forward pass we only feed the
7
- new token plus the per-layer recurrent state into the model. This makes
8
- generation *O(T)* instead of *O(T²)*, the single biggest win for CPU
9
- decoding.
10
- * **Pre-pack BitLinear weights** at startup so the first decoded token does
11
- not pay the unpack/repack cost.
12
- * **Greedy fast path** (``temperature == 0``) skips softmax / sort entirely.
13
- * **Top-k constrained nucleus** — when both ``top_k`` and ``top_p`` are
14
- used we sort the top-k slice only (not the full 200K vocabulary).
15
- * **Streaming output** — tokens are decoded incrementally so the first
16
- bytes appear immediately.
17
-
18
- Usage::
19
-
20
- python inference.py --checkpoint chimera_output/final/model.pt \\
21
- --prompt "Once upon a time" --max_tokens 200
22
- """
23
-
24
  from __future__ import annotations
25
 
26
  import argparse
@@ -41,11 +20,9 @@ def _setup_cpu_runtime() -> None:
41
 
42
  _setup_cpu_runtime()
43
 
44
-
45
  import torch
46
  import torch.nn.functional as F
47
 
48
-
49
  try:
50
  torch.set_num_threads(int(os.environ.get("OMP_NUM_THREADS", os.cpu_count() or 4)))
51
  torch.set_num_interop_threads(int(os.environ.get("CHIMERA_INTEROP_THREADS", "1")))
@@ -57,9 +34,13 @@ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
57
  from chimera import Chimera51ForCausalLM, ChimeraTokenizer
58
 
59
 
60
- # ---------------------------------------------------------------------------
61
- # Checkpoint loading
62
- # ---------------------------------------------------------------------------
 
 
 
 
63
 
64
  def load_model(checkpoint_path: str, device: str = "cpu"):
65
  print(f"[LOAD] Checkpoint: {checkpoint_path}")
@@ -77,38 +58,45 @@ def load_model(checkpoint_path: str, device: str = "cpu"):
77
  else:
78
  print("[LOAD] Config from checkpoint")
79
 
80
- model = Chimera51ForCausalLM(config)
81
- counts = model.count_parameters()
82
- print(f"[LOAD] Params: {counts['total']:,} (ternary: {counts['ternary']:,})")
83
-
84
  state = ckpt.get("model", ckpt)
85
 
86
- # Reconcile vocab mismatches in either direction without crashing.
87
- model_vocab = int(config.get("vocab_size", model.embed.num_embeddings))
88
- ckpt_vocab = None
89
- for key in ("embed.weight", "lm_head.weight"):
90
- for sk, t in state.items():
91
- if sk.endswith(key):
92
- ckpt_vocab = int(t.shape[0])
93
- break
94
- if ckpt_vocab is not None:
95
- break
96
-
97
- if ckpt_vocab and ckpt_vocab != model_vocab:
98
- print(f"[WARN] vocab mismatch ckpt={ckpt_vocab} cfg={model_vocab}; resizing")
99
- with torch.no_grad():
100
- old = model.embed.weight.data
101
- new = torch.zeros(ckpt_vocab, old.shape[1], dtype=old.dtype, device=old.device)
102
- new[:min(old.shape[0], ckpt_vocab)] = old[:min(old.shape[0], ckpt_vocab)]
103
- model.embed = torch.nn.Embedding(ckpt_vocab, old.shape[1])
104
- model.embed.weight.data = new
105
- old_h = model.lm_head.weight.data
106
- new_h = torch.zeros(ckpt_vocab, old_h.shape[1], dtype=old_h.dtype, device=old_h.device)
107
- new_h[:min(old_h.shape[0], ckpt_vocab)] = old_h[:min(old_h.shape[0], ckpt_vocab)]
108
- model.lm_head = torch.nn.Linear(old_h.shape[1], ckpt_vocab, bias=False)
109
- model.lm_head.weight.data = new_h
110
  config["vocab_size"] = ckpt_vocab
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  missing, unexpected = model.load_state_dict(state, strict=False)
113
  if missing:
114
  print(f"[WARN] Missing keys ({len(missing)}): {missing[:5]}...")
@@ -116,7 +104,7 @@ def load_model(checkpoint_path: str, device: str = "cpu"):
116
  print(f"[WARN] Unexpected keys ({len(unexpected)}): {unexpected[:5]}...")
117
 
118
  model.to(device).eval()
119
- model.prepare_for_inference() # pre-pack ternary weights
120
 
121
  step = ckpt.get("step", "?")
122
  best_loss = ckpt.get("best_loss")
@@ -127,22 +115,13 @@ def load_model(checkpoint_path: str, device: str = "cpu"):
127
  return model, config
128
 
129
 
130
- # ---------------------------------------------------------------------------
131
- # Sampling helpers
132
- # ---------------------------------------------------------------------------
133
-
134
  def _sample_next(logits: torch.Tensor, temperature: float, top_p: float, top_k: int
135
  ) -> int:
136
- """Return the next token id sampled from ``logits`` ([1, V] or [V])."""
137
  if logits.dim() == 1:
138
  logits = logits.unsqueeze(0)
139
-
140
- # Greedy fast path.
141
  if temperature <= 0.0:
142
  return int(torch.argmax(logits, dim=-1).item())
143
-
144
  logits = logits / temperature
145
-
146
  if top_k and top_k > 0:
147
  k = min(top_k, logits.size(-1))
148
  cand_logits, cand_indices = torch.topk(logits, k, dim=-1)
@@ -157,7 +136,6 @@ def _sample_next(logits: torch.Tensor, temperature: float, top_p: float, top_k:
157
  return int(sorted_indices.gather(-1, torch.multinomial(probs, 1)).item())
158
  probs = F.softmax(cand_logits, dim=-1)
159
  return int(cand_indices.gather(-1, torch.multinomial(probs, 1)).item())
160
-
161
  if top_p < 1.0:
162
  sorted_logits, sorted_indices = torch.sort(logits, descending=True)
163
  cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
@@ -166,15 +144,10 @@ def _sample_next(logits: torch.Tensor, temperature: float, top_p: float, top_k:
166
  sorted_logits = sorted_logits.masked_fill(remove, float("-inf"))
167
  probs = F.softmax(sorted_logits, dim=-1)
168
  return int(sorted_indices.gather(-1, torch.multinomial(probs, 1)).item())
169
-
170
  probs = F.softmax(logits, dim=-1)
171
  return int(torch.multinomial(probs, 1).item())
172
 
173
 
174
- # ---------------------------------------------------------------------------
175
- # Generation loop
176
- # ---------------------------------------------------------------------------
177
-
178
  def generate(model: Chimera51ForCausalLM, tokenizer: ChimeraTokenizer,
179
  prompt: str, max_tokens: int = 100, temperature: float = 0.8,
180
  top_p: float = 0.9, top_k: int = 50, device: str = "cpu",
@@ -201,7 +174,6 @@ def generate(model: Chimera51ForCausalLM, tokenizer: ChimeraTokenizer,
201
 
202
  t0 = time.time()
203
  with torch.inference_mode(), autocast_ctx:
204
- # Initial pass: feed the whole prompt and capture per-layer caches.
205
  out = model(input_ids, use_cache=True, logits_to_keep=1)
206
  caches = out.caches
207
  next_token = _sample_next(out.logits[:, -1, :].float(), temperature, top_p, top_k)
@@ -218,7 +190,6 @@ def generate(model: Chimera51ForCausalLM, tokenizer: ChimeraTokenizer,
218
  break
219
  generated.append(next_token)
220
  if stream:
221
- # Try to render only the newly produced text.
222
  full = tokenizer.decode(generated, skip_special_tokens=False)
223
  if full.startswith(decoded_so_far):
224
  sys.stdout.write(full[len(decoded_so_far):])
@@ -241,15 +212,10 @@ def generate(model: Chimera51ForCausalLM, tokenizer: ChimeraTokenizer,
241
  class _nullctx:
242
  def __enter__(self):
243
  return self
244
-
245
  def __exit__(self, *args):
246
  return False
247
 
248
 
249
- # ---------------------------------------------------------------------------
250
- # CLI
251
- # ---------------------------------------------------------------------------
252
-
253
  def main() -> None:
254
  p = argparse.ArgumentParser(description="Chimera 5.2 CPU inference")
255
  p.add_argument("--checkpoint", default="chimera_output/final/model.pt")
@@ -286,8 +252,7 @@ def main() -> None:
286
 
287
  print("[WARM] Warmup forward...")
288
  with torch.inference_mode():
289
- _ = model(torch.tensor([[tokenizer.eos_token_id]], device=args.device),
290
- logits_to_keep=1)
291
  print("[WARM] Done.")
292
 
293
  generate(
 
1
  #!/usr/bin/env python3
2
+ """Chimera 5.2 — CPU-first inference / text generation."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from __future__ import annotations
4
 
5
  import argparse
 
20
 
21
  _setup_cpu_runtime()
22
 
 
23
  import torch
24
  import torch.nn.functional as F
25
 
 
26
  try:
27
  torch.set_num_threads(int(os.environ.get("OMP_NUM_THREADS", os.cpu_count() or 4)))
28
  torch.set_num_interop_threads(int(os.environ.get("CHIMERA_INTEROP_THREADS", "1")))
 
34
  from chimera import Chimera51ForCausalLM, ChimeraTokenizer
35
 
36
 
37
+ def _infer_dim(state, keys, idx):
38
+ for k in keys:
39
+ for sk, t in state.items():
40
+ if sk.endswith(k):
41
+ return int(t.shape[idx])
42
+ return None
43
+
44
 
45
  def load_model(checkpoint_path: str, device: str = "cpu"):
46
  print(f"[LOAD] Checkpoint: {checkpoint_path}")
 
58
  else:
59
  print("[LOAD] Config from checkpoint")
60
 
61
+ # ---- reconcile structural dims from checkpoint weights BEFORE model build ----
 
 
 
62
  state = ckpt.get("model", ckpt)
63
 
64
+ ckpt_vocab = _infer_dim(state, ["embed.weight", "lm_head.weight"], 0)
65
+ if ckpt_vocab and ckpt_vocab != config.get("vocab_size", ckpt_vocab):
66
+ print(f"[WARN] vocab_size mismatch ckpt={ckpt_vocab} cfg={config.get('vocab_size')}; resizing")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  config["vocab_size"] = ckpt_vocab
68
 
69
+ ckpt_hidden = _infer_dim(state, ["embed.weight", "lm_head.weight"], 1)
70
+ if ckpt_hidden and ckpt_hidden != config.get("hidden_size", ckpt_hidden):
71
+ print(f"[WARN] hidden_size mismatch ckpt={ckpt_hidden} cfg={config.get('hidden_size')}; resizing")
72
+ config["hidden_size"] = ckpt_hidden
73
+
74
+ # head_dim from any attention q_proj (shape [num_heads*head_dim, hidden_size])
75
+ ckpt_q = _infer_dim(state, ["layers.0.attn.q_proj.weight", "layers.1.attn.q_proj.weight"], 0)
76
+ if ckpt_q and ckpt_hidden:
77
+ head_dim_guess = config.get("head_dim")
78
+ num_heads_guess = config.get("num_heads", 40)
79
+ if head_dim_guess and ckpt_q != num_heads_guess * head_dim_guess:
80
+ # mismatch — try to infer actual head_dim from q_proj / num_heads
81
+ for nh in [1, 2, 4, 5, 8, 10, 16, 20, 32, 40, 64]:
82
+ if ckpt_q % nh == 0:
83
+ inferred_hd = ckpt_q // nh
84
+ if ckpt_hidden % inferred_hd == 0:
85
+ config["num_heads"] = nh
86
+ config["head_dim"] = inferred_hd
87
+ print(f"[WARN] auto-inferred num_heads={nh}, head_dim={inferred_hd} from q_proj={ckpt_q}")
88
+ break
89
+
90
+ ckpt_inter = _infer_dim(state, ["layers.0.ffn.gate_proj.weight", "layers.1.ffn.gate_proj.weight"], 0)
91
+ if ckpt_inter and ckpt_inter != config.get("intermediate_size", ckpt_inter):
92
+ print(f"[WARN] intermediate_size mismatch ckpt={ckpt_inter} cfg={config.get('intermediate_size')}; resizing")
93
+ config["intermediate_size"] = ckpt_inter
94
+ # ---------------------------------------------------------------------------
95
+
96
+ model = Chimera51ForCausalLM(config)
97
+ counts = model.count_parameters()
98
+ print(f"[LOAD] Params: {counts['total']:,} (ternary: {counts['ternary']:,})")
99
+
100
  missing, unexpected = model.load_state_dict(state, strict=False)
101
  if missing:
102
  print(f"[WARN] Missing keys ({len(missing)}): {missing[:5]}...")
 
104
  print(f"[WARN] Unexpected keys ({len(unexpected)}): {unexpected[:5]}...")
105
 
106
  model.to(device).eval()
107
+ model.prepare_for_inference()
108
 
109
  step = ckpt.get("step", "?")
110
  best_loss = ckpt.get("best_loss")
 
115
  return model, config
116
 
117
 
 
 
 
 
118
  def _sample_next(logits: torch.Tensor, temperature: float, top_p: float, top_k: int
119
  ) -> int:
 
120
  if logits.dim() == 1:
121
  logits = logits.unsqueeze(0)
 
 
122
  if temperature <= 0.0:
123
  return int(torch.argmax(logits, dim=-1).item())
 
124
  logits = logits / temperature
 
125
  if top_k and top_k > 0:
126
  k = min(top_k, logits.size(-1))
127
  cand_logits, cand_indices = torch.topk(logits, k, dim=-1)
 
136
  return int(sorted_indices.gather(-1, torch.multinomial(probs, 1)).item())
137
  probs = F.softmax(cand_logits, dim=-1)
138
  return int(cand_indices.gather(-1, torch.multinomial(probs, 1)).item())
 
139
  if top_p < 1.0:
140
  sorted_logits, sorted_indices = torch.sort(logits, descending=True)
141
  cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
 
144
  sorted_logits = sorted_logits.masked_fill(remove, float("-inf"))
145
  probs = F.softmax(sorted_logits, dim=-1)
146
  return int(sorted_indices.gather(-1, torch.multinomial(probs, 1)).item())
 
147
  probs = F.softmax(logits, dim=-1)
148
  return int(torch.multinomial(probs, 1).item())
149
 
150
 
 
 
 
 
151
  def generate(model: Chimera51ForCausalLM, tokenizer: ChimeraTokenizer,
152
  prompt: str, max_tokens: int = 100, temperature: float = 0.8,
153
  top_p: float = 0.9, top_k: int = 50, device: str = "cpu",
 
174
 
175
  t0 = time.time()
176
  with torch.inference_mode(), autocast_ctx:
 
177
  out = model(input_ids, use_cache=True, logits_to_keep=1)
178
  caches = out.caches
179
  next_token = _sample_next(out.logits[:, -1, :].float(), temperature, top_p, top_k)
 
190
  break
191
  generated.append(next_token)
192
  if stream:
 
193
  full = tokenizer.decode(generated, skip_special_tokens=False)
194
  if full.startswith(decoded_so_far):
195
  sys.stdout.write(full[len(decoded_so_far):])
 
212
  class _nullctx:
213
  def __enter__(self):
214
  return self
 
215
  def __exit__(self, *args):
216
  return False
217
 
218
 
 
 
 
 
219
  def main() -> None:
220
  p = argparse.ArgumentParser(description="Chimera 5.2 CPU inference")
221
  p.add_argument("--checkpoint", default="chimera_output/final/model.pt")
 
252
 
253
  print("[WARM] Warmup forward...")
254
  with torch.inference_mode():
255
+ _ = model(torch.tensor([[tokenizer.eos_token_id]], device=args.device), logits_to_keep=1)
 
256
  print("[WARM] Done.")
257
 
258
  generate(