File size: 5,148 Bytes
11c11f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""
Chimera 5.2 — multimodal encoders (CPU-friendly, slim).

The previous draft had two latent issues:
* The vision/audio encoders projected to ``out_dim`` (e.g. 2560) which did
  not match the trunk's ``hidden_size`` after scaling, so concatenating
  image embeddings into the LM hidden stream blew up.  We now project to
  the trunk's hidden size by default.
* The internal ``_EncoderBlock`` wrapped a recurrent layer expecting a
  ``cache`` argument; we now call the layer correctly and discard the
  cache (the encoder is purely parallel).

The encoders themselves remain BitLinear-friendly so they share the
ternary memory budget of the trunk.
"""

from __future__ import annotations

from typing import Optional

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

from .layers import GatedDeltaNetLayer
from .quantization import BitLinear, RMSNorm


def _make_linear(use_ternary: bool):
    if use_ternary:
        return BitLinear
    return lambda i, o, **kw: nn.Linear(i, o, bias=False)


class PatchEmbed(nn.Module):
    __constants__ = ["patch_size"]

    def __init__(self, patch_size: int = 16, in_channels: int = 3, hidden_size: int = 384):
        super().__init__()
        self.patch_size = int(patch_size)
        self.proj = nn.Conv2d(in_channels, hidden_size,
                              kernel_size=self.patch_size, stride=self.patch_size)
        self.norm = RMSNorm(hidden_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)
        return self.norm(x)


class _EncoderBlock(nn.Module):
    def __init__(self, hidden: int, num_heads: int, head_dim: int,
                 use_ternary: bool = True):
        super().__init__()
        self.norm = RMSNorm(hidden)
        self.attn = GatedDeltaNetLayer(hidden, num_heads, head_dim,
                                       use_ternary=use_ternary, chunk_size=64)
        self.mlp_norm = RMSNorm(hidden)
        L = _make_linear(use_ternary)
        self.mlp = nn.Sequential(L(hidden, hidden * 4), nn.GELU(), L(hidden * 4, hidden))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        attn_out, _ = self.attn(self.norm(x))
        x = x + attn_out
        return x + self.mlp(self.mlp_norm(x))


class _EncoderBase(nn.Module):
    """Shared encoder body for vision/audio."""

    def __init__(self, hidden: int, depth: int, num_heads: int, head_dim: int,
                 out_dim: int, use_ternary: bool, use_checkpoint: bool):
        super().__init__()
        self.layers = nn.ModuleList([
            _EncoderBlock(hidden, num_heads, head_dim, use_ternary)
            for _ in range(depth)
        ])
        self.proj = nn.Linear(hidden, out_dim, bias=False)
        self.norm = RMSNorm(out_dim)
        self.use_checkpoint = bool(use_checkpoint)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for layer in self.layers:
            if self.use_checkpoint and self.training:
                x = checkpoint(layer, x, use_reentrant=False)
            else:
                x = layer(x)
        return self.norm(self.proj(x))


class VisionEncoder(nn.Module):
    def __init__(self, config: dict):
        super().__init__()
        v = config.get("vision", {})
        self.enabled = bool(config.get("enabled", True))
        hidden = int(v.get("hidden", 384))
        depth = int(v.get("depth", 12))
        patch = int(v.get("patch", 16))
        # Default the encoder output to the trunk hidden_size so concatenation
        # into the LM stream is dimensionally consistent.
        out_dim = int(v.get("out", config.get("hidden_size", hidden)))
        use_ternary = v.get("quant", "ternary") == "ternary"
        num_heads = max(1, hidden // 64)
        head_dim = hidden // num_heads
        self.patch_embed = PatchEmbed(patch_size=patch, hidden_size=hidden)
        self.body = _EncoderBase(hidden, depth, num_heads, head_dim,
                                 out_dim, use_ternary, use_checkpoint=True)

    def forward(self, pixel_values: torch.Tensor) -> Optional[torch.Tensor]:
        if not self.enabled:
            return None
        return self.body(self.patch_embed(pixel_values))


class AudioEncoder(nn.Module):
    def __init__(self, config: dict):
        super().__init__()
        a = config.get("audio", {})
        self.enabled = bool(config.get("enabled", True))
        hidden = int(a.get("hidden", 256))
        depth = int(a.get("depth", 6))
        out_dim = int(a.get("out", config.get("hidden_size", hidden)))
        use_ternary = a.get("quant", "ternary") == "ternary"
        num_heads = max(1, hidden // 64)
        head_dim = hidden // num_heads
        self.input_proj = nn.Linear(80, hidden, bias=False)
        self.body = _EncoderBase(hidden, depth, num_heads, head_dim,
                                 out_dim, use_ternary, use_checkpoint=True)

    def forward(self, mel_features: torch.Tensor) -> Optional[torch.Tensor]:
        if not self.enabled:
            return None
        return self.body(self.input_proj(mel_features))


__all__ = ["PatchEmbed", "VisionEncoder", "AudioEncoder"]