| """ |
| Chimera 5.2 — multimodal encoders (CPU-friendly, slim). |
| |
| The previous draft had two latent issues: |
| * The vision/audio encoders projected to ``out_dim`` (e.g. 2560) which did |
| not match the trunk's ``hidden_size`` after scaling, so concatenating |
| image embeddings into the LM hidden stream blew up. We now project to |
| the trunk's hidden size by default. |
| * The internal ``_EncoderBlock`` wrapped a recurrent layer expecting a |
| ``cache`` argument; we now call the layer correctly and discard the |
| cache (the encoder is purely parallel). |
| |
| The encoders themselves remain BitLinear-friendly so they share the |
| ternary memory budget of the trunk. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
| from torch.utils.checkpoint import checkpoint |
|
|
| from .layers import GatedDeltaNetLayer |
| from .quantization import BitLinear, RMSNorm |
|
|
|
|
| def _make_linear(use_ternary: bool): |
| if use_ternary: |
| return BitLinear |
| return lambda i, o, **kw: nn.Linear(i, o, bias=False) |
|
|
|
|
| class PatchEmbed(nn.Module): |
| __constants__ = ["patch_size"] |
|
|
| def __init__(self, patch_size: int = 16, in_channels: int = 3, hidden_size: int = 384): |
| super().__init__() |
| self.patch_size = int(patch_size) |
| self.proj = nn.Conv2d(in_channels, hidden_size, |
| kernel_size=self.patch_size, stride=self.patch_size) |
| self.norm = RMSNorm(hidden_size) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self.proj(x) |
| x = x.flatten(2).transpose(1, 2) |
| return self.norm(x) |
|
|
|
|
| class _EncoderBlock(nn.Module): |
| def __init__(self, hidden: int, num_heads: int, head_dim: int, |
| use_ternary: bool = True): |
| super().__init__() |
| self.norm = RMSNorm(hidden) |
| self.attn = GatedDeltaNetLayer(hidden, num_heads, head_dim, |
| use_ternary=use_ternary, chunk_size=64) |
| self.mlp_norm = RMSNorm(hidden) |
| L = _make_linear(use_ternary) |
| self.mlp = nn.Sequential(L(hidden, hidden * 4), nn.GELU(), L(hidden * 4, hidden)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| attn_out, _ = self.attn(self.norm(x)) |
| x = x + attn_out |
| return x + self.mlp(self.mlp_norm(x)) |
|
|
|
|
| class _EncoderBase(nn.Module): |
| """Shared encoder body for vision/audio.""" |
|
|
| def __init__(self, hidden: int, depth: int, num_heads: int, head_dim: int, |
| out_dim: int, use_ternary: bool, use_checkpoint: bool): |
| super().__init__() |
| self.layers = nn.ModuleList([ |
| _EncoderBlock(hidden, num_heads, head_dim, use_ternary) |
| for _ in range(depth) |
| ]) |
| self.proj = nn.Linear(hidden, out_dim, bias=False) |
| self.norm = RMSNorm(out_dim) |
| self.use_checkpoint = bool(use_checkpoint) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| for layer in self.layers: |
| if self.use_checkpoint and self.training: |
| x = checkpoint(layer, x, use_reentrant=False) |
| else: |
| x = layer(x) |
| return self.norm(self.proj(x)) |
|
|
|
|
| class VisionEncoder(nn.Module): |
| def __init__(self, config: dict): |
| super().__init__() |
| v = config.get("vision", {}) |
| self.enabled = bool(config.get("enabled", True)) |
| hidden = int(v.get("hidden", 384)) |
| depth = int(v.get("depth", 12)) |
| patch = int(v.get("patch", 16)) |
| |
| |
| out_dim = int(v.get("out", config.get("hidden_size", hidden))) |
| use_ternary = v.get("quant", "ternary") == "ternary" |
| num_heads = max(1, hidden // 64) |
| head_dim = hidden // num_heads |
| self.patch_embed = PatchEmbed(patch_size=patch, hidden_size=hidden) |
| self.body = _EncoderBase(hidden, depth, num_heads, head_dim, |
| out_dim, use_ternary, use_checkpoint=True) |
|
|
| def forward(self, pixel_values: torch.Tensor) -> Optional[torch.Tensor]: |
| if not self.enabled: |
| return None |
| return self.body(self.patch_embed(pixel_values)) |
|
|
|
|
| class AudioEncoder(nn.Module): |
| def __init__(self, config: dict): |
| super().__init__() |
| a = config.get("audio", {}) |
| self.enabled = bool(config.get("enabled", True)) |
| hidden = int(a.get("hidden", 256)) |
| depth = int(a.get("depth", 6)) |
| out_dim = int(a.get("out", config.get("hidden_size", hidden))) |
| use_ternary = a.get("quant", "ternary") == "ternary" |
| num_heads = max(1, hidden // 64) |
| head_dim = hidden // num_heads |
| self.input_proj = nn.Linear(80, hidden, bias=False) |
| self.body = _EncoderBase(hidden, depth, num_heads, head_dim, |
| out_dim, use_ternary, use_checkpoint=True) |
|
|
| def forward(self, mel_features: torch.Tensor) -> Optional[torch.Tensor]: |
| if not self.enabled: |
| return None |
| return self.body(self.input_proj(mel_features)) |
|
|
|
|
| __all__ = ["PatchEmbed", "VisionEncoder", "AudioEncoder"] |
|
|