Lgr54HFi commited on
Commit
f1df870
·
verified ·
1 Parent(s): 33219af

Upload chimera/model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. chimera/model.py +150 -90
chimera/model.py CHANGED
@@ -1,22 +1,13 @@
1
  """
2
- Chimera 5.2 — full causal LM (CPU-first).
3
-
4
- Key improvements over the previous implementation:
5
-
6
- * Every recurrent block returns ``(out, cache)`` so the inference loop can
7
- carry per-layer state. This collapses generation latency from O(T²) to
8
- O(T) on CPU.
9
- * Looping mode now passes ``cache=None`` only on the *first* loop iteration
10
- for each step, so iterative refinement does not accidentally double-count
11
- past tokens.
12
- * The grammar/debt heads are real no-ops when their constraints are empty,
13
- meaning a freshly loaded model performs **one** ``F.linear`` for the LM
14
- head and that's it on the per-token path.
15
- * Vision/audio embeddings are now projected to ``hidden_size`` so the
16
- concatenation is dimensionally correct.
17
- * ``logits_to_keep`` short-circuits the final hidden norm to the last
18
- ``k`` tokens — the original code only sliced *before* ``norm`` was
19
- applied, wasting CPU cycles on positions we never used.
20
  """
21
 
22
  from __future__ import annotations
@@ -40,35 +31,30 @@ from .evolution import SelfEvolutionEngine
40
  from .multimodal import VisionEncoder, AudioEncoder
41
 
42
 
43
- # ---------------------------------------------------------------------------
44
- # Output container
45
- # ---------------------------------------------------------------------------
46
-
47
  class CausalLMOutput(dict):
48
  """Light HF-compatible output dict supporting tuple unpacking."""
49
 
50
  def __init__(self, loss: Optional[torch.Tensor] = None,
51
  logits: Optional[torch.Tensor] = None,
52
  hidden_states: Optional[torch.Tensor] = None,
53
- caches: Optional[list] = None):
 
54
  super().__init__(loss=loss, logits=logits,
55
- hidden_states=hidden_states, caches=caches)
 
56
  self.loss = loss
57
  self.logits = logits
58
  self.hidden_states = hidden_states
59
  self.caches = caches
 
60
 
61
  def __iter__(self):
62
  yield self.loss
63
  yield self.logits
64
 
65
 
66
- # ---------------------------------------------------------------------------
67
- # Layer expansion helper
68
- # ---------------------------------------------------------------------------
69
-
70
  def expand_layer_pattern(config: dict) -> List[str]:
71
- """Expand the layer-pattern shorthand (``"GD XM GD TM ..."``) into a list."""
72
  backbone = config.get("backbone", {})
73
  pattern_str = backbone.get("layer_pattern", "GD XM GD TM GD XM GD SK")
74
  aliases = backbone.get("layer_aliases", {
@@ -81,16 +67,8 @@ def expand_layer_pattern(config: dict) -> List[str]:
81
  return [aliases.get(p, p) for p in full]
82
 
83
 
84
- # ---------------------------------------------------------------------------
85
- # Single block: pre-norm attention/recurrence + pre-norm MLP/MoE
86
- # ---------------------------------------------------------------------------
87
-
88
  class Chimera51Block(nn.Module):
89
- """One transformer-style block of the trunk.
90
-
91
- ``forward`` accepts an optional ``cache`` and returns the updated cache
92
- so layers above can keep KV/state across decoder steps.
93
- """
94
 
95
  _RECURRENT = {"gated_deltanet", "xlstm_m", "titans_mac", "tsp_span_knot"}
96
 
@@ -104,6 +82,7 @@ class Chimera51Block(nn.Module):
104
  ternary = bool(config.get("use_ternary", True))
105
  chunk_sz = int(config.get("gated_deltanet", {}).get("chunk_size", 64))
106
 
 
107
  self.layer_type = layer_type
108
  self.attn_norm = RMSNorm(h, eps=eps)
109
 
@@ -144,20 +123,30 @@ class Chimera51Block(nn.Module):
144
  inter = 256 * ((inter + 255) // 256)
145
  self.mlp = SwiGLUMLP(h, inter, use_ternary=ternary)
146
 
147
- def forward(self, x: torch.Tensor, cache: Optional[dict] = None
148
- ) -> Tuple[torch.Tensor, dict]:
149
- attn_out, new_cache = self.attn(self.attn_norm(x), cache=cache)
 
 
 
 
 
 
150
  x = x + attn_out
 
 
151
  x = x + self.mlp(self.mlp_norm(x))
152
- return x, new_cache
153
 
 
 
 
 
 
 
154
 
155
- # ---------------------------------------------------------------------------
156
- # Full causal LM
157
- # ---------------------------------------------------------------------------
158
 
159
  class Chimera51ForCausalLM(nn.Module):
160
- """Chimera 5.x causal language model."""
161
 
162
  def __init__(self, config: dict):
163
  super().__init__()
@@ -182,7 +171,7 @@ class Chimera51ForCausalLM(nn.Module):
182
  if config.get("tie_word_embeddings", True):
183
  self.lm_head.weight = self.embed.weight
184
 
185
- # Parcae looping controller (only built when there are enough layers).
186
  loop_cfg = config.get("looping", {})
187
  self.looping_enabled = bool(loop_cfg.get("enabled", True)) and n_layers >= 3
188
  if self.looping_enabled:
@@ -195,20 +184,21 @@ class Chimera51ForCausalLM(nn.Module):
195
  adaptive_exit_threshold=float(loop_cfg.get("adaptive_exit_threshold", 0.01)),
196
  )
197
 
198
- # Inference systems.
199
  si_cfg = config.get("span_inference", {})
200
  self.span_engine = SpanInferenceEngine(h, si_cfg) if si_cfg.get("enabled", True) else None
201
  self.grammar = GrammarFST(config.get("grammar", {}))
202
  self.entropy_valve = EntropyValve(config.get("entropy_valve", {}))
203
  self.debt_ledger = DebtLedger(config.get("debt_ledger", {}))
204
 
205
- # Self-evolution.
206
  evo_cfg = dict(config.get("self_evolution", {}))
207
  evo_cfg["_semantic_memory_config"] = config.get("semantic_memory", {})
208
  self.evolution = SelfEvolutionEngine(evo_cfg, h)
 
 
209
 
210
- # Multimodal — projection happens inside the encoder so the output
211
- # already matches ``hidden_size``.
212
  mm_cfg = dict(config.get("multimodal", {}))
213
  mm_cfg["hidden_size"] = h
214
  if mm_cfg.get("enabled", False):
@@ -222,8 +212,6 @@ class Chimera51ForCausalLM(nn.Module):
222
  self._init_weights()
223
  self._wire_semantic_memory()
224
 
225
- # -- module lifecycle ------------------------------------------------------
226
-
227
  def enable_gradient_checkpointing(self) -> None:
228
  self.gradient_checkpointing = True
229
 
@@ -246,38 +234,61 @@ class Chimera51ForCausalLM(nn.Module):
246
  nn.init.zeros_(module.bias)
247
  elif isinstance(module, nn.Embedding):
248
  nn.init.normal_(module.weight, mean=0.0, std=init_range)
249
- # BitLinear caches need refreshing after init.
250
  for module in self.modules():
251
  if isinstance(module, BitLinear):
252
  module.invalidate_packed()
253
 
254
- # -- core forward ----------------------------------------------------------
255
-
256
  def _run_layers(self, x: torch.Tensor, start: int, end: int,
257
- caches: Optional[list]) -> torch.Tensor:
 
 
 
 
 
 
 
258
  for i in range(start, min(end + 1, len(self.layers))):
259
  layer = self.layers[i]
260
  cache = caches[i] if caches is not None else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  if self.gradient_checkpointing and self.training:
262
- # Wrap the layer in a tensor-only closure so PyTorch's
263
- # checkpoint helper can hash the inputs reliably. Caches
264
- # are not refreshed during gradient checkpointing — the
265
- # recurrent state is recomputed in the backward pass.
266
- def _ckpt_fn(x_in, layer=layer, cache=cache):
267
- out, _ = layer(x_in, cache=cache)
268
  return out
269
  x = checkpoint(_ckpt_fn, x, use_reentrant=False)
270
  else:
271
- x, new_cache = layer(x, cache=cache)
272
  if caches is not None:
273
  caches[i] = new_cache
274
- return x
275
 
276
- def _loop_fn_factory(self, caches: Optional[list]):
277
- """Capture caches for the loop controller's repeated invocations."""
278
- def loop_fn(x: torch.Tensor) -> torch.Tensor:
279
- return self._run_layers(x, self.loop_start, self.loop_end, caches)
280
- return loop_fn
281
 
282
  def forward(self, input_ids: torch.Tensor,
283
  labels: Optional[torch.Tensor] = None,
@@ -286,10 +297,11 @@ class Chimera51ForCausalLM(nn.Module):
286
  num_loops: Optional[int] = None,
287
  caches: Optional[list] = None,
288
  use_cache: bool = False,
289
- logits_to_keep: int = 0):
 
290
  x = self.embed(input_ids)
291
 
292
- # Multimodal prepend (encoders already project to hidden_size).
293
  if pixel_values is not None and self.vision_encoder is not None:
294
  v = self.vision_encoder(pixel_values)
295
  if v is not None:
@@ -299,25 +311,49 @@ class Chimera51ForCausalLM(nn.Module):
299
  if a is not None:
300
  x = torch.cat([a, x], dim=1)
301
 
302
- # Optional KV/state caches. ``use_cache`` is honoured even when the
303
- # caller didn't supply one.
304
  if caches is None and use_cache:
305
  caches = [None] * len(self.layers)
306
 
 
 
 
 
307
  if self.looping_enabled and hasattr(self, "loop_controller"):
308
- x = self._run_layers(x, self.prelude_start, self.prelude_end, caches)
 
 
 
 
 
 
 
309
  effective = num_loops
310
- if effective is None and not self.training:
311
- # Sample compute on the last token's logits only.
312
- probe = self.lm_head(self.norm(x[:, -1:, :]))
313
- effective = self.entropy_valve.get_loop_count(probe)
314
- x = self.loop_controller(x, self._loop_fn_factory(caches), num_loops=effective)
315
- x = self._run_layers(x, self.coda_start, self.coda_end, caches)
 
 
 
 
 
 
 
 
 
 
 
 
316
  else:
317
- x = self._run_layers(x, 0, len(self.layers) - 1, caches)
 
 
 
 
318
 
319
- # Slice to the relevant tail before allocating logits — the LM head is
320
- # the largest matmul on small models because vocab >> hidden_size.
321
  if logits_to_keep and labels is None:
322
  keep = int(logits_to_keep)
323
  tail = x[:, -keep:, :]
@@ -334,21 +370,45 @@ class Chimera51ForCausalLM(nn.Module):
334
  logits = self.grammar(logits)
335
  logits = self.debt_ledger(logits)
336
 
 
 
 
 
 
 
 
337
  loss = None
338
  if labels is not None:
339
  seq_len = min(logits.size(1), labels.size(1))
340
  shift_logits = logits[:, :seq_len, :].contiguous()
341
  shift_labels = labels[:, :seq_len].contiguous()
342
- loss = F.cross_entropy(
343
  shift_logits.view(-1, shift_logits.size(-1)),
344
  shift_labels.view(-1),
345
  ignore_index=-100,
346
  )
347
-
348
- return CausalLMOutput(loss=loss, logits=logits, hidden_states=x,
349
- caches=caches if use_cache else None)
350
-
351
- # -- utilities -------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
 
353
  @torch.no_grad()
354
  def prepare_for_inference(self) -> None:
 
1
  """
2
+ Chimera 5.2 — full causal LM with FUNCTIONAL self-evolution.
3
+
4
+ Key changes for auto-evolution:
5
+ * SelfEvolutionEngine is called at EVERY layer during forward pass
6
+ * Semantic memory modulation is added to hidden states
7
+ * TTT updates target MLP weights in-place during forward
8
+ * Evolution loss is added to causal LM loss during training
9
+ * Contrastive evaluation tracks memory usefulness
10
+ * Loop depth classifier sets compute budget per sequence
 
 
 
 
 
 
 
 
 
11
  """
12
 
13
  from __future__ import annotations
 
31
  from .multimodal import VisionEncoder, AudioEncoder
32
 
33
 
 
 
 
 
34
  class CausalLMOutput(dict):
35
  """Light HF-compatible output dict supporting tuple unpacking."""
36
 
37
  def __init__(self, loss: Optional[torch.Tensor] = None,
38
  logits: Optional[torch.Tensor] = None,
39
  hidden_states: Optional[torch.Tensor] = None,
40
+ caches: Optional[list] = None,
41
+ evolution_metrics: Optional[dict] = None):
42
  super().__init__(loss=loss, logits=logits,
43
+ hidden_states=hidden_states, caches=caches,
44
+ evolution_metrics=evolution_metrics)
45
  self.loss = loss
46
  self.logits = logits
47
  self.hidden_states = hidden_states
48
  self.caches = caches
49
+ self.evolution_metrics = evolution_metrics or {}
50
 
51
  def __iter__(self):
52
  yield self.loss
53
  yield self.logits
54
 
55
 
 
 
 
 
56
  def expand_layer_pattern(config: dict) -> List[str]:
57
+ """Expand the layer-pattern shorthand into a list."""
58
  backbone = config.get("backbone", {})
59
  pattern_str = backbone.get("layer_pattern", "GD XM GD TM GD XM GD SK")
60
  aliases = backbone.get("layer_aliases", {
 
67
  return [aliases.get(p, p) for p in full]
68
 
69
 
 
 
 
 
70
  class Chimera51Block(nn.Module):
71
+ """One block with evolution-aware forward."""
 
 
 
 
72
 
73
  _RECURRENT = {"gated_deltanet", "xlstm_m", "titans_mac", "tsp_span_knot"}
74
 
 
82
  ternary = bool(config.get("use_ternary", True))
83
  chunk_sz = int(config.get("gated_deltanet", {}).get("chunk_size", 64))
84
 
85
+ self.layer_idx = layer_idx
86
  self.layer_type = layer_type
87
  self.attn_norm = RMSNorm(h, eps=eps)
88
 
 
123
  inter = 256 * ((inter + 255) // 256)
124
  self.mlp = SwiGLUMLP(h, inter, use_ternary=ternary)
125
 
126
+ # Evolution modulation projection (learnable scale)
127
+ self.evo_gate = nn.Linear(h, h, bias=False)
128
+ nn.init.zeros_(self.evo_gate.weight)
129
+
130
+ def forward(self, x: torch.Tensor, cache: Optional[dict] = None,
131
+ evo_modulation: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, dict]:
132
+ # Apply attention with pre-norm
133
+ normed = self.attn_norm(x)
134
+ attn_out, new_cache = self.attn(normed, cache=cache)
135
  x = x + attn_out
136
+
137
+ # Apply MLP with pre-norm
138
  x = x + self.mlp(self.mlp_norm(x))
 
139
 
140
+ # Apply evolution modulation (gated residual)
141
+ if evo_modulation is not None:
142
+ gate = torch.sigmoid(self.evo_gate(x))
143
+ x = x + gate * evo_modulation
144
+
145
+ return x, new_cache
146
 
 
 
 
147
 
148
  class Chimera51ForCausalLM(nn.Module):
149
+ """Chimera 5.x causal language model with functional self-evolution."""
150
 
151
  def __init__(self, config: dict):
152
  super().__init__()
 
171
  if config.get("tie_word_embeddings", True):
172
  self.lm_head.weight = self.embed.weight
173
 
174
+ # Parcae looping controller
175
  loop_cfg = config.get("looping", {})
176
  self.looping_enabled = bool(loop_cfg.get("enabled", True)) and n_layers >= 3
177
  if self.looping_enabled:
 
184
  adaptive_exit_threshold=float(loop_cfg.get("adaptive_exit_threshold", 0.01)),
185
  )
186
 
187
+ # Inference systems
188
  si_cfg = config.get("span_inference", {})
189
  self.span_engine = SpanInferenceEngine(h, si_cfg) if si_cfg.get("enabled", True) else None
190
  self.grammar = GrammarFST(config.get("grammar", {}))
191
  self.entropy_valve = EntropyValve(config.get("entropy_valve", {}))
192
  self.debt_ledger = DebtLedger(config.get("debt_ledger", {}))
193
 
194
+ # Self-evolution — FUNCTIONAL
195
  evo_cfg = dict(config.get("self_evolution", {}))
196
  evo_cfg["_semantic_memory_config"] = config.get("semantic_memory", {})
197
  self.evolution = SelfEvolutionEngine(evo_cfg, h)
198
+ self.evo_weight = float(config.get("evolution_loss_weight", 0.01))
199
+ self.evo_every_n_layers = int(config.get("evolution_every_n_layers", 4))
200
 
201
+ # Multimodal
 
202
  mm_cfg = dict(config.get("multimodal", {}))
203
  mm_cfg["hidden_size"] = h
204
  if mm_cfg.get("enabled", False):
 
212
  self._init_weights()
213
  self._wire_semantic_memory()
214
 
 
 
215
  def enable_gradient_checkpointing(self) -> None:
216
  self.gradient_checkpointing = True
217
 
 
234
  nn.init.zeros_(module.bias)
235
  elif isinstance(module, nn.Embedding):
236
  nn.init.normal_(module.weight, mean=0.0, std=init_range)
 
237
  for module in self.modules():
238
  if isinstance(module, BitLinear):
239
  module.invalidate_packed()
240
 
 
 
241
  def _run_layers(self, x: torch.Tensor, start: int, end: int,
242
+ caches: Optional[list],
243
+ compute_logits: bool = False,
244
+ labels: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor], list]:
245
+ """Run layers with evolution hooks. Returns (x, logits_if_computed, caches)."""
246
+ all_metrics = []
247
+ logits = None
248
+ evolution_loss = torch.tensor(0.0, device=x.device)
249
+
250
  for i in range(start, min(end + 1, len(self.layers))):
251
  layer = self.layers[i]
252
  cache = caches[i] if caches is not None else None
253
+
254
+ # Evolution modulation every N layers (lightweight)
255
+ evo_mod = None
256
+ if i % self.evo_every_n_layers == 0 and self.evolution is not None:
257
+ # Compute modulation from semantic memory
258
+ # Note: loss parameter requires a scalar loss tensor for TTT/surprise;
259
+ # pass None during standard forward, compute explicitly for TTT
260
+ evo_result = self.evolution(
261
+ hidden_states=x.detach() if not x.requires_grad else x,
262
+ layer_idx=i,
263
+ loss=None
264
+ )
265
+ evo_mod = evo_result['modulation']
266
+ if evo_result['evolution_loss'] is not None:
267
+ evolution_loss = evolution_loss + evo_result['evolution_loss']
268
+ all_metrics.append(evo_result.get('metrics', {}))
269
+
270
+ # TTT update for target layers (only in training, no backprop)
271
+ if self.training and evo_result.get('ttt_delta') is not None:
272
+ with torch.no_grad():
273
+ # Apply TTT to MLP down-projection if this is a target layer
274
+ if hasattr(layer.mlp, 'w_down'):
275
+ layer.mlp.w_down.data.add_(evo_result['ttt_delta'] * self.evolution.ttt.inner_lr)
276
+
277
  if self.gradient_checkpointing and self.training:
278
+ def _ckpt_fn(x_in, layer=layer, cache=cache, evo=evo_mod):
279
+ out, _ = layer(x_in, cache=cache, evo_modulation=evo)
 
 
 
 
280
  return out
281
  x = checkpoint(_ckpt_fn, x, use_reentrant=False)
282
  else:
283
+ x, new_cache = layer(x, cache=cache, evo_modulation=evo_mod)
284
  if caches is not None:
285
  caches[i] = new_cache
 
286
 
287
+ # Compute probe logits for entropy valve (every few layers)
288
+ if compute_logits and i == end:
289
+ logits = self.lm_head(self.norm(x[:, -1:, :]))
290
+
291
+ return x, logits, caches, evolution_loss, all_metrics
292
 
293
  def forward(self, input_ids: torch.Tensor,
294
  labels: Optional[torch.Tensor] = None,
 
297
  num_loops: Optional[int] = None,
298
  caches: Optional[list] = None,
299
  use_cache: bool = False,
300
+ logits_to_keep: int = 0,
301
+ return_evolution_metrics: bool = False):
302
  x = self.embed(input_ids)
303
 
304
+ # Multimodal prepend
305
  if pixel_values is not None and self.vision_encoder is not None:
306
  v = self.vision_encoder(pixel_values)
307
  if v is not None:
 
311
  if a is not None:
312
  x = torch.cat([a, x], dim=1)
313
 
 
 
314
  if caches is None and use_cache:
315
  caches = [None] * len(self.layers)
316
 
317
+ total_evo_loss = torch.tensor(0.0, device=x.device)
318
+ all_evo_metrics = []
319
+
320
+ # Prelude + Loop + Coda with evolution
321
  if self.looping_enabled and hasattr(self, "loop_controller"):
322
+ # Prelude
323
+ x, probe_logits, caches, evo_loss, metrics = self._run_layers(
324
+ x, self.prelude_start, self.prelude_end, caches,
325
+ compute_logits=not self.training, labels=labels)
326
+ total_evo_loss = total_evo_loss + evo_loss
327
+ all_evo_metrics.extend(metrics)
328
+
329
+ # Determine loop depth
330
  effective = num_loops
331
+ if effective is None and not self.training and probe_logits is not None:
332
+ effective = self.entropy_valve.get_loop_count(probe_logits)
333
+ elif effective is None and self.evolution is not None:
334
+ # Use loop classifier from evolution
335
+ last_hidden = x[:, -1, :].mean(dim=0, keepdim=True) # Average over batch
336
+ effective = self.evolution.loop_classifier(last_hidden).item()
337
+ effective = max(1, min(effective, 6))
338
+
339
+ # Loop body
340
+ loop_fn = lambda inp: self._run_layers(
341
+ inp, self.loop_start, self.loop_end, caches, labels=labels)[0]
342
+ x = self.loop_controller(x, loop_fn, num_loops=effective)
343
+
344
+ # Coda
345
+ x, _, caches, evo_loss, metrics = self._run_layers(
346
+ x, self.coda_start, self.coda_end, caches, labels=labels)
347
+ total_evo_loss = total_evo_loss + evo_loss
348
+ all_evo_metrics.extend(metrics)
349
  else:
350
+ x, _, caches, evo_loss, metrics = self._run_layers(
351
+ x, 0, len(self.layers) - 1, caches,
352
+ compute_logits=not self.training, labels=labels)
353
+ total_evo_loss = total_evo_loss + evo_loss
354
+ all_evo_metrics.extend(metrics)
355
 
356
+ # Final norm and logits
 
357
  if logits_to_keep and labels is None:
358
  keep = int(logits_to_keep)
359
  tail = x[:, -keep:, :]
 
370
  logits = self.grammar(logits)
371
  logits = self.debt_ledger(logits)
372
 
373
+ # Self-feedback refinement check (inference only)
374
+ if not self.training and self.evolution is not None:
375
+ should_refine = self.evolution.self_feedback.should_refine(logits)
376
+ if should_refine:
377
+ all_evo_metrics.append({'refinement_triggered': True})
378
+
379
+ # Compute loss
380
  loss = None
381
  if labels is not None:
382
  seq_len = min(logits.size(1), labels.size(1))
383
  shift_logits = logits[:, :seq_len, :].contiguous()
384
  shift_labels = labels[:, :seq_len].contiguous()
385
+ ce_loss = F.cross_entropy(
386
  shift_logits.view(-1, shift_logits.size(-1)),
387
  shift_labels.view(-1),
388
  ignore_index=-100,
389
  )
390
+ # Add evolution loss (contrastive memory evaluation)
391
+ loss = ce_loss + self.evo_weight * total_evo_loss
392
+ else:
393
+ ce_loss = None
394
+
395
+ # Store episodic case after forward (for inference mode)
396
+ if not self.training and self.evolution is not None:
397
+ last_hidden = x[:, -1, :].detach()
398
+ # Schedule episodic storage for end of sequence
399
+ # (In real use, call model.evolution.store_episodic() explicitly)
400
+
401
+ return CausalLMOutput(
402
+ loss=loss,
403
+ logits=logits,
404
+ hidden_states=x,
405
+ caches=caches if use_cache else None,
406
+ evolution_metrics={
407
+ 'ce_loss': ce_loss.item() if ce_loss is not None else None,
408
+ 'evo_loss': total_evo_loss.item(),
409
+ 'layer_metrics': all_evo_metrics,
410
+ } if return_evolution_metrics else None
411
+ )
412
 
413
  @torch.no_grad()
414
  def prepare_for_inference(self) -> None: