""" Chimera 5.2 — multimodal encoders (CPU-friendly, slim). The previous draft had two latent issues: * The vision/audio encoders projected to ``out_dim`` (e.g. 2560) which did not match the trunk's ``hidden_size`` after scaling, so concatenating image embeddings into the LM hidden stream blew up. We now project to the trunk's hidden size by default. * The internal ``_EncoderBlock`` wrapped a recurrent layer expecting a ``cache`` argument; we now call the layer correctly and discard the cache (the encoder is purely parallel). The encoders themselves remain BitLinear-friendly so they share the ternary memory budget of the trunk. """ from __future__ import annotations from typing import Optional import torch import torch.nn as nn from torch.utils.checkpoint import checkpoint from .layers import GatedDeltaNetLayer from .quantization import BitLinear, RMSNorm def _make_linear(use_ternary: bool): if use_ternary: return BitLinear return lambda i, o, **kw: nn.Linear(i, o, bias=False) class PatchEmbed(nn.Module): __constants__ = ["patch_size"] def __init__(self, patch_size: int = 16, in_channels: int = 3, hidden_size: int = 384): super().__init__() self.patch_size = int(patch_size) self.proj = nn.Conv2d(in_channels, hidden_size, kernel_size=self.patch_size, stride=self.patch_size) self.norm = RMSNorm(hidden_size) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.proj(x) x = x.flatten(2).transpose(1, 2) return self.norm(x) class _EncoderBlock(nn.Module): def __init__(self, hidden: int, num_heads: int, head_dim: int, use_ternary: bool = True): super().__init__() self.norm = RMSNorm(hidden) self.attn = GatedDeltaNetLayer(hidden, num_heads, head_dim, use_ternary=use_ternary, chunk_size=64) self.mlp_norm = RMSNorm(hidden) L = _make_linear(use_ternary) self.mlp = nn.Sequential(L(hidden, hidden * 4), nn.GELU(), L(hidden * 4, hidden)) def forward(self, x: torch.Tensor) -> torch.Tensor: attn_out, _ = self.attn(self.norm(x)) x = x + attn_out return x + self.mlp(self.mlp_norm(x)) class _EncoderBase(nn.Module): """Shared encoder body for vision/audio.""" def __init__(self, hidden: int, depth: int, num_heads: int, head_dim: int, out_dim: int, use_ternary: bool, use_checkpoint: bool): super().__init__() self.layers = nn.ModuleList([ _EncoderBlock(hidden, num_heads, head_dim, use_ternary) for _ in range(depth) ]) self.proj = nn.Linear(hidden, out_dim, bias=False) self.norm = RMSNorm(out_dim) self.use_checkpoint = bool(use_checkpoint) def forward(self, x: torch.Tensor) -> torch.Tensor: for layer in self.layers: if self.use_checkpoint and self.training: x = checkpoint(layer, x, use_reentrant=False) else: x = layer(x) return self.norm(self.proj(x)) class VisionEncoder(nn.Module): def __init__(self, config: dict): super().__init__() v = config.get("vision", {}) self.enabled = bool(config.get("enabled", True)) hidden = int(v.get("hidden", 384)) depth = int(v.get("depth", 12)) patch = int(v.get("patch", 16)) # Default the encoder output to the trunk hidden_size so concatenation # into the LM stream is dimensionally consistent. out_dim = int(v.get("out", config.get("hidden_size", hidden))) use_ternary = v.get("quant", "ternary") == "ternary" num_heads = max(1, hidden // 64) head_dim = hidden // num_heads self.patch_embed = PatchEmbed(patch_size=patch, hidden_size=hidden) self.body = _EncoderBase(hidden, depth, num_heads, head_dim, out_dim, use_ternary, use_checkpoint=True) def forward(self, pixel_values: torch.Tensor) -> Optional[torch.Tensor]: if not self.enabled: return None return self.body(self.patch_embed(pixel_values)) class AudioEncoder(nn.Module): def __init__(self, config: dict): super().__init__() a = config.get("audio", {}) self.enabled = bool(config.get("enabled", True)) hidden = int(a.get("hidden", 256)) depth = int(a.get("depth", 6)) out_dim = int(a.get("out", config.get("hidden_size", hidden))) use_ternary = a.get("quant", "ternary") == "ternary" num_heads = max(1, hidden // 64) head_dim = hidden // num_heads self.input_proj = nn.Linear(80, hidden, bias=False) self.body = _EncoderBase(hidden, depth, num_heads, head_dim, out_dim, use_ternary, use_checkpoint=True) def forward(self, mel_features: torch.Tensor) -> Optional[torch.Tensor]: if not self.enabled: return None return self.body(self.input_proj(mel_features)) __all__ = ["PatchEmbed", "VisionEncoder", "AudioEncoder"]