| |
| |
|
|
| import math |
| from typing import Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim: int, eps: float = 1e-6): |
| super().__init__() |
| self.eps = float(eps) |
| self.weight = nn.Parameter(torch.ones(dim)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| x_float = x.float() |
| rms = x_float.pow(2).mean(dim=-1, keepdim=True).add(self.eps).sqrt() |
| y = (x_float / rms).to(dtype=x.dtype) |
| return y * self.weight.to(dtype=x.dtype, device=x.device) |
|
|
|
|
| class SwiGLU(nn.Module): |
| @staticmethod |
| def forward(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor: |
| return nn.functional.silu(gate) * up |
|
|
|
|
| class TabularImageGQALayer(nn.Module): |
| """ |
| Pre-norm Transformer block with: |
| - Tabular tokens produce Q; tabular+image produce KV (image optional) |
| - GQA: num_query_heads is a multiple of num_kv_heads |
| - Numeric+categorical must be concatenated before calling this layer (one tabular stream) |
| - attention_mask is 1D [B, T_tab] and does not include vision tokens |
| - If vision_features is None, attention is tabular-only |
| - Vision tokens are not updated (no Q for vision) |
| """ |
|
|
| def __init__( |
| self, |
| tabular_dim: int, |
| vision_dim: int, |
| num_query_heads: int, |
| num_kv_heads: int, |
| head_dim: int, |
| mlp_ratio: float = 4.0, |
| dropout: float = 0.0, |
| rmsnorm_eps: float = 1e-6, |
| ): |
| super().__init__() |
|
|
| if num_query_heads % num_kv_heads != 0: |
| raise ValueError("num_query_heads must be a multiple of num_kv_heads") |
|
|
| self.tabular_dim = int(tabular_dim) |
| self.vision_dim = int(vision_dim) |
| self.num_query_heads = int(num_query_heads) |
| self.num_kv_heads = int(num_kv_heads) |
| self.head_dim = int(head_dim) |
|
|
| self.q_dim = self.num_query_heads * self.head_dim |
| self.kv_dim = self.num_kv_heads * self.head_dim |
| self.group_size = self.num_query_heads // self.num_kv_heads |
|
|
| self.attn_norm = RMSNorm(self.tabular_dim, eps=rmsnorm_eps) |
|
|
| |
| self.q_proj_tab = nn.Linear(self.tabular_dim, self.q_dim, bias=False) |
| self.k_proj_tab = nn.Linear(self.tabular_dim, self.kv_dim, bias=False) |
| self.v_proj_tab = nn.Linear(self.tabular_dim, self.kv_dim, bias=False) |
|
|
| |
| self.k_proj_img = nn.Linear(self.vision_dim, self.kv_dim, bias=False) |
| self.v_proj_img = nn.Linear(self.vision_dim, self.kv_dim, bias=False) |
|
|
| self.o_proj = nn.Linear(self.q_dim, self.tabular_dim, bias=False) |
|
|
| self.attn_dropout = float(dropout) |
| self.resid_dropout = float(dropout) |
|
|
| |
| self.ffn_norm = RMSNorm(self.tabular_dim, eps=rmsnorm_eps) |
| ffn_dim = int(round(self.tabular_dim * float(mlp_ratio))) |
|
|
| self.gate_proj = nn.Linear(self.tabular_dim, ffn_dim, bias=False) |
| self.up_proj = nn.Linear(self.tabular_dim, ffn_dim, bias=False) |
| self.down_proj = nn.Linear(ffn_dim, self.tabular_dim, bias=False) |
| self.act = SwiGLU() |
|
|
| def init_weights(self, std: float = 0.02): |
| |
| nn.init.ones_(self.attn_norm.weight) |
| nn.init.ones_(self.ffn_norm.weight) |
|
|
| |
| nn.init.normal_(self.q_proj_tab.weight, std=std) |
| nn.init.normal_(self.k_proj_tab.weight, std=std) |
| nn.init.normal_(self.v_proj_tab.weight, std=std) |
| nn.init.normal_(self.k_proj_img.weight, std=std) |
| nn.init.normal_(self.v_proj_img.weight, std=std) |
| nn.init.normal_(self.o_proj.weight, std=std) |
|
|
| |
| nn.init.normal_(self.gate_proj.weight, std=std) |
| nn.init.normal_(self.up_proj.weight, std=std) |
| nn.init.normal_(self.down_proj.weight, std=std) |
|
|
| @staticmethod |
| def _make_key_bias_from_mask(mask_1d: torch.Tensor, key_len: int) -> torch.Tensor: |
| """ |
| mask_1d: [B, T_key] with 1=keep, 0=mask |
| returns: [B, 1, 1, T_key] float bias with 0 for keep and -inf for mask |
| """ |
| if mask_1d.dtype != torch.float32: |
| mask_f = mask_1d.float() |
| else: |
| mask_f = mask_1d |
| if mask_f.shape[1] != key_len: |
| raise ValueError(f"mask_1d width mismatch: got {mask_f.shape[1]} expected {key_len}") |
| bias = (1.0 - mask_f) * -1e9 |
| return bias.view(mask_f.shape[0], 1, 1, key_len) |
|
|
| def _split_heads_q(self, x: torch.Tensor) -> torch.Tensor: |
| |
| B, T, _ = x.shape |
| return x.view(B, T, self.num_query_heads, self.head_dim).transpose(1, 2).contiguous() |
|
|
| def _split_heads_kv(self, x: torch.Tensor) -> torch.Tensor: |
| |
| B, T, _ = x.shape |
| return x.view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2).contiguous() |
|
|
| @staticmethod |
| def _merge_heads_q(x: torch.Tensor) -> torch.Tensor: |
| |
| B, H, T, d = x.shape |
| return x.transpose(1, 2).contiguous().view(B, T, H * d) |
|
|
| def forward( |
| self, |
| x_tab: torch.Tensor, |
| attention_mask: torch.Tensor, |
| vision_features: Optional[torch.Tensor] = None, |
| vision_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """ |
| x_tab: [B, T_tab, tabular_dim] |
| attention_mask: [B, T_tab] (1=valid tab token, 0=masked tab token). Does NOT include vision. |
| vision_features: None or [B, T_img, vision_dim] |
| vision_mask: None or [B, T_img] (1=valid vision token, 0=masked). Required if vision_features is not None. |
| returns: updated x_tab [B, T_tab, tabular_dim] |
| """ |
| if x_tab.dim() != 3: |
| raise ValueError(f"x_tab must be [B,T,D], got {tuple(x_tab.shape)}") |
| if attention_mask.dim() != 2: |
| raise ValueError(f"attention_mask must be [B,T_tab], got {tuple(attention_mask.shape)}") |
|
|
| B, T_tab, D = x_tab.shape |
| if D != self.tabular_dim: |
| raise ValueError(f"tabular_dim mismatch: got {D}, expected {self.tabular_dim}") |
| if attention_mask.shape != (B, T_tab): |
| raise ValueError("attention_mask shape mismatch with x_tab") |
| if attention_mask.device != x_tab.device: |
| attention_mask = attention_mask.to(device=x_tab.device) |
|
|
| |
| h = self.attn_norm(x_tab) |
|
|
| q_tab = self.q_proj_tab(h) |
| k_tab = self.k_proj_tab(h) |
| v_tab = self.v_proj_tab(h) |
|
|
| q = self._split_heads_q(q_tab) |
| k_tab = self._split_heads_kv(k_tab) |
| v_tab = self._split_heads_kv(v_tab) |
|
|
| if vision_features is None: |
| |
| k = k_tab |
| v = v_tab |
| key_mask = attention_mask |
| else: |
| if vision_features.dim() != 3: |
| raise ValueError(f"vision_features must be [B,T_img,Dv], got {tuple(vision_features.shape)}") |
| if vision_features.shape[0] != B: |
| raise ValueError("vision_features batch mismatch") |
| if vision_features.shape[2] != self.vision_dim: |
| raise ValueError(f"vision_dim mismatch: got {vision_features.shape[2]}, expected {self.vision_dim}") |
|
|
| |
| if vision_mask is None: |
| raise ValueError("vision_mask must be provided when vision_features is not None") |
| if vision_mask.dim() != 2: |
| raise ValueError(f"vision_mask must be [B,T_img], got {tuple(vision_mask.shape)}") |
|
|
| T_img = vision_features.shape[1] |
| if vision_mask.shape != (B, T_img): |
| raise ValueError(f"vision_mask shape mismatch: expected {(B, T_img)}, got {tuple(vision_mask.shape)}") |
|
|
| |
| if vision_mask.dtype != attention_mask.dtype: |
| vision_mask = vision_mask.to(dtype=attention_mask.dtype) |
| if vision_mask.device != attention_mask.device: |
| vision_mask = vision_mask.to(device=attention_mask.device) |
|
|
| param = self.k_proj_img.weight |
| vision_features = vision_features.to(device=param.device, dtype=param.dtype) |
| k_img = self.k_proj_img(vision_features) |
| v_img = self.v_proj_img(vision_features) |
| k_img = self._split_heads_kv(k_img) |
| v_img = self._split_heads_kv(v_img) |
|
|
| k = torch.cat([k_tab, k_img], dim=2) |
| v = torch.cat([v_tab, v_img], dim=2) |
|
|
| |
| key_mask = torch.cat([attention_mask, vision_mask], dim=1) |
|
|
| |
| if self.group_size != 1: |
| k = k.repeat_interleave(self.group_size, dim=1) |
| v = v.repeat_interleave(self.group_size, dim=1) |
|
|
| T_k = k.shape[2] |
| key_bias = self._make_key_bias_from_mask(key_mask, key_len=T_k) |
|
|
| |
| scale = 1.0 / math.sqrt(self.head_dim) |
| attn_scores = torch.einsum("bhtd,bhkd->bhtk", q, k) * scale |
| attn_scores = attn_scores + key_bias |
|
|
| attn_probs = F.softmax(attn_scores.float(), dim=-1) |
| if self.attn_dropout > 0.0 and self.training: |
| attn_probs = F.dropout(attn_probs, p=self.attn_dropout) |
| attn_probs = attn_probs.to(v.dtype) |
|
|
| attn_out = torch.einsum("bhtk,bhkd->bhtd", attn_probs, v) |
| attn_out = self._merge_heads_q(attn_out) |
| attn_out = self.o_proj(attn_out) |
|
|
| |
| attn_out = attn_out * attention_mask.to(attn_out.dtype).unsqueeze(-1) |
|
|
| if self.resid_dropout > 0.0 and self.training: |
| attn_out = F.dropout(attn_out, p=self.resid_dropout) |
|
|
| x = x_tab + attn_out |
|
|
| |
| h2 = self.ffn_norm(x) |
| gate = self.gate_proj(h2) |
| up = self.up_proj(h2) |
| f = self.act(gate, up) |
| f = self.down_proj(f) |
|
|
| |
| f = f * attention_mask.to(f.dtype).unsqueeze(-1) |
|
|
| if self.resid_dropout > 0.0 and self.training: |
| f = F.dropout(f, p=self.resid_dropout) |
|
|
| x = x + f |
| return x |
|
|
|
|
| def _count_params(m: nn.Module) -> Tuple[int, int]: |
| total = sum(p.numel() for p in m.parameters()) |
| trainable = sum(p.numel() for p in m.parameters() if p.requires_grad) |
| return total, trainable |
|
|
|
|
| def _demo_main(): |
| import argparse |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument("--batch_size", type=int, default=4) |
| parser.add_argument("--t_tab", type=int, default=126) |
| parser.add_argument("--t_img", type=int, default=256) |
| parser.add_argument("--tabular_dim", type=int, default=768) |
| parser.add_argument("--vision_dim", type=int, default=768) |
| parser.add_argument("--num_query_heads", type=int, default=8) |
| parser.add_argument("--num_kv_heads", type=int, default=2) |
| parser.add_argument("--head_dim", type=int, default=128) |
| parser.add_argument("--mlp_ratio", type=float, default=1.5) |
| parser.add_argument("--dropout", type=float, default=0.0) |
| parser.add_argument("--with_vision", action="store_true") |
| parser.add_argument("--dtype", type=str, default="float32", choices=["float16", "bfloat16", "float32"]) |
| parser.add_argument("--device", type=str, default=None) |
| args = parser.parse_args() |
|
|
| device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu")) |
| dtype_map = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32} |
| dtype = dtype_map[args.dtype] |
|
|
| layer = TabularImageGQALayer( |
| tabular_dim=args.tabular_dim, |
| vision_dim=args.vision_dim, |
| num_query_heads=args.num_query_heads, |
| num_kv_heads=args.num_kv_heads, |
| head_dim=args.head_dim, |
| mlp_ratio=args.mlp_ratio, |
| dropout=args.dropout, |
| ).to(device=device, dtype=dtype) |
|
|
| total, trainable = _count_params(layer) |
| print(f"Layer parameters: {total:,} (trainable: {trainable:,})") |
|
|
| B = args.batch_size |
| T_tab = args.t_tab |
|
|
| x_tab = torch.randn(B, T_tab, args.tabular_dim, device=device, dtype=dtype) |
|
|
| |
| |
| lengths = torch.randint(low=max(1, T_tab // 2), high=T_tab + 1, size=(B,), device=device) |
| attention_mask = torch.zeros(B, T_tab, device=device, dtype=torch.long) |
| for b in range(B): |
| attention_mask[b, : int(lengths[b].item())] = 1 |
|
|
| if args.with_vision: |
| vision = torch.randn(B, args.t_img, args.vision_dim, device=device, dtype=dtype) |
|
|
| |
| vision_mask = torch.ones(B, args.t_img, device=device, dtype=torch.long) |
| if args.t_img > 0: |
| vision_mask[0, args.t_img // 2:] = 0 |
| else: |
| vision = None |
| vision_mask = None |
|
|
| print("Input x_tab:", tuple(x_tab.shape), x_tab.dtype, x_tab.device) |
| print("Input attention_mask:", tuple(attention_mask.shape), attention_mask.dtype, attention_mask.device) |
| print("Input vision_features:", None if vision is None else (tuple(vision.shape), vision.dtype, vision.device)) |
| print("Input vision_mask:", |
| None if vision_mask is None else (tuple(vision_mask.shape), vision_mask.dtype, vision_mask.device)) |
|
|
| with torch.no_grad(): |
| y = layer( |
| x_tab=x_tab, |
| attention_mask=attention_mask, |
| vision_features=vision, |
| vision_mask=vision_mask, |
| ) |
|
|
| print("Output y_tab:", tuple(y.shape), y.dtype, y.device) |
|
|
|
|
| if __name__ == "__main__": |
| _demo_main() |
|
|