Lgr54HFi commited on
Commit
bc0ec84
·
verified ·
1 Parent(s): 0a7fd59

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +87 -43
inference.py CHANGED
@@ -1,5 +1,8 @@
1
  #!/usr/bin/env python3
2
- """Chimera 5.2 — CPU-first inference / text generation."""
 
 
 
3
  from __future__ import annotations
4
 
5
  import argparse
@@ -7,6 +10,7 @@ import json
7
  import os
8
  import sys
9
  import time
 
10
 
11
 
12
  def _setup_cpu_runtime() -> None:
@@ -34,13 +38,36 @@ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
34
  from chimera import Chimera51ForCausalLM, ChimeraTokenizer
35
 
36
 
37
- def _infer_dim(state, keys, idx):
38
- for k in keys:
39
- for sk, t in state.items():
40
- if sk.endswith(k):
41
- return int(t.shape[idx])
42
- return None
 
 
 
 
 
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  def load_model(checkpoint_path: str, device: str = "cpu"):
46
  print(f"[LOAD] Checkpoint: {checkpoint_path}")
@@ -58,46 +85,51 @@ def load_model(checkpoint_path: str, device: str = "cpu"):
58
  else:
59
  print("[LOAD] Config from checkpoint")
60
 
61
- # ---- reconcile structural dims from checkpoint weights BEFORE model build ----
62
- state = ckpt.get("model", ckpt)
63
-
64
- ckpt_vocab = _infer_dim(state, ["embed.weight", "lm_head.weight"], 0)
65
- if ckpt_vocab and ckpt_vocab != config.get("vocab_size", ckpt_vocab):
66
- print(f"[WARN] vocab_size mismatch ckpt={ckpt_vocab} cfg={config.get('vocab_size')}; resizing")
67
- config["vocab_size"] = ckpt_vocab
68
-
69
- ckpt_hidden = _infer_dim(state, ["embed.weight", "lm_head.weight"], 1)
70
- if ckpt_hidden and ckpt_hidden != config.get("hidden_size", ckpt_hidden):
71
- print(f"[WARN] hidden_size mismatch ckpt={ckpt_hidden} cfg={config.get('hidden_size')}; resizing")
72
- config["hidden_size"] = ckpt_hidden
73
-
74
- # head_dim from any attention q_proj (shape [num_heads*head_dim, hidden_size])
75
- ckpt_q = _infer_dim(state, ["layers.0.attn.q_proj.weight", "layers.1.attn.q_proj.weight"], 0)
76
- if ckpt_q and ckpt_hidden:
77
- head_dim_guess = config.get("head_dim")
78
- num_heads_guess = config.get("num_heads", 40)
79
- if head_dim_guess and ckpt_q != num_heads_guess * head_dim_guess:
80
- # mismatch — try to infer actual head_dim from q_proj / num_heads
81
- for nh in [1, 2, 4, 5, 8, 10, 16, 20, 32, 40, 64]:
82
- if ckpt_q % nh == 0:
83
- inferred_hd = ckpt_q // nh
84
- if ckpt_hidden % inferred_hd == 0:
85
- config["num_heads"] = nh
86
- config["head_dim"] = inferred_hd
87
- print(f"[WARN] auto-inferred num_heads={nh}, head_dim={inferred_hd} from q_proj={ckpt_q}")
88
- break
89
-
90
- ckpt_inter = _infer_dim(state, ["layers.0.ffn.gate_proj.weight", "layers.1.ffn.gate_proj.weight"], 0)
91
- if ckpt_inter and ckpt_inter != config.get("intermediate_size", ckpt_inter):
92
- print(f"[WARN] intermediate_size mismatch ckpt={ckpt_inter} cfg={config.get('intermediate_size')}; resizing")
93
- config["intermediate_size"] = ckpt_inter
94
- # ---------------------------------------------------------------------------
95
-
96
  model = Chimera51ForCausalLM(config)
97
  counts = model.count_parameters()
98
  print(f"[LOAD] Params: {counts['total']:,} (ternary: {counts['ternary']:,})")
99
 
100
- missing, unexpected = model.load_state_dict(state, strict=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  if missing:
102
  print(f"[WARN] Missing keys ({len(missing)}): {missing[:5]}...")
103
  if unexpected:
@@ -115,6 +147,10 @@ def load_model(checkpoint_path: str, device: str = "cpu"):
115
  return model, config
116
 
117
 
 
 
 
 
118
  def _sample_next(logits: torch.Tensor, temperature: float, top_p: float, top_k: int
119
  ) -> int:
120
  if logits.dim() == 1:
@@ -148,6 +184,10 @@ def _sample_next(logits: torch.Tensor, temperature: float, top_p: float, top_k:
148
  return int(torch.multinomial(probs, 1).item())
149
 
150
 
 
 
 
 
151
  def generate(model: Chimera51ForCausalLM, tokenizer: ChimeraTokenizer,
152
  prompt: str, max_tokens: int = 100, temperature: float = 0.8,
153
  top_p: float = 0.9, top_k: int = 50, device: str = "cpu",
@@ -216,6 +256,10 @@ class _nullctx:
216
  return False
217
 
218
 
 
 
 
 
219
  def main() -> None:
220
  p = argparse.ArgumentParser(description="Chimera 5.2 CPU inference")
221
  p.add_argument("--checkpoint", default="chimera_output/final/model.pt")
 
1
  #!/usr/bin/env python3
2
+ """Chimera 5.2 — CPU-first inference / text generation.
3
+
4
+ Config is source of truth. Checkpoint weights are resized to match the model.
5
+ """
6
  from __future__ import annotations
7
 
8
  import argparse
 
10
  import os
11
  import sys
12
  import time
13
+ from typing import Dict, Tuple
14
 
15
 
16
  def _setup_cpu_runtime() -> None:
 
38
  from chimera import Chimera51ForCausalLM, ChimeraTokenizer
39
 
40
 
41
+ # ---------------------------------------------------------------------------
42
+ # Resize helpers: checkpoint weights -> model architecture (config is truth)
43
+ # ---------------------------------------------------------------------------
44
+
45
+ @torch.no_grad()
46
+ def _resize_1d(w: torch.Tensor, target: int) -> torch.Tensor:
47
+ out = torch.ones(target, dtype=w.dtype, device=w.device)
48
+ n = min(w.numel(), target)
49
+ out[:n] = w[:n]
50
+ return out
51
+
52
 
53
+ @torch.no_grad()
54
+ def _resize_2d(w: torch.Tensor, target_shape: Tuple[int, int]) -> torch.Tensor:
55
+ to, ti = target_shape
56
+ so, si = w.shape
57
+ if (so, si) == (to, ti):
58
+ return w
59
+ out = torch.empty((to, ti), dtype=w.dtype, device=w.device)
60
+ std = float(w.std(unbiased=False).item()) if w.numel() > 1 else 0.02
61
+ std = max(min(std, 0.2), 1e-4)
62
+ out.normal_(mean=0.0, std=std)
63
+ ro, ci = min(so, to), min(si, ti)
64
+ out[:ro, :ci] = w[:ro, :ci]
65
+ return out
66
+
67
+
68
+ # ---------------------------------------------------------------------------
69
+ # Checkpoint loading
70
+ # ---------------------------------------------------------------------------
71
 
72
  def load_model(checkpoint_path: str, device: str = "cpu"):
73
  print(f"[LOAD] Checkpoint: {checkpoint_path}")
 
85
  else:
86
  print("[LOAD] Config from checkpoint")
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  model = Chimera51ForCausalLM(config)
89
  counts = model.count_parameters()
90
  print(f"[LOAD] Params: {counts['total']:,} (ternary: {counts['ternary']:,})")
91
 
92
+ state = ckpt.get("model", ckpt)
93
+ model_state = model.state_dict()
94
+
95
+ # Config is source of truth: resize checkpoint tensors to match model.
96
+ resized: Dict[str, torch.Tensor] = {}
97
+ for k, v in state.items():
98
+ if k in model_state:
99
+ expected = model_state[k].shape
100
+ if v.shape != expected:
101
+ print(f"[WARN] resizing {k}: {tuple(v.shape)} -> {tuple(expected)}")
102
+ if v.ndim == 1:
103
+ v = _resize_1d(v, expected[0])
104
+ elif v.ndim == 2:
105
+ v = _resize_2d(v, expected)
106
+ else:
107
+ print(f"[SKIP] {k}: cannot resize {v.ndim}D tensor")
108
+ continue
109
+ resized[k] = v
110
+ else:
111
+ resized[k] = v
112
+
113
+ # Vocab reconciliation: if vocab mismatch, re-init embed + lm_head.
114
+ model_vocab = int(config.get("vocab_size", model.embed.num_embeddings))
115
+ if "embed.weight" in resized:
116
+ ckpt_vocab = int(resized["embed.weight"].shape[0])
117
+ if ckpt_vocab != model_vocab:
118
+ print(f"[WARN] vocab mismatch ckpt={ckpt_vocab} cfg={model_vocab}; re-init embed+head")
119
+ with torch.no_grad():
120
+ old = model.embed.weight.data
121
+ new = torch.zeros(ckpt_vocab, old.shape[1], dtype=old.dtype, device=old.device)
122
+ new[:min(old.shape[0], ckpt_vocab)] = old[:min(old.shape[0], ckpt_vocab)]
123
+ model.embed = torch.nn.Embedding(ckpt_vocab, old.shape[1])
124
+ model.embed.weight.data = new
125
+ old_h = model.lm_head.weight.data
126
+ new_h = torch.zeros(ckpt_vocab, old_h.shape[1], dtype=old_h.dtype, device=old_h.device)
127
+ new_h[:min(old_h.shape[0], ckpt_vocab)] = old_h[:min(old_h.shape[0], ckpt_vocab)]
128
+ model.lm_head = torch.nn.Linear(old_h.shape[1], ckpt_vocab, bias=False)
129
+ model.lm_head.weight.data = new_h
130
+ config["vocab_size"] = ckpt_vocab
131
+
132
+ missing, unexpected = model.load_state_dict(resized, strict=False)
133
  if missing:
134
  print(f"[WARN] Missing keys ({len(missing)}): {missing[:5]}...")
135
  if unexpected:
 
147
  return model, config
148
 
149
 
150
+ # ---------------------------------------------------------------------------
151
+ # Sampling helpers
152
+ # ---------------------------------------------------------------------------
153
+
154
  def _sample_next(logits: torch.Tensor, temperature: float, top_p: float, top_k: int
155
  ) -> int:
156
  if logits.dim() == 1:
 
184
  return int(torch.multinomial(probs, 1).item())
185
 
186
 
187
+ # ---------------------------------------------------------------------------
188
+ # Generation loop
189
+ # ---------------------------------------------------------------------------
190
+
191
  def generate(model: Chimera51ForCausalLM, tokenizer: ChimeraTokenizer,
192
  prompt: str, max_tokens: int = 100, temperature: float = 0.8,
193
  top_p: float = 0.9, top_k: int = 50, device: str = "cpu",
 
256
  return False
257
 
258
 
259
+ # ---------------------------------------------------------------------------
260
+ # CLI
261
+ # ---------------------------------------------------------------------------
262
+
263
  def main() -> None:
264
  p = argparse.ArgumentParser(description="Chimera 5.2 CPU inference")
265
  p.add_argument("--checkpoint", default="chimera_output/final/model.pt")