File size: 5,148 Bytes
11c11f8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 | """
Chimera 5.2 — multimodal encoders (CPU-friendly, slim).
The previous draft had two latent issues:
* The vision/audio encoders projected to ``out_dim`` (e.g. 2560) which did
not match the trunk's ``hidden_size`` after scaling, so concatenating
image embeddings into the LM hidden stream blew up. We now project to
the trunk's hidden size by default.
* The internal ``_EncoderBlock`` wrapped a recurrent layer expecting a
``cache`` argument; we now call the layer correctly and discard the
cache (the encoder is purely parallel).
The encoders themselves remain BitLinear-friendly so they share the
ternary memory budget of the trunk.
"""
from __future__ import annotations
from typing import Optional
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from .layers import GatedDeltaNetLayer
from .quantization import BitLinear, RMSNorm
def _make_linear(use_ternary: bool):
if use_ternary:
return BitLinear
return lambda i, o, **kw: nn.Linear(i, o, bias=False)
class PatchEmbed(nn.Module):
__constants__ = ["patch_size"]
def __init__(self, patch_size: int = 16, in_channels: int = 3, hidden_size: int = 384):
super().__init__()
self.patch_size = int(patch_size)
self.proj = nn.Conv2d(in_channels, hidden_size,
kernel_size=self.patch_size, stride=self.patch_size)
self.norm = RMSNorm(hidden_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
x = x.flatten(2).transpose(1, 2)
return self.norm(x)
class _EncoderBlock(nn.Module):
def __init__(self, hidden: int, num_heads: int, head_dim: int,
use_ternary: bool = True):
super().__init__()
self.norm = RMSNorm(hidden)
self.attn = GatedDeltaNetLayer(hidden, num_heads, head_dim,
use_ternary=use_ternary, chunk_size=64)
self.mlp_norm = RMSNorm(hidden)
L = _make_linear(use_ternary)
self.mlp = nn.Sequential(L(hidden, hidden * 4), nn.GELU(), L(hidden * 4, hidden))
def forward(self, x: torch.Tensor) -> torch.Tensor:
attn_out, _ = self.attn(self.norm(x))
x = x + attn_out
return x + self.mlp(self.mlp_norm(x))
class _EncoderBase(nn.Module):
"""Shared encoder body for vision/audio."""
def __init__(self, hidden: int, depth: int, num_heads: int, head_dim: int,
out_dim: int, use_ternary: bool, use_checkpoint: bool):
super().__init__()
self.layers = nn.ModuleList([
_EncoderBlock(hidden, num_heads, head_dim, use_ternary)
for _ in range(depth)
])
self.proj = nn.Linear(hidden, out_dim, bias=False)
self.norm = RMSNorm(out_dim)
self.use_checkpoint = bool(use_checkpoint)
def forward(self, x: torch.Tensor) -> torch.Tensor:
for layer in self.layers:
if self.use_checkpoint and self.training:
x = checkpoint(layer, x, use_reentrant=False)
else:
x = layer(x)
return self.norm(self.proj(x))
class VisionEncoder(nn.Module):
def __init__(self, config: dict):
super().__init__()
v = config.get("vision", {})
self.enabled = bool(config.get("enabled", True))
hidden = int(v.get("hidden", 384))
depth = int(v.get("depth", 12))
patch = int(v.get("patch", 16))
# Default the encoder output to the trunk hidden_size so concatenation
# into the LM stream is dimensionally consistent.
out_dim = int(v.get("out", config.get("hidden_size", hidden)))
use_ternary = v.get("quant", "ternary") == "ternary"
num_heads = max(1, hidden // 64)
head_dim = hidden // num_heads
self.patch_embed = PatchEmbed(patch_size=patch, hidden_size=hidden)
self.body = _EncoderBase(hidden, depth, num_heads, head_dim,
out_dim, use_ternary, use_checkpoint=True)
def forward(self, pixel_values: torch.Tensor) -> Optional[torch.Tensor]:
if not self.enabled:
return None
return self.body(self.patch_embed(pixel_values))
class AudioEncoder(nn.Module):
def __init__(self, config: dict):
super().__init__()
a = config.get("audio", {})
self.enabled = bool(config.get("enabled", True))
hidden = int(a.get("hidden", 256))
depth = int(a.get("depth", 6))
out_dim = int(a.get("out", config.get("hidden_size", hidden)))
use_ternary = a.get("quant", "ternary") == "ternary"
num_heads = max(1, hidden // 64)
head_dim = hidden // num_heads
self.input_proj = nn.Linear(80, hidden, bias=False)
self.body = _EncoderBase(hidden, depth, num_heads, head_dim,
out_dim, use_ternary, use_checkpoint=True)
def forward(self, mel_features: torch.Tensor) -> Optional[torch.Tensor]:
if not self.enabled:
return None
return self.body(self.input_proj(mel_features))
__all__ = ["PatchEmbed", "VisionEncoder", "AudioEncoder"]
|