Lgr54HFi commited on
Commit
dda344d
·
verified ·
1 Parent(s): 6d5c935

Skip SpanEngine/Grammar/DebtLedger during training (inference-only ops on 200K logits)

Browse files
Files changed (1) hide show
  1. chimera/model.py +37 -109
chimera/model.py CHANGED
@@ -32,16 +32,10 @@ from .multimodal import VisionEncoder, AudioEncoder
32
 
33
 
34
  class CausalLMOutput(dict):
35
- """Light HF-compatible output dict supporting tuple unpacking."""
36
-
37
- def __init__(self, loss: Optional[torch.Tensor] = None,
38
- logits: Optional[torch.Tensor] = None,
39
- hidden_states: Optional[torch.Tensor] = None,
40
- caches: Optional[list] = None,
41
- evolution_metrics: Optional[dict] = None):
42
- super().__init__(loss=loss, logits=logits,
43
- hidden_states=hidden_states, caches=caches,
44
- evolution_metrics=evolution_metrics)
45
  self.loss = loss
46
  self.logits = logits
47
  self.hidden_states = hidden_states
@@ -53,8 +47,7 @@ class CausalLMOutput(dict):
53
  yield self.logits
54
 
55
 
56
- def expand_layer_pattern(config: dict) -> List[str]:
57
- """Expand the layer-pattern shorthand into a list."""
58
  backbone = config.get("backbone", {})
59
  pattern_str = backbone.get("layer_pattern", "GD XM GD TM GD XM GD SK")
60
  aliases = backbone.get("layer_aliases", {
@@ -68,12 +61,9 @@ def expand_layer_pattern(config: dict) -> List[str]:
68
 
69
 
70
  class Chimera51Block(nn.Module):
71
- """One block with evolution-aware forward."""
72
-
73
  _RECURRENT = {"gated_deltanet", "xlstm_m", "titans_mac", "tsp_span_knot"}
74
 
75
- def __init__(self, config: dict, layer_type: str, layer_idx: int,
76
- use_moe: bool = False):
77
  super().__init__()
78
  h = int(config["hidden_size"])
79
  eps = float(config.get("rms_norm_eps", 1e-6))
@@ -87,22 +77,18 @@ class Chimera51Block(nn.Module):
87
  self.attn_norm = RMSNorm(h, eps=eps)
88
 
89
  if layer_type == "gated_deltanet":
90
- self.attn = GatedDeltaNetLayer(h, heads, head_dim, norm_eps=eps,
91
- chunk_size=chunk_sz, use_ternary=ternary)
92
  elif layer_type == "xlstm_m":
93
  mem_h = config.get("xlstm", {}).get("memory_size_per_head", [head_dim, head_dim])
94
- self.attn = MLSTMLayer(h, heads, int(mem_h[0]), norm_eps=eps,
95
- use_ternary=ternary)
96
  elif layer_type == "titans_mac":
97
  tc = config.get("titans", {})
98
- self.attn = TitansMACLayer(h, heads, head_dim,
99
- memory_depth=int(tc.get("memory_depth", 2)),
100
  persistent_slots=int(tc.get("persistent_memory_slots", 64)),
101
  local_window=int(tc.get("local_window_size", 1024)),
102
  norm_eps=eps, use_ternary=ternary)
103
  elif layer_type == "tsp_span_knot":
104
- self.attn = TSPSpanKnotLayer(h, heads, head_dim, norm_eps=eps,
105
- chunk_size=chunk_sz, use_ternary=ternary)
106
  else:
107
  raise ValueError(f"Unknown layer type: {layer_type}")
108
 
@@ -110,45 +96,33 @@ class Chimera51Block(nn.Module):
110
  self.use_moe = bool(use_moe)
111
  if self.use_moe:
112
  moe_cfg = config.get("backbone", {}).get("moe", {})
113
- self.mlp = MoELayer(
114
- hidden_size=h,
115
  moe_intermediate_size=int(moe_cfg.get("moe_intermediate_size", h * 2)),
116
  n_routed_experts=int(moe_cfg.get("n_routed_experts", 16)),
117
  n_shared_experts=int(moe_cfg.get("n_shared_experts", 1)),
118
  num_experts_per_tok=int(moe_cfg.get("num_experts_per_tok", 2)),
119
- use_ternary=ternary,
120
- )
121
  else:
122
  inter = int(config.get("intermediate_size", int(h * 8 / 3)))
123
  inter = 256 * ((inter + 255) // 256)
124
  self.mlp = SwiGLUMLP(h, inter, use_ternary=ternary)
125
 
126
- # Evolution modulation projection (learnable scale)
127
  self.evo_gate = nn.Linear(h, h, bias=False)
128
  nn.init.zeros_(self.evo_gate.weight)
129
 
130
- def forward(self, x: torch.Tensor, cache: Optional[dict] = None,
131
- evo_modulation: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, dict]:
132
- # Apply attention with pre-norm
133
  normed = self.attn_norm(x)
134
  attn_out, new_cache = self.attn(normed, cache=cache)
135
  x = x + attn_out
136
-
137
- # Apply MLP with pre-norm
138
  x = x + self.mlp(self.mlp_norm(x))
139
-
140
- # Apply evolution modulation (gated residual)
141
  if evo_modulation is not None:
142
  gate = torch.sigmoid(self.evo_gate(x))
143
  x = x + gate * evo_modulation
144
-
145
  return x, new_cache
146
 
147
 
148
  class Chimera51ForCausalLM(nn.Module):
149
- """Chimera 5.x causal language model with functional self-evolution."""
150
-
151
- def __init__(self, config: dict):
152
  super().__init__()
153
  self.config = config
154
  h = int(config["hidden_size"])
@@ -159,19 +133,15 @@ class Chimera51ForCausalLM(nn.Module):
159
  self.embed = nn.Embedding(vocab, h)
160
  layer_types = expand_layer_pattern(config)
161
  moe_layers = set(int(i) for i in config.get("backbone", {}).get("moe", {}).get("layers", []))
162
-
163
  self.layers = nn.ModuleList([
164
  Chimera51Block(config, layer_types[i], i, use_moe=(i in moe_layers))
165
  for i in range(n_layers)
166
  ])
167
-
168
  self.norm = RMSNorm(h, eps=eps)
169
  self.lm_head = nn.Linear(h, vocab, bias=False)
170
-
171
  if config.get("tie_word_embeddings", True):
172
  self.lm_head.weight = self.embed.weight
173
 
174
- # Parcae looping controller
175
  loop_cfg = config.get("looping", {})
176
  self.looping_enabled = bool(loop_cfg.get("enabled", True)) and n_layers >= 3
177
  if self.looping_enabled:
@@ -181,24 +151,20 @@ class Chimera51ForCausalLM(nn.Module):
181
  self.loop_controller = ParcaeLoopController(
182
  h, loop_range=tuple(loop_cfg.get("loop_range", [1, 6])),
183
  loop_default=int(loop_cfg.get("loop_default", 2)),
184
- adaptive_exit_threshold=float(loop_cfg.get("adaptive_exit_threshold", 0.01)),
185
- )
186
 
187
- # Inference systems
188
  si_cfg = config.get("span_inference", {})
189
  self.span_engine = SpanInferenceEngine(h, si_cfg) if si_cfg.get("enabled", True) else None
190
  self.grammar = GrammarFST(config.get("grammar", {}))
191
  self.entropy_valve = EntropyValve(config.get("entropy_valve", {}))
192
  self.debt_ledger = DebtLedger(config.get("debt_ledger", {}))
193
 
194
- # Self-evolution — FUNCTIONAL
195
  evo_cfg = dict(config.get("self_evolution", {}))
196
  evo_cfg["_semantic_memory_config"] = config.get("semantic_memory", {})
197
  self.evolution = SelfEvolutionEngine(evo_cfg, h)
198
  self.evo_weight = float(config.get("evolution_loss_weight", 0.01))
199
  self.evo_every_n_layers = int(config.get("evolution_every_n_layers", 4))
200
 
201
- # Multimodal
202
  mm_cfg = dict(config.get("multimodal", {}))
203
  mm_cfg["hidden_size"] = h
204
  if mm_cfg.get("enabled", False):
@@ -212,19 +178,19 @@ class Chimera51ForCausalLM(nn.Module):
212
  self._init_weights()
213
  self._wire_semantic_memory()
214
 
215
- def enable_gradient_checkpointing(self) -> None:
216
  self.gradient_checkpointing = True
217
 
218
- def disable_gradient_checkpointing(self) -> None:
219
  self.gradient_checkpointing = False
220
 
221
- def _wire_semantic_memory(self) -> None:
222
  mem = self.evolution.semantic_memory
223
  for layer in self.layers:
224
  if hasattr(layer.attn, "set_semantic_memory"):
225
  layer.attn.set_semantic_memory(mem)
226
 
227
- def _init_weights(self) -> None:
228
  init_range = float(self.config.get("initializer_range", 0.006))
229
  for module in self.modules():
230
  if isinstance(module, (nn.Linear, BitLinear)):
@@ -238,11 +204,7 @@ class Chimera51ForCausalLM(nn.Module):
238
  if isinstance(module, BitLinear):
239
  module.invalidate_packed()
240
 
241
- def _run_layers(self, x: torch.Tensor, start: int, end: int,
242
- caches: Optional[list],
243
- compute_logits: bool = False,
244
- labels: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor], list]:
245
- """Run layers with evolution hooks. Returns (x, logits_if_computed, caches)."""
246
  all_metrics = []
247
  logits = None
248
  evolution_loss = torch.tensor(0.0, device=x.device)
@@ -250,21 +212,15 @@ class Chimera51ForCausalLM(nn.Module):
250
  for i in range(start, min(end + 1, len(self.layers))):
251
  layer = self.layers[i]
252
  cache = caches[i] if caches is not None else None
253
-
254
- # Evolution modulation every N layers (lightweight)
255
  evo_mod = None
256
  if i % self.evo_every_n_layers == 0 and self.evolution is not None:
257
  evo_result = self.evolution(
258
  hidden_states=x.detach() if not x.requires_grad else x,
259
- layer_idx=i,
260
- loss=None
261
- )
262
  evo_mod = evo_result['modulation']
263
  if evo_result['evolution_loss'] is not None:
264
  evolution_loss = evolution_loss + evo_result['evolution_loss']
265
  all_metrics.append(evo_result.get('metrics', {}))
266
-
267
- # TTT update for target layers (only in training, no backprop)
268
  if self.training and evo_result.get('ttt_delta') is not None:
269
  with torch.no_grad():
270
  if hasattr(layer.mlp, 'w_down'):
@@ -280,24 +236,16 @@ class Chimera51ForCausalLM(nn.Module):
280
  if caches is not None:
281
  caches[i] = new_cache
282
 
283
- # Compute probe logits for entropy valve (every few layers)
284
  if compute_logits and i == end:
285
  logits = self.lm_head(self.norm(x[:, -1:, :]))
286
 
287
  return x, logits, caches, evolution_loss, all_metrics
288
 
289
- def forward(self, input_ids: torch.Tensor,
290
- labels: Optional[torch.Tensor] = None,
291
- pixel_values: Optional[torch.Tensor] = None,
292
- mel_features: Optional[torch.Tensor] = None,
293
- num_loops: Optional[int] = None,
294
- caches: Optional[list] = None,
295
- use_cache: bool = False,
296
- logits_to_keep: int = 0,
297
- return_evolution_metrics: bool = False):
298
  x = self.embed(input_ids)
299
 
300
- # Multimodal prepend
301
  if pixel_values is not None and self.vision_encoder is not None:
302
  v = self.vision_encoder(pixel_values)
303
  if v is not None:
@@ -313,31 +261,23 @@ class Chimera51ForCausalLM(nn.Module):
313
  total_evo_loss = torch.tensor(0.0, device=x.device)
314
  all_evo_metrics = []
315
 
316
- # Prelude + Loop + Coda with evolution
317
  if self.looping_enabled and hasattr(self, "loop_controller"):
318
- # Prelude
319
  x, probe_logits, caches, evo_loss, metrics = self._run_layers(
320
  x, self.prelude_start, self.prelude_end, caches,
321
  compute_logits=not self.training, labels=labels)
322
  total_evo_loss = total_evo_loss + evo_loss
323
  all_evo_metrics.extend(metrics)
324
 
325
- # Determine loop depth
326
  effective = num_loops
327
  if effective is None and not self.training and probe_logits is not None:
328
  effective = self.entropy_valve.get_loop_count(probe_logits)
329
  elif effective is None:
330
- # FIX: During training, use the loop_controller.loop_default directly
331
- # instead of running the loop classifier (which calls .item() and is
332
- # expensive). The ProgressiveLoopScheduler already sets loop_default.
333
  effective = self.loop_controller.loop_default
334
 
335
- # Loop body
336
  loop_fn = lambda inp: self._run_layers(
337
  inp, self.loop_start, self.loop_end, caches, labels=labels)[0]
338
  x = self.loop_controller(x, loop_fn, num_loops=effective)
339
 
340
- # Coda
341
  x, _, caches, evo_loss, metrics = self._run_layers(
342
  x, self.coda_start, self.coda_end, caches, labels=labels)
343
  total_evo_loss = total_evo_loss + evo_loss
@@ -349,30 +289,29 @@ class Chimera51ForCausalLM(nn.Module):
349
  total_evo_loss = total_evo_loss + evo_loss
350
  all_evo_metrics.extend(metrics)
351
 
352
- # Final norm and logits
353
  if logits_to_keep and labels is None:
354
  keep = int(logits_to_keep)
355
  tail = x[:, -keep:, :]
356
  tail = self.norm(tail)
357
- if self.span_engine is not None:
358
  tail = self.span_engine(tail)
359
  logits = self.lm_head(tail)
360
  else:
361
  x = self.norm(x)
362
- if self.span_engine is not None:
363
  x = self.span_engine(x)
364
  logits = self.lm_head(x)
365
 
366
- logits = self.grammar(logits)
367
- logits = self.debt_ledger(logits)
 
 
368
 
369
- # Self-feedback refinement check (inference only)
370
  if not self.training and self.evolution is not None:
371
  should_refine = self.evolution.self_feedback.should_refine(logits)
372
  if should_refine:
373
  all_evo_metrics.append({'refinement_triggered': True})
374
 
375
- # Compute loss
376
  loss = None
377
  if labels is not None:
378
  seq_len = min(logits.size(1), labels.size(1))
@@ -380,49 +319,38 @@ class Chimera51ForCausalLM(nn.Module):
380
  shift_labels = labels[:, :seq_len].contiguous()
381
  ce_loss = F.cross_entropy(
382
  shift_logits.view(-1, shift_logits.size(-1)),
383
- shift_labels.view(-1),
384
- ignore_index=-100,
385
- )
386
- # Add evolution loss (contrastive memory evaluation)
387
  loss = ce_loss + self.evo_weight * total_evo_loss
388
  else:
389
  ce_loss = None
390
 
391
- # Store episodic case after forward (for inference mode)
392
- if not self.training and self.evolution is not None:
393
- last_hidden = x[:, -1, :].detach()
394
-
395
  return CausalLMOutput(
396
- loss=loss,
397
- logits=logits,
398
- hidden_states=x,
399
  caches=caches if use_cache else None,
400
  evolution_metrics={
401
  'ce_loss': ce_loss.item() if ce_loss is not None else None,
402
  'evo_loss': total_evo_loss.item(),
403
  'layer_metrics': all_evo_metrics,
404
- } if return_evolution_metrics else None
405
- )
406
 
407
  @torch.no_grad()
408
- def prepare_for_inference(self) -> None:
409
- """Pre-pack every BitLinear so the first generation step is fast."""
410
  for module in self.modules():
411
  if isinstance(module, BitLinear):
412
  module.prepare_for_inference()
413
 
414
- def get_mode_config(self, mode: str = "balanced") -> dict:
415
  modes = self.config.get("modes", {})
416
  return modes.get(mode, modes.get("balanced", {}))
417
 
418
- def count_parameters(self) -> dict:
419
  total = sum(p.numel() for p in self.parameters())
420
  ternary = sum(p.numel() for _, m in self.named_modules()
421
  if isinstance(m, BitLinear) for p in m.parameters())
422
  return {"total": total, "ternary": ternary, "fp32": total - ternary}
423
 
424
  @classmethod
425
- def from_config_file(cls, path: str) -> "Chimera51ForCausalLM":
426
  with open(path, "r", encoding="utf-8") as fh:
427
  config = json.load(fh)
428
  return cls(config)
 
32
 
33
 
34
  class CausalLMOutput(dict):
35
+ def __init__(self, loss=None, logits=None, hidden_states=None,
36
+ caches=None, evolution_metrics=None):
37
+ super().__init__(loss=loss, logits=logits, hidden_states=hidden_states,
38
+ caches=caches, evolution_metrics=evolution_metrics)
 
 
 
 
 
 
39
  self.loss = loss
40
  self.logits = logits
41
  self.hidden_states = hidden_states
 
47
  yield self.logits
48
 
49
 
50
+ def expand_layer_pattern(config):
 
51
  backbone = config.get("backbone", {})
52
  pattern_str = backbone.get("layer_pattern", "GD XM GD TM GD XM GD SK")
53
  aliases = backbone.get("layer_aliases", {
 
61
 
62
 
63
  class Chimera51Block(nn.Module):
 
 
64
  _RECURRENT = {"gated_deltanet", "xlstm_m", "titans_mac", "tsp_span_knot"}
65
 
66
+ def __init__(self, config, layer_type, layer_idx, use_moe=False):
 
67
  super().__init__()
68
  h = int(config["hidden_size"])
69
  eps = float(config.get("rms_norm_eps", 1e-6))
 
77
  self.attn_norm = RMSNorm(h, eps=eps)
78
 
79
  if layer_type == "gated_deltanet":
80
+ self.attn = GatedDeltaNetLayer(h, heads, head_dim, norm_eps=eps, chunk_size=chunk_sz, use_ternary=ternary)
 
81
  elif layer_type == "xlstm_m":
82
  mem_h = config.get("xlstm", {}).get("memory_size_per_head", [head_dim, head_dim])
83
+ self.attn = MLSTMLayer(h, heads, int(mem_h[0]), norm_eps=eps, use_ternary=ternary)
 
84
  elif layer_type == "titans_mac":
85
  tc = config.get("titans", {})
86
+ self.attn = TitansMACLayer(h, heads, head_dim, memory_depth=int(tc.get("memory_depth", 2)),
 
87
  persistent_slots=int(tc.get("persistent_memory_slots", 64)),
88
  local_window=int(tc.get("local_window_size", 1024)),
89
  norm_eps=eps, use_ternary=ternary)
90
  elif layer_type == "tsp_span_knot":
91
+ self.attn = TSPSpanKnotLayer(h, heads, head_dim, norm_eps=eps, chunk_size=chunk_sz, use_ternary=ternary)
 
92
  else:
93
  raise ValueError(f"Unknown layer type: {layer_type}")
94
 
 
96
  self.use_moe = bool(use_moe)
97
  if self.use_moe:
98
  moe_cfg = config.get("backbone", {}).get("moe", {})
99
+ self.mlp = MoELayer(hidden_size=h,
 
100
  moe_intermediate_size=int(moe_cfg.get("moe_intermediate_size", h * 2)),
101
  n_routed_experts=int(moe_cfg.get("n_routed_experts", 16)),
102
  n_shared_experts=int(moe_cfg.get("n_shared_experts", 1)),
103
  num_experts_per_tok=int(moe_cfg.get("num_experts_per_tok", 2)),
104
+ use_ternary=ternary)
 
105
  else:
106
  inter = int(config.get("intermediate_size", int(h * 8 / 3)))
107
  inter = 256 * ((inter + 255) // 256)
108
  self.mlp = SwiGLUMLP(h, inter, use_ternary=ternary)
109
 
 
110
  self.evo_gate = nn.Linear(h, h, bias=False)
111
  nn.init.zeros_(self.evo_gate.weight)
112
 
113
+ def forward(self, x, cache=None, evo_modulation=None):
 
 
114
  normed = self.attn_norm(x)
115
  attn_out, new_cache = self.attn(normed, cache=cache)
116
  x = x + attn_out
 
 
117
  x = x + self.mlp(self.mlp_norm(x))
 
 
118
  if evo_modulation is not None:
119
  gate = torch.sigmoid(self.evo_gate(x))
120
  x = x + gate * evo_modulation
 
121
  return x, new_cache
122
 
123
 
124
  class Chimera51ForCausalLM(nn.Module):
125
+ def __init__(self, config):
 
 
126
  super().__init__()
127
  self.config = config
128
  h = int(config["hidden_size"])
 
133
  self.embed = nn.Embedding(vocab, h)
134
  layer_types = expand_layer_pattern(config)
135
  moe_layers = set(int(i) for i in config.get("backbone", {}).get("moe", {}).get("layers", []))
 
136
  self.layers = nn.ModuleList([
137
  Chimera51Block(config, layer_types[i], i, use_moe=(i in moe_layers))
138
  for i in range(n_layers)
139
  ])
 
140
  self.norm = RMSNorm(h, eps=eps)
141
  self.lm_head = nn.Linear(h, vocab, bias=False)
 
142
  if config.get("tie_word_embeddings", True):
143
  self.lm_head.weight = self.embed.weight
144
 
 
145
  loop_cfg = config.get("looping", {})
146
  self.looping_enabled = bool(loop_cfg.get("enabled", True)) and n_layers >= 3
147
  if self.looping_enabled:
 
151
  self.loop_controller = ParcaeLoopController(
152
  h, loop_range=tuple(loop_cfg.get("loop_range", [1, 6])),
153
  loop_default=int(loop_cfg.get("loop_default", 2)),
154
+ adaptive_exit_threshold=float(loop_cfg.get("adaptive_exit_threshold", 0.01)))
 
155
 
 
156
  si_cfg = config.get("span_inference", {})
157
  self.span_engine = SpanInferenceEngine(h, si_cfg) if si_cfg.get("enabled", True) else None
158
  self.grammar = GrammarFST(config.get("grammar", {}))
159
  self.entropy_valve = EntropyValve(config.get("entropy_valve", {}))
160
  self.debt_ledger = DebtLedger(config.get("debt_ledger", {}))
161
 
 
162
  evo_cfg = dict(config.get("self_evolution", {}))
163
  evo_cfg["_semantic_memory_config"] = config.get("semantic_memory", {})
164
  self.evolution = SelfEvolutionEngine(evo_cfg, h)
165
  self.evo_weight = float(config.get("evolution_loss_weight", 0.01))
166
  self.evo_every_n_layers = int(config.get("evolution_every_n_layers", 4))
167
 
 
168
  mm_cfg = dict(config.get("multimodal", {}))
169
  mm_cfg["hidden_size"] = h
170
  if mm_cfg.get("enabled", False):
 
178
  self._init_weights()
179
  self._wire_semantic_memory()
180
 
181
+ def enable_gradient_checkpointing(self):
182
  self.gradient_checkpointing = True
183
 
184
+ def disable_gradient_checkpointing(self):
185
  self.gradient_checkpointing = False
186
 
187
+ def _wire_semantic_memory(self):
188
  mem = self.evolution.semantic_memory
189
  for layer in self.layers:
190
  if hasattr(layer.attn, "set_semantic_memory"):
191
  layer.attn.set_semantic_memory(mem)
192
 
193
+ def _init_weights(self):
194
  init_range = float(self.config.get("initializer_range", 0.006))
195
  for module in self.modules():
196
  if isinstance(module, (nn.Linear, BitLinear)):
 
204
  if isinstance(module, BitLinear):
205
  module.invalidate_packed()
206
 
207
+ def _run_layers(self, x, start, end, caches, compute_logits=False, labels=None):
 
 
 
 
208
  all_metrics = []
209
  logits = None
210
  evolution_loss = torch.tensor(0.0, device=x.device)
 
212
  for i in range(start, min(end + 1, len(self.layers))):
213
  layer = self.layers[i]
214
  cache = caches[i] if caches is not None else None
 
 
215
  evo_mod = None
216
  if i % self.evo_every_n_layers == 0 and self.evolution is not None:
217
  evo_result = self.evolution(
218
  hidden_states=x.detach() if not x.requires_grad else x,
219
+ layer_idx=i, loss=None)
 
 
220
  evo_mod = evo_result['modulation']
221
  if evo_result['evolution_loss'] is not None:
222
  evolution_loss = evolution_loss + evo_result['evolution_loss']
223
  all_metrics.append(evo_result.get('metrics', {}))
 
 
224
  if self.training and evo_result.get('ttt_delta') is not None:
225
  with torch.no_grad():
226
  if hasattr(layer.mlp, 'w_down'):
 
236
  if caches is not None:
237
  caches[i] = new_cache
238
 
 
239
  if compute_logits and i == end:
240
  logits = self.lm_head(self.norm(x[:, -1:, :]))
241
 
242
  return x, logits, caches, evolution_loss, all_metrics
243
 
244
+ def forward(self, input_ids, labels=None, pixel_values=None,
245
+ mel_features=None, num_loops=None, caches=None,
246
+ use_cache=False, logits_to_keep=0, return_evolution_metrics=False):
 
 
 
 
 
 
247
  x = self.embed(input_ids)
248
 
 
249
  if pixel_values is not None and self.vision_encoder is not None:
250
  v = self.vision_encoder(pixel_values)
251
  if v is not None:
 
261
  total_evo_loss = torch.tensor(0.0, device=x.device)
262
  all_evo_metrics = []
263
 
 
264
  if self.looping_enabled and hasattr(self, "loop_controller"):
 
265
  x, probe_logits, caches, evo_loss, metrics = self._run_layers(
266
  x, self.prelude_start, self.prelude_end, caches,
267
  compute_logits=not self.training, labels=labels)
268
  total_evo_loss = total_evo_loss + evo_loss
269
  all_evo_metrics.extend(metrics)
270
 
 
271
  effective = num_loops
272
  if effective is None and not self.training and probe_logits is not None:
273
  effective = self.entropy_valve.get_loop_count(probe_logits)
274
  elif effective is None:
 
 
 
275
  effective = self.loop_controller.loop_default
276
 
 
277
  loop_fn = lambda inp: self._run_layers(
278
  inp, self.loop_start, self.loop_end, caches, labels=labels)[0]
279
  x = self.loop_controller(x, loop_fn, num_loops=effective)
280
 
 
281
  x, _, caches, evo_loss, metrics = self._run_layers(
282
  x, self.coda_start, self.coda_end, caches, labels=labels)
283
  total_evo_loss = total_evo_loss + evo_loss
 
289
  total_evo_loss = total_evo_loss + evo_loss
290
  all_evo_metrics.extend(metrics)
291
 
 
292
  if logits_to_keep and labels is None:
293
  keep = int(logits_to_keep)
294
  tail = x[:, -keep:, :]
295
  tail = self.norm(tail)
296
+ if self.span_engine is not None and not self.training:
297
  tail = self.span_engine(tail)
298
  logits = self.lm_head(tail)
299
  else:
300
  x = self.norm(x)
301
+ if self.span_engine is not None and not self.training:
302
  x = self.span_engine(x)
303
  logits = self.lm_head(x)
304
 
305
+ # Inference-only post-processing on 200K-dim logits — skip during training
306
+ if not self.training:
307
+ logits = self.grammar(logits)
308
+ logits = self.debt_ledger(logits)
309
 
 
310
  if not self.training and self.evolution is not None:
311
  should_refine = self.evolution.self_feedback.should_refine(logits)
312
  if should_refine:
313
  all_evo_metrics.append({'refinement_triggered': True})
314
 
 
315
  loss = None
316
  if labels is not None:
317
  seq_len = min(logits.size(1), labels.size(1))
 
319
  shift_labels = labels[:, :seq_len].contiguous()
320
  ce_loss = F.cross_entropy(
321
  shift_logits.view(-1, shift_logits.size(-1)),
322
+ shift_labels.view(-1), ignore_index=-100)
 
 
 
323
  loss = ce_loss + self.evo_weight * total_evo_loss
324
  else:
325
  ce_loss = None
326
 
 
 
 
 
327
  return CausalLMOutput(
328
+ loss=loss, logits=logits, hidden_states=x,
 
 
329
  caches=caches if use_cache else None,
330
  evolution_metrics={
331
  'ce_loss': ce_loss.item() if ce_loss is not None else None,
332
  'evo_loss': total_evo_loss.item(),
333
  'layer_metrics': all_evo_metrics,
334
+ } if return_evolution_metrics else None)
 
335
 
336
  @torch.no_grad()
337
+ def prepare_for_inference(self):
 
338
  for module in self.modules():
339
  if isinstance(module, BitLinear):
340
  module.prepare_for_inference()
341
 
342
+ def get_mode_config(self, mode="balanced"):
343
  modes = self.config.get("modes", {})
344
  return modes.get(mode, modes.get("balanced", {}))
345
 
346
+ def count_parameters(self):
347
  total = sum(p.numel() for p in self.parameters())
348
  ternary = sum(p.numel() for _, m in self.named_modules()
349
  if isinstance(m, BitLinear) for p in m.parameters())
350
  return {"total": total, "ternary": ternary, "fp32": total - ternary}
351
 
352
  @classmethod
353
+ def from_config_file(cls, path):
354
  with open(path, "r", encoding="utf-8") as fh:
355
  config = json.load(fh)
356
  return cls(config)