# layer.py # -*- coding: utf-8 -*- import math from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F # noqa class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = float(eps) self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: # x: [..., dim] x_float = x.float() rms = x_float.pow(2).mean(dim=-1, keepdim=True).add(self.eps).sqrt() y = (x_float / rms).to(dtype=x.dtype) return y * self.weight.to(dtype=x.dtype, device=x.device) class SwiGLU(nn.Module): @staticmethod def forward(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor: return nn.functional.silu(gate) * up class TabularImageGQALayer(nn.Module): """ Pre-norm Transformer block with: - Tabular tokens produce Q; tabular+image produce KV (image optional) - GQA: num_query_heads is a multiple of num_kv_heads - Numeric+categorical must be concatenated before calling this layer (one tabular stream) - attention_mask is 1D [B, T_tab] and does not include vision tokens - If vision_features is None, attention is tabular-only - Vision tokens are not updated (no Q for vision) """ def __init__( self, tabular_dim: int, vision_dim: int, num_query_heads: int, num_kv_heads: int, head_dim: int, mlp_ratio: float = 4.0, dropout: float = 0.0, rmsnorm_eps: float = 1e-6, ): super().__init__() if num_query_heads % num_kv_heads != 0: raise ValueError("num_query_heads must be a multiple of num_kv_heads") self.tabular_dim = int(tabular_dim) self.vision_dim = int(vision_dim) self.num_query_heads = int(num_query_heads) self.num_kv_heads = int(num_kv_heads) self.head_dim = int(head_dim) self.q_dim = self.num_query_heads * self.head_dim self.kv_dim = self.num_kv_heads * self.head_dim self.group_size = self.num_query_heads // self.num_kv_heads self.attn_norm = RMSNorm(self.tabular_dim, eps=rmsnorm_eps) # Tabular projections (shared for numeric+categorical stream) self.q_proj_tab = nn.Linear(self.tabular_dim, self.q_dim, bias=False) self.k_proj_tab = nn.Linear(self.tabular_dim, self.kv_dim, bias=False) self.v_proj_tab = nn.Linear(self.tabular_dim, self.kv_dim, bias=False) # Vision KV projections (separate; vision has no Q) self.k_proj_img = nn.Linear(self.vision_dim, self.kv_dim, bias=False) self.v_proj_img = nn.Linear(self.vision_dim, self.kv_dim, bias=False) self.o_proj = nn.Linear(self.q_dim, self.tabular_dim, bias=False) self.attn_dropout = float(dropout) self.resid_dropout = float(dropout) # FFN (LLM-style: gated MLP with SwiGLU) self.ffn_norm = RMSNorm(self.tabular_dim, eps=rmsnorm_eps) ffn_dim = int(round(self.tabular_dim * float(mlp_ratio))) self.gate_proj = nn.Linear(self.tabular_dim, ffn_dim, bias=False) self.up_proj = nn.Linear(self.tabular_dim, ffn_dim, bias=False) self.down_proj = nn.Linear(ffn_dim, self.tabular_dim, bias=False) self.act = SwiGLU() def init_weights(self, std: float = 0.02): # RMSNorm nn.init.ones_(self.attn_norm.weight) nn.init.ones_(self.ffn_norm.weight) # Attention projections nn.init.normal_(self.q_proj_tab.weight, std=std) nn.init.normal_(self.k_proj_tab.weight, std=std) nn.init.normal_(self.v_proj_tab.weight, std=std) nn.init.normal_(self.k_proj_img.weight, std=std) nn.init.normal_(self.v_proj_img.weight, std=std) nn.init.normal_(self.o_proj.weight, std=std) # FFN nn.init.normal_(self.gate_proj.weight, std=std) nn.init.normal_(self.up_proj.weight, std=std) nn.init.normal_(self.down_proj.weight, std=std) @staticmethod def _make_key_bias_from_mask(mask_1d: torch.Tensor, key_len: int) -> torch.Tensor: """ mask_1d: [B, T_key] with 1=keep, 0=mask returns: [B, 1, 1, T_key] float bias with 0 for keep and -inf for mask """ if mask_1d.dtype != torch.float32: mask_f = mask_1d.float() else: mask_f = mask_1d if mask_f.shape[1] != key_len: raise ValueError(f"mask_1d width mismatch: got {mask_f.shape[1]} expected {key_len}") bias = (1.0 - mask_f) * -1e9 return bias.view(mask_f.shape[0], 1, 1, key_len) def _split_heads_q(self, x: torch.Tensor) -> torch.Tensor: # x: [B, T, Hq*d] -> [B, Hq, T, d] B, T, _ = x.shape return x.view(B, T, self.num_query_heads, self.head_dim).transpose(1, 2).contiguous() def _split_heads_kv(self, x: torch.Tensor) -> torch.Tensor: # x: [B, T, Hkv*d] -> [B, Hkv, T, d] B, T, _ = x.shape return x.view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2).contiguous() @staticmethod def _merge_heads_q(x: torch.Tensor) -> torch.Tensor: # x: [B, Hq, T, d] -> [B, T, Hq*d] B, H, T, d = x.shape return x.transpose(1, 2).contiguous().view(B, T, H * d) def forward( self, x_tab: torch.Tensor, attention_mask: torch.Tensor, vision_features: Optional[torch.Tensor] = None, vision_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ x_tab: [B, T_tab, tabular_dim] attention_mask: [B, T_tab] (1=valid tab token, 0=masked tab token). Does NOT include vision. vision_features: None or [B, T_img, vision_dim] vision_mask: None or [B, T_img] (1=valid vision token, 0=masked). Required if vision_features is not None. returns: updated x_tab [B, T_tab, tabular_dim] """ if x_tab.dim() != 3: raise ValueError(f"x_tab must be [B,T,D], got {tuple(x_tab.shape)}") if attention_mask.dim() != 2: raise ValueError(f"attention_mask must be [B,T_tab], got {tuple(attention_mask.shape)}") B, T_tab, D = x_tab.shape if D != self.tabular_dim: raise ValueError(f"tabular_dim mismatch: got {D}, expected {self.tabular_dim}") if attention_mask.shape != (B, T_tab): raise ValueError("attention_mask shape mismatch with x_tab") if attention_mask.device != x_tab.device: attention_mask = attention_mask.to(device=x_tab.device) # ---- Attention block (pre-norm) h = self.attn_norm(x_tab) q_tab = self.q_proj_tab(h) # [B, T_tab, Hq*d] k_tab = self.k_proj_tab(h) # [B, T_tab, Hkv*d] v_tab = self.v_proj_tab(h) # [B, T_tab, Hkv*d] q = self._split_heads_q(q_tab) # [B, Hq, T_tab, d] k_tab = self._split_heads_kv(k_tab) # [B, Hkv, T_tab, d] v_tab = self._split_heads_kv(v_tab) # [B, Hkv, T_tab, d] if vision_features is None: # Keys/values = tab only k = k_tab v = v_tab key_mask = attention_mask # [B, T_tab] else: if vision_features.dim() != 3: raise ValueError(f"vision_features must be [B,T_img,Dv], got {tuple(vision_features.shape)}") if vision_features.shape[0] != B: raise ValueError("vision_features batch mismatch") if vision_features.shape[2] != self.vision_dim: raise ValueError(f"vision_dim mismatch: got {vision_features.shape[2]}, expected {self.vision_dim}") # Require vision_mask for strict missing handling if vision_mask is None: raise ValueError("vision_mask must be provided when vision_features is not None") if vision_mask.dim() != 2: raise ValueError(f"vision_mask must be [B,T_img], got {tuple(vision_mask.shape)}") T_img = vision_features.shape[1] if vision_mask.shape != (B, T_img): raise ValueError(f"vision_mask shape mismatch: expected {(B, T_img)}, got {tuple(vision_mask.shape)}") # Ensure mask dtype matches attention_mask dtype for concatenation if vision_mask.dtype != attention_mask.dtype: vision_mask = vision_mask.to(dtype=attention_mask.dtype) if vision_mask.device != attention_mask.device: vision_mask = vision_mask.to(device=attention_mask.device) param = self.k_proj_img.weight vision_features = vision_features.to(device=param.device, dtype=param.dtype) k_img = self.k_proj_img(vision_features) # [B, T_img, Hkv*d] v_img = self.v_proj_img(vision_features) # [B, T_img, Hkv*d] k_img = self._split_heads_kv(k_img) # [B, Hkv, T_img, d] v_img = self._split_heads_kv(v_img) # [B, Hkv, T_img, d] k = torch.cat([k_tab, k_img], dim=2) # [B, Hkv, T_tab+T_img, d] v = torch.cat([v_tab, v_img], dim=2) # [B, Hkv, T_tab+T_img, d] # STRICT key mask: tab_mask + vision_mask key_mask = torch.cat([attention_mask, vision_mask], dim=1) # [B, T_tab+T_img] # Expand KV heads to Q heads (GQA) if self.group_size != 1: k = k.repeat_interleave(self.group_size, dim=1) # [B, Hq, T_k, d] v = v.repeat_interleave(self.group_size, dim=1) # [B, Hq, T_k, d] T_k = k.shape[2] key_bias = self._make_key_bias_from_mask(key_mask, key_len=T_k) # [B,1,1,T_k] # Attention scores: [B, Hq, T_tab, T_k] scale = 1.0 / math.sqrt(self.head_dim) attn_scores = torch.einsum("bhtd,bhkd->bhtk", q, k) * scale attn_scores = attn_scores + key_bias # broadcast attn_probs = F.softmax(attn_scores.float(), dim=-1) if self.attn_dropout > 0.0 and self.training: attn_probs = F.dropout(attn_probs, p=self.attn_dropout) attn_probs = attn_probs.to(v.dtype) attn_out = torch.einsum("bhtk,bhkd->bhtd", attn_probs, v) # [B,Hq,T_tab,d] attn_out = self._merge_heads_q(attn_out) # [B,T_tab,Hq*d] attn_out = self.o_proj(attn_out) # [B,T_tab,tab_dim] # Query-side masking (tab only): prevents masked tab tokens from updating residual path attn_out = attn_out * attention_mask.to(attn_out.dtype).unsqueeze(-1) if self.resid_dropout > 0.0 and self.training: attn_out = F.dropout(attn_out, p=self.resid_dropout) x = x_tab + attn_out # ---- FFN block (pre-norm) h2 = self.ffn_norm(x) gate = self.gate_proj(h2) up = self.up_proj(h2) f = self.act(gate, up) f = self.down_proj(f) # Query-side masking (tab only) f = f * attention_mask.to(f.dtype).unsqueeze(-1) if self.resid_dropout > 0.0 and self.training: f = F.dropout(f, p=self.resid_dropout) x = x + f return x def _count_params(m: nn.Module) -> Tuple[int, int]: total = sum(p.numel() for p in m.parameters()) trainable = sum(p.numel() for p in m.parameters() if p.requires_grad) return total, trainable def _demo_main(): import argparse parser = argparse.ArgumentParser() parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--t_tab", type=int, default=126) parser.add_argument("--t_img", type=int, default=256) parser.add_argument("--tabular_dim", type=int, default=768) parser.add_argument("--vision_dim", type=int, default=768) parser.add_argument("--num_query_heads", type=int, default=8) parser.add_argument("--num_kv_heads", type=int, default=2) parser.add_argument("--head_dim", type=int, default=128) parser.add_argument("--mlp_ratio", type=float, default=1.5) parser.add_argument("--dropout", type=float, default=0.0) parser.add_argument("--with_vision", action="store_true") parser.add_argument("--dtype", type=str, default="float32", choices=["float16", "bfloat16", "float32"]) parser.add_argument("--device", type=str, default=None) args = parser.parse_args() device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu")) dtype_map = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32} dtype = dtype_map[args.dtype] layer = TabularImageGQALayer( tabular_dim=args.tabular_dim, vision_dim=args.vision_dim, num_query_heads=args.num_query_heads, num_kv_heads=args.num_kv_heads, head_dim=args.head_dim, mlp_ratio=args.mlp_ratio, dropout=args.dropout, ).to(device=device, dtype=dtype) total, trainable = _count_params(layer) print(f"Layer parameters: {total:,} (trainable: {trainable:,})") B = args.batch_size T_tab = args.t_tab x_tab = torch.randn(B, T_tab, args.tabular_dim, device=device, dtype=dtype) # Build a typical HF-style 1D attention mask: 1 for valid, 0 for masked/padded. # Here we create variable valid lengths. lengths = torch.randint(low=max(1, T_tab // 2), high=T_tab + 1, size=(B,), device=device) attention_mask = torch.zeros(B, T_tab, device=device, dtype=torch.long) for b in range(B): attention_mask[b, : int(lengths[b].item())] = 1 if args.with_vision: vision = torch.randn(B, args.t_img, args.vision_dim, device=device, dtype=dtype) # Example vision mask: first half valid for sample 0, all valid for others vision_mask = torch.ones(B, args.t_img, device=device, dtype=torch.long) if args.t_img > 0: vision_mask[0, args.t_img // 2:] = 0 else: vision = None vision_mask = None print("Input x_tab:", tuple(x_tab.shape), x_tab.dtype, x_tab.device) print("Input attention_mask:", tuple(attention_mask.shape), attention_mask.dtype, attention_mask.device) print("Input vision_features:", None if vision is None else (tuple(vision.shape), vision.dtype, vision.device)) print("Input vision_mask:", None if vision_mask is None else (tuple(vision_mask.shape), vision_mask.dtype, vision_mask.device)) with torch.no_grad(): y = layer( x_tab=x_tab, attention_mask=attention_mask, vision_features=vision, vision_mask=vision_mask, ) print("Output y_tab:", tuple(y.shape), y.dtype, y.device) if __name__ == "__main__": _demo_main()