| """ |
| Chimera 5.2 — full causal LM with FUNCTIONAL self-evolution. |
| |
| Key changes for auto-evolution: |
| * SelfEvolutionEngine is called at EVERY layer during forward pass |
| * Semantic memory modulation is added to hidden states |
| * TTT updates target MLP weights in-place during forward |
| * Evolution loss is added to causal LM loss during training |
| * Contrastive evaluation tracks memory usefulness |
| * Loop depth classifier sets compute budget per sequence |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| from typing import Any, List, Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.checkpoint import checkpoint |
|
|
| from .quantization import BitLinear, RMSNorm |
| from .layers import (GatedDeltaNetLayer, MLSTMLayer, TitansMACLayer, |
| TSPSpanKnotLayer, SwiGLUMLP) |
| from .moe import MoELayer |
| from .looping import ParcaeLoopController |
| from .inference import (SpanInferenceEngine, GrammarFST, EntropyValve, |
| DebtLedger, BraidState) |
| from .evolution import SelfEvolutionEngine |
| from .multimodal import VisionEncoder, AudioEncoder |
|
|
|
|
| class CausalLMOutput(dict): |
| """Light HF-compatible output dict supporting tuple unpacking.""" |
|
|
| def __init__(self, loss: Optional[torch.Tensor] = None, |
| logits: Optional[torch.Tensor] = None, |
| hidden_states: Optional[torch.Tensor] = None, |
| caches: Optional[list] = None, |
| evolution_metrics: Optional[dict] = None): |
| super().__init__(loss=loss, logits=logits, |
| hidden_states=hidden_states, caches=caches, |
| evolution_metrics=evolution_metrics) |
| self.loss = loss |
| self.logits = logits |
| self.hidden_states = hidden_states |
| self.caches = caches |
| self.evolution_metrics = evolution_metrics or {} |
|
|
| def __iter__(self): |
| yield self.loss |
| yield self.logits |
|
|
|
|
| def expand_layer_pattern(config: dict) -> List[str]: |
| """Expand the layer-pattern shorthand into a list.""" |
| backbone = config.get("backbone", {}) |
| pattern_str = backbone.get("layer_pattern", "GD XM GD TM GD XM GD SK") |
| aliases = backbone.get("layer_aliases", { |
| "GD": "gated_deltanet", "XM": "xlstm_m", |
| "TM": "titans_mac", "SK": "tsp_span_knot", |
| }) |
| pattern = pattern_str.split() |
| n_layers = int(config.get("num_hidden_layers", 28)) |
| full = (pattern * (n_layers // len(pattern) + 1))[:n_layers] |
| return [aliases.get(p, p) for p in full] |
|
|
|
|
| class Chimera51Block(nn.Module): |
| """One block with evolution-aware forward.""" |
|
|
| _RECURRENT = {"gated_deltanet", "xlstm_m", "titans_mac", "tsp_span_knot"} |
|
|
| def __init__(self, config: dict, layer_type: str, layer_idx: int, |
| use_moe: bool = False): |
| super().__init__() |
| h = int(config["hidden_size"]) |
| eps = float(config.get("rms_norm_eps", 1e-6)) |
| heads = int(config["num_heads"]) |
| head_dim = int(config["head_dim"]) |
| ternary = bool(config.get("use_ternary", True)) |
| chunk_sz = int(config.get("gated_deltanet", {}).get("chunk_size", 64)) |
|
|
| self.layer_idx = layer_idx |
| self.layer_type = layer_type |
| self.attn_norm = RMSNorm(h, eps=eps) |
|
|
| if layer_type == "gated_deltanet": |
| self.attn = GatedDeltaNetLayer(h, heads, head_dim, norm_eps=eps, |
| chunk_size=chunk_sz, use_ternary=ternary) |
| elif layer_type == "xlstm_m": |
| mem_h = config.get("xlstm", {}).get("memory_size_per_head", [head_dim, head_dim]) |
| self.attn = MLSTMLayer(h, heads, int(mem_h[0]), norm_eps=eps, |
| use_ternary=ternary) |
| elif layer_type == "titans_mac": |
| tc = config.get("titans", {}) |
| self.attn = TitansMACLayer(h, heads, head_dim, |
| memory_depth=int(tc.get("memory_depth", 2)), |
| persistent_slots=int(tc.get("persistent_memory_slots", 64)), |
| local_window=int(tc.get("local_window_size", 1024)), |
| norm_eps=eps, use_ternary=ternary) |
| elif layer_type == "tsp_span_knot": |
| self.attn = TSPSpanKnotLayer(h, heads, head_dim, norm_eps=eps, |
| chunk_size=chunk_sz, use_ternary=ternary) |
| else: |
| raise ValueError(f"Unknown layer type: {layer_type}") |
|
|
| self.mlp_norm = RMSNorm(h, eps=eps) |
| self.use_moe = bool(use_moe) |
| if self.use_moe: |
| moe_cfg = config.get("backbone", {}).get("moe", {}) |
| self.mlp = MoELayer( |
| hidden_size=h, |
| moe_intermediate_size=int(moe_cfg.get("moe_intermediate_size", h * 2)), |
| n_routed_experts=int(moe_cfg.get("n_routed_experts", 16)), |
| n_shared_experts=int(moe_cfg.get("n_shared_experts", 1)), |
| num_experts_per_tok=int(moe_cfg.get("num_experts_per_tok", 2)), |
| use_ternary=ternary, |
| ) |
| else: |
| inter = int(config.get("intermediate_size", int(h * 8 / 3))) |
| inter = 256 * ((inter + 255) // 256) |
| self.mlp = SwiGLUMLP(h, inter, use_ternary=ternary) |
|
|
| |
| self.evo_gate = nn.Linear(h, h, bias=False) |
| nn.init.zeros_(self.evo_gate.weight) |
|
|
| def forward(self, x: torch.Tensor, cache: Optional[dict] = None, |
| evo_modulation: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, dict]: |
| |
| normed = self.attn_norm(x) |
| attn_out, new_cache = self.attn(normed, cache=cache) |
| x = x + attn_out |
|
|
| |
| x = x + self.mlp(self.mlp_norm(x)) |
|
|
| |
| if evo_modulation is not None: |
| gate = torch.sigmoid(self.evo_gate(x)) |
| x = x + gate * evo_modulation |
|
|
| return x, new_cache |
|
|
|
|
| class Chimera51ForCausalLM(nn.Module): |
| """Chimera 5.x causal language model with functional self-evolution.""" |
|
|
| def __init__(self, config: dict): |
| super().__init__() |
| self.config = config |
| h = int(config["hidden_size"]) |
| vocab = int(config["vocab_size"]) |
| n_layers = int(config["num_hidden_layers"]) |
| eps = float(config.get("rms_norm_eps", 1e-6)) |
|
|
| self.embed = nn.Embedding(vocab, h) |
| layer_types = expand_layer_pattern(config) |
| moe_layers = set(int(i) for i in config.get("backbone", {}).get("moe", {}).get("layers", [])) |
|
|
| self.layers = nn.ModuleList([ |
| Chimera51Block(config, layer_types[i], i, use_moe=(i in moe_layers)) |
| for i in range(n_layers) |
| ]) |
|
|
| self.norm = RMSNorm(h, eps=eps) |
| self.lm_head = nn.Linear(h, vocab, bias=False) |
|
|
| if config.get("tie_word_embeddings", True): |
| self.lm_head.weight = self.embed.weight |
|
|
| |
| loop_cfg = config.get("looping", {}) |
| self.looping_enabled = bool(loop_cfg.get("enabled", True)) and n_layers >= 3 |
| if self.looping_enabled: |
| self.prelude_start, self.prelude_end = loop_cfg.get("prelude", [0, min(3, n_layers - 1)]) |
| self.loop_start, self.loop_end = loop_cfg.get("loop", [min(4, n_layers - 1), max(4, n_layers - 4)]) |
| self.coda_start, self.coda_end = loop_cfg.get("coda", [max(0, n_layers - 4), n_layers - 1]) |
| self.loop_controller = ParcaeLoopController( |
| h, loop_range=tuple(loop_cfg.get("loop_range", [1, 6])), |
| loop_default=int(loop_cfg.get("loop_default", 2)), |
| adaptive_exit_threshold=float(loop_cfg.get("adaptive_exit_threshold", 0.01)), |
| ) |
|
|
| |
| si_cfg = config.get("span_inference", {}) |
| self.span_engine = SpanInferenceEngine(h, si_cfg) if si_cfg.get("enabled", True) else None |
| self.grammar = GrammarFST(config.get("grammar", {})) |
| self.entropy_valve = EntropyValve(config.get("entropy_valve", {})) |
| self.debt_ledger = DebtLedger(config.get("debt_ledger", {})) |
|
|
| |
| evo_cfg = dict(config.get("self_evolution", {})) |
| evo_cfg["_semantic_memory_config"] = config.get("semantic_memory", {}) |
| self.evolution = SelfEvolutionEngine(evo_cfg, h) |
| self.evo_weight = float(config.get("evolution_loss_weight", 0.01)) |
| self.evo_every_n_layers = int(config.get("evolution_every_n_layers", 4)) |
|
|
| |
| mm_cfg = dict(config.get("multimodal", {})) |
| mm_cfg["hidden_size"] = h |
| if mm_cfg.get("enabled", False): |
| self.vision_encoder = VisionEncoder(mm_cfg) |
| self.audio_encoder = AudioEncoder(mm_cfg) |
| else: |
| self.vision_encoder = None |
| self.audio_encoder = None |
|
|
| self.gradient_checkpointing = False |
| self._init_weights() |
| self._wire_semantic_memory() |
|
|
| def enable_gradient_checkpointing(self) -> None: |
| self.gradient_checkpointing = True |
|
|
| def disable_gradient_checkpointing(self) -> None: |
| self.gradient_checkpointing = False |
|
|
| def _wire_semantic_memory(self) -> None: |
| mem = self.evolution.semantic_memory |
| for layer in self.layers: |
| if hasattr(layer.attn, "set_semantic_memory"): |
| layer.attn.set_semantic_memory(mem) |
|
|
| def _init_weights(self) -> None: |
| init_range = float(self.config.get("initializer_range", 0.006)) |
| for module in self.modules(): |
| if isinstance(module, (nn.Linear, BitLinear)): |
| if module.weight is not None: |
| nn.init.normal_(module.weight, mean=0.0, std=init_range) |
| if getattr(module, "bias", None) is not None: |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Embedding): |
| nn.init.normal_(module.weight, mean=0.0, std=init_range) |
| for module in self.modules(): |
| if isinstance(module, BitLinear): |
| module.invalidate_packed() |
|
|
| def _run_layers(self, x: torch.Tensor, start: int, end: int, |
| caches: Optional[list], |
| compute_logits: bool = False, |
| labels: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor], list]: |
| """Run layers with evolution hooks. Returns (x, logits_if_computed, caches).""" |
| all_metrics = [] |
| logits = None |
| evolution_loss = torch.tensor(0.0, device=x.device) |
|
|
| for i in range(start, min(end + 1, len(self.layers))): |
| layer = self.layers[i] |
| cache = caches[i] if caches is not None else None |
|
|
| |
| evo_mod = None |
| if i % self.evo_every_n_layers == 0 and self.evolution is not None: |
| |
| |
| |
| evo_result = self.evolution( |
| hidden_states=x.detach() if not x.requires_grad else x, |
| layer_idx=i, |
| loss=None |
| ) |
| evo_mod = evo_result['modulation'] |
| if evo_result['evolution_loss'] is not None: |
| evolution_loss = evolution_loss + evo_result['evolution_loss'] |
| all_metrics.append(evo_result.get('metrics', {})) |
|
|
| |
| if self.training and evo_result.get('ttt_delta') is not None: |
| with torch.no_grad(): |
| |
| if hasattr(layer.mlp, 'w_down'): |
| layer.mlp.w_down.data.add_(evo_result['ttt_delta'] * self.evolution.ttt.inner_lr) |
|
|
| if self.gradient_checkpointing and self.training: |
| def _ckpt_fn(x_in, layer=layer, cache=cache, evo=evo_mod): |
| out, _ = layer(x_in, cache=cache, evo_modulation=evo) |
| return out |
| x = checkpoint(_ckpt_fn, x, use_reentrant=False) |
| else: |
| x, new_cache = layer(x, cache=cache, evo_modulation=evo_mod) |
| if caches is not None: |
| caches[i] = new_cache |
|
|
| |
| if compute_logits and i == end: |
| logits = self.lm_head(self.norm(x[:, -1:, :])) |
|
|
| return x, logits, caches, evolution_loss, all_metrics |
|
|
| def forward(self, input_ids: torch.Tensor, |
| labels: Optional[torch.Tensor] = None, |
| pixel_values: Optional[torch.Tensor] = None, |
| mel_features: Optional[torch.Tensor] = None, |
| num_loops: Optional[int] = None, |
| caches: Optional[list] = None, |
| use_cache: bool = False, |
| logits_to_keep: int = 0, |
| return_evolution_metrics: bool = False): |
| x = self.embed(input_ids) |
|
|
| |
| if pixel_values is not None and self.vision_encoder is not None: |
| v = self.vision_encoder(pixel_values) |
| if v is not None: |
| x = torch.cat([v, x], dim=1) |
| if mel_features is not None and self.audio_encoder is not None: |
| a = self.audio_encoder(mel_features) |
| if a is not None: |
| x = torch.cat([a, x], dim=1) |
|
|
| if caches is None and use_cache: |
| caches = [None] * len(self.layers) |
|
|
| total_evo_loss = torch.tensor(0.0, device=x.device) |
| all_evo_metrics = [] |
|
|
| |
| if self.looping_enabled and hasattr(self, "loop_controller"): |
| |
| x, probe_logits, caches, evo_loss, metrics = self._run_layers( |
| x, self.prelude_start, self.prelude_end, caches, |
| compute_logits=not self.training, labels=labels) |
| total_evo_loss = total_evo_loss + evo_loss |
| all_evo_metrics.extend(metrics) |
|
|
| |
| effective = num_loops |
| if effective is None and not self.training and probe_logits is not None: |
| effective = self.entropy_valve.get_loop_count(probe_logits) |
| elif effective is None and self.evolution is not None: |
| |
| last_hidden = x[:, -1, :].mean(dim=0, keepdim=True) |
| effective = self.evolution.loop_classifier(last_hidden).item() |
| effective = max(1, min(effective, 6)) |
|
|
| |
| loop_fn = lambda inp: self._run_layers( |
| inp, self.loop_start, self.loop_end, caches, labels=labels)[0] |
| x = self.loop_controller(x, loop_fn, num_loops=effective) |
|
|
| |
| x, _, caches, evo_loss, metrics = self._run_layers( |
| x, self.coda_start, self.coda_end, caches, labels=labels) |
| total_evo_loss = total_evo_loss + evo_loss |
| all_evo_metrics.extend(metrics) |
| else: |
| x, _, caches, evo_loss, metrics = self._run_layers( |
| x, 0, len(self.layers) - 1, caches, |
| compute_logits=not self.training, labels=labels) |
| total_evo_loss = total_evo_loss + evo_loss |
| all_evo_metrics.extend(metrics) |
|
|
| |
| if logits_to_keep and labels is None: |
| keep = int(logits_to_keep) |
| tail = x[:, -keep:, :] |
| tail = self.norm(tail) |
| if self.span_engine is not None: |
| tail = self.span_engine(tail) |
| logits = self.lm_head(tail) |
| else: |
| x = self.norm(x) |
| if self.span_engine is not None: |
| x = self.span_engine(x) |
| logits = self.lm_head(x) |
|
|
| logits = self.grammar(logits) |
| logits = self.debt_ledger(logits) |
|
|
| |
| if not self.training and self.evolution is not None: |
| should_refine = self.evolution.self_feedback.should_refine(logits) |
| if should_refine: |
| all_evo_metrics.append({'refinement_triggered': True}) |
|
|
| |
| loss = None |
| if labels is not None: |
| seq_len = min(logits.size(1), labels.size(1)) |
| shift_logits = logits[:, :seq_len, :].contiguous() |
| shift_labels = labels[:, :seq_len].contiguous() |
| ce_loss = F.cross_entropy( |
| shift_logits.view(-1, shift_logits.size(-1)), |
| shift_labels.view(-1), |
| ignore_index=-100, |
| ) |
| |
| loss = ce_loss + self.evo_weight * total_evo_loss |
| else: |
| ce_loss = None |
|
|
| |
| if not self.training and self.evolution is not None: |
| last_hidden = x[:, -1, :].detach() |
| |
| |
|
|
| return CausalLMOutput( |
| loss=loss, |
| logits=logits, |
| hidden_states=x, |
| caches=caches if use_cache else None, |
| evolution_metrics={ |
| 'ce_loss': ce_loss.item() if ce_loss is not None else None, |
| 'evo_loss': total_evo_loss.item(), |
| 'layer_metrics': all_evo_metrics, |
| } if return_evolution_metrics else None |
| ) |
|
|
| @torch.no_grad() |
| def prepare_for_inference(self) -> None: |
| """Pre-pack every BitLinear so the first generation step is fast.""" |
| for module in self.modules(): |
| if isinstance(module, BitLinear): |
| module.prepare_for_inference() |
|
|
| def get_mode_config(self, mode: str = "balanced") -> dict: |
| modes = self.config.get("modes", {}) |
| return modes.get(mode, modes.get("balanced", {})) |
|
|
| def count_parameters(self) -> dict: |
| total = sum(p.numel() for p in self.parameters()) |
| ternary = sum(p.numel() for _, m in self.named_modules() |
| if isinstance(m, BitLinear) for p in m.parameters()) |
| return {"total": total, "ternary": ternary, "fp32": total - ternary} |
|
|
| @classmethod |
| def from_config_file(cls, path: str) -> "Chimera51ForCausalLM": |
| with open(path, "r", encoding="utf-8") as fh: |
| config = json.load(fh) |
| return cls(config) |
|
|
|
|
| __all__ = ["Chimera51ForCausalLM", "Chimera51Block", "CausalLMOutput", |
| "expand_layer_pattern"] |
|
|