ch1mera / chimera /multimodal.py
Lgr54HFi's picture
Upload folder using huggingface_hub
6e408ce verified
"""
Chimera 5.2 — multimodal encoders (CPU-friendly, slim).
The previous draft had two latent issues:
* The vision/audio encoders projected to ``out_dim`` (e.g. 2560) which did
not match the trunk's ``hidden_size`` after scaling, so concatenating
image embeddings into the LM hidden stream blew up. We now project to
the trunk's hidden size by default.
* The internal ``_EncoderBlock`` wrapped a recurrent layer expecting a
``cache`` argument; we now call the layer correctly and discard the
cache (the encoder is purely parallel).
The encoders themselves remain BitLinear-friendly so they share the
ternary memory budget of the trunk.
"""
from __future__ import annotations
from typing import Optional
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from .layers import GatedDeltaNetLayer
from .quantization import BitLinear, RMSNorm
def _make_linear(use_ternary: bool):
if use_ternary:
return BitLinear
return lambda i, o, **kw: nn.Linear(i, o, bias=False)
class PatchEmbed(nn.Module):
__constants__ = ["patch_size"]
def __init__(self, patch_size: int = 16, in_channels: int = 3, hidden_size: int = 384):
super().__init__()
self.patch_size = int(patch_size)
self.proj = nn.Conv2d(in_channels, hidden_size,
kernel_size=self.patch_size, stride=self.patch_size)
self.norm = RMSNorm(hidden_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
x = x.flatten(2).transpose(1, 2)
return self.norm(x)
class _EncoderBlock(nn.Module):
def __init__(self, hidden: int, num_heads: int, head_dim: int,
use_ternary: bool = True):
super().__init__()
self.norm = RMSNorm(hidden)
self.attn = GatedDeltaNetLayer(hidden, num_heads, head_dim,
use_ternary=use_ternary, chunk_size=64)
self.mlp_norm = RMSNorm(hidden)
L = _make_linear(use_ternary)
self.mlp = nn.Sequential(L(hidden, hidden * 4), nn.GELU(), L(hidden * 4, hidden))
def forward(self, x: torch.Tensor) -> torch.Tensor:
attn_out, _ = self.attn(self.norm(x))
x = x + attn_out
return x + self.mlp(self.mlp_norm(x))
class _EncoderBase(nn.Module):
"""Shared encoder body for vision/audio."""
def __init__(self, hidden: int, depth: int, num_heads: int, head_dim: int,
out_dim: int, use_ternary: bool, use_checkpoint: bool):
super().__init__()
self.layers = nn.ModuleList([
_EncoderBlock(hidden, num_heads, head_dim, use_ternary)
for _ in range(depth)
])
self.proj = nn.Linear(hidden, out_dim, bias=False)
self.norm = RMSNorm(out_dim)
self.use_checkpoint = bool(use_checkpoint)
def forward(self, x: torch.Tensor) -> torch.Tensor:
for layer in self.layers:
if self.use_checkpoint and self.training:
x = checkpoint(layer, x, use_reentrant=False)
else:
x = layer(x)
return self.norm(self.proj(x))
class VisionEncoder(nn.Module):
def __init__(self, config: dict):
super().__init__()
v = config.get("vision", {})
self.enabled = bool(config.get("enabled", True))
hidden = int(v.get("hidden", 384))
depth = int(v.get("depth", 12))
patch = int(v.get("patch", 16))
# Default the encoder output to the trunk hidden_size so concatenation
# into the LM stream is dimensionally consistent.
out_dim = int(v.get("out", config.get("hidden_size", hidden)))
use_ternary = v.get("quant", "ternary") == "ternary"
num_heads = max(1, hidden // 64)
head_dim = hidden // num_heads
self.patch_embed = PatchEmbed(patch_size=patch, hidden_size=hidden)
self.body = _EncoderBase(hidden, depth, num_heads, head_dim,
out_dim, use_ternary, use_checkpoint=True)
def forward(self, pixel_values: torch.Tensor) -> Optional[torch.Tensor]:
if not self.enabled:
return None
return self.body(self.patch_embed(pixel_values))
class AudioEncoder(nn.Module):
def __init__(self, config: dict):
super().__init__()
a = config.get("audio", {})
self.enabled = bool(config.get("enabled", True))
hidden = int(a.get("hidden", 256))
depth = int(a.get("depth", 6))
out_dim = int(a.get("out", config.get("hidden_size", hidden)))
use_ternary = a.get("quant", "ternary") == "ternary"
num_heads = max(1, hidden // 64)
head_dim = hidden // num_heads
self.input_proj = nn.Linear(80, hidden, bias=False)
self.body = _EncoderBase(hidden, depth, num_heads, head_dim,
out_dim, use_ternary, use_checkpoint=True)
def forward(self, mel_features: torch.Tensor) -> Optional[torch.Tensor]:
if not self.enabled:
return None
return self.body(self.input_proj(mel_features))
__all__ = ["PatchEmbed", "VisionEncoder", "AudioEncoder"]