soilformer / modelling /layer.py
Kuangdai
Initial release of SoilFormer
6fb6c07
# layer.py
# -*- coding: utf-8 -*-
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = float(eps)
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [..., dim]
x_float = x.float()
rms = x_float.pow(2).mean(dim=-1, keepdim=True).add(self.eps).sqrt()
y = (x_float / rms).to(dtype=x.dtype)
return y * self.weight.to(dtype=x.dtype, device=x.device)
class SwiGLU(nn.Module):
@staticmethod
def forward(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
return nn.functional.silu(gate) * up
class TabularImageGQALayer(nn.Module):
"""
Pre-norm Transformer block with:
- Tabular tokens produce Q; tabular+image produce KV (image optional)
- GQA: num_query_heads is a multiple of num_kv_heads
- Numeric+categorical must be concatenated before calling this layer (one tabular stream)
- attention_mask is 1D [B, T_tab] and does not include vision tokens
- If vision_features is None, attention is tabular-only
- Vision tokens are not updated (no Q for vision)
"""
def __init__(
self,
tabular_dim: int,
vision_dim: int,
num_query_heads: int,
num_kv_heads: int,
head_dim: int,
mlp_ratio: float = 4.0,
dropout: float = 0.0,
rmsnorm_eps: float = 1e-6,
):
super().__init__()
if num_query_heads % num_kv_heads != 0:
raise ValueError("num_query_heads must be a multiple of num_kv_heads")
self.tabular_dim = int(tabular_dim)
self.vision_dim = int(vision_dim)
self.num_query_heads = int(num_query_heads)
self.num_kv_heads = int(num_kv_heads)
self.head_dim = int(head_dim)
self.q_dim = self.num_query_heads * self.head_dim
self.kv_dim = self.num_kv_heads * self.head_dim
self.group_size = self.num_query_heads // self.num_kv_heads
self.attn_norm = RMSNorm(self.tabular_dim, eps=rmsnorm_eps)
# Tabular projections (shared for numeric+categorical stream)
self.q_proj_tab = nn.Linear(self.tabular_dim, self.q_dim, bias=False)
self.k_proj_tab = nn.Linear(self.tabular_dim, self.kv_dim, bias=False)
self.v_proj_tab = nn.Linear(self.tabular_dim, self.kv_dim, bias=False)
# Vision KV projections (separate; vision has no Q)
self.k_proj_img = nn.Linear(self.vision_dim, self.kv_dim, bias=False)
self.v_proj_img = nn.Linear(self.vision_dim, self.kv_dim, bias=False)
self.o_proj = nn.Linear(self.q_dim, self.tabular_dim, bias=False)
self.attn_dropout = float(dropout)
self.resid_dropout = float(dropout)
# FFN (LLM-style: gated MLP with SwiGLU)
self.ffn_norm = RMSNorm(self.tabular_dim, eps=rmsnorm_eps)
ffn_dim = int(round(self.tabular_dim * float(mlp_ratio)))
self.gate_proj = nn.Linear(self.tabular_dim, ffn_dim, bias=False)
self.up_proj = nn.Linear(self.tabular_dim, ffn_dim, bias=False)
self.down_proj = nn.Linear(ffn_dim, self.tabular_dim, bias=False)
self.act = SwiGLU()
def init_weights(self, std: float = 0.02):
# RMSNorm
nn.init.ones_(self.attn_norm.weight)
nn.init.ones_(self.ffn_norm.weight)
# Attention projections
nn.init.normal_(self.q_proj_tab.weight, std=std)
nn.init.normal_(self.k_proj_tab.weight, std=std)
nn.init.normal_(self.v_proj_tab.weight, std=std)
nn.init.normal_(self.k_proj_img.weight, std=std)
nn.init.normal_(self.v_proj_img.weight, std=std)
nn.init.normal_(self.o_proj.weight, std=std)
# FFN
nn.init.normal_(self.gate_proj.weight, std=std)
nn.init.normal_(self.up_proj.weight, std=std)
nn.init.normal_(self.down_proj.weight, std=std)
@staticmethod
def _make_key_bias_from_mask(mask_1d: torch.Tensor, key_len: int) -> torch.Tensor:
"""
mask_1d: [B, T_key] with 1=keep, 0=mask
returns: [B, 1, 1, T_key] float bias with 0 for keep and -inf for mask
"""
if mask_1d.dtype != torch.float32:
mask_f = mask_1d.float()
else:
mask_f = mask_1d
if mask_f.shape[1] != key_len:
raise ValueError(f"mask_1d width mismatch: got {mask_f.shape[1]} expected {key_len}")
bias = (1.0 - mask_f) * -1e9
return bias.view(mask_f.shape[0], 1, 1, key_len)
def _split_heads_q(self, x: torch.Tensor) -> torch.Tensor:
# x: [B, T, Hq*d] -> [B, Hq, T, d]
B, T, _ = x.shape
return x.view(B, T, self.num_query_heads, self.head_dim).transpose(1, 2).contiguous()
def _split_heads_kv(self, x: torch.Tensor) -> torch.Tensor:
# x: [B, T, Hkv*d] -> [B, Hkv, T, d]
B, T, _ = x.shape
return x.view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2).contiguous()
@staticmethod
def _merge_heads_q(x: torch.Tensor) -> torch.Tensor:
# x: [B, Hq, T, d] -> [B, T, Hq*d]
B, H, T, d = x.shape
return x.transpose(1, 2).contiguous().view(B, T, H * d)
def forward(
self,
x_tab: torch.Tensor,
attention_mask: torch.Tensor,
vision_features: Optional[torch.Tensor] = None,
vision_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
x_tab: [B, T_tab, tabular_dim]
attention_mask: [B, T_tab] (1=valid tab token, 0=masked tab token). Does NOT include vision.
vision_features: None or [B, T_img, vision_dim]
vision_mask: None or [B, T_img] (1=valid vision token, 0=masked). Required if vision_features is not None.
returns: updated x_tab [B, T_tab, tabular_dim]
"""
if x_tab.dim() != 3:
raise ValueError(f"x_tab must be [B,T,D], got {tuple(x_tab.shape)}")
if attention_mask.dim() != 2:
raise ValueError(f"attention_mask must be [B,T_tab], got {tuple(attention_mask.shape)}")
B, T_tab, D = x_tab.shape
if D != self.tabular_dim:
raise ValueError(f"tabular_dim mismatch: got {D}, expected {self.tabular_dim}")
if attention_mask.shape != (B, T_tab):
raise ValueError("attention_mask shape mismatch with x_tab")
if attention_mask.device != x_tab.device:
attention_mask = attention_mask.to(device=x_tab.device)
# ---- Attention block (pre-norm)
h = self.attn_norm(x_tab)
q_tab = self.q_proj_tab(h) # [B, T_tab, Hq*d]
k_tab = self.k_proj_tab(h) # [B, T_tab, Hkv*d]
v_tab = self.v_proj_tab(h) # [B, T_tab, Hkv*d]
q = self._split_heads_q(q_tab) # [B, Hq, T_tab, d]
k_tab = self._split_heads_kv(k_tab) # [B, Hkv, T_tab, d]
v_tab = self._split_heads_kv(v_tab) # [B, Hkv, T_tab, d]
if vision_features is None:
# Keys/values = tab only
k = k_tab
v = v_tab
key_mask = attention_mask # [B, T_tab]
else:
if vision_features.dim() != 3:
raise ValueError(f"vision_features must be [B,T_img,Dv], got {tuple(vision_features.shape)}")
if vision_features.shape[0] != B:
raise ValueError("vision_features batch mismatch")
if vision_features.shape[2] != self.vision_dim:
raise ValueError(f"vision_dim mismatch: got {vision_features.shape[2]}, expected {self.vision_dim}")
# Require vision_mask for strict missing handling
if vision_mask is None:
raise ValueError("vision_mask must be provided when vision_features is not None")
if vision_mask.dim() != 2:
raise ValueError(f"vision_mask must be [B,T_img], got {tuple(vision_mask.shape)}")
T_img = vision_features.shape[1]
if vision_mask.shape != (B, T_img):
raise ValueError(f"vision_mask shape mismatch: expected {(B, T_img)}, got {tuple(vision_mask.shape)}")
# Ensure mask dtype matches attention_mask dtype for concatenation
if vision_mask.dtype != attention_mask.dtype:
vision_mask = vision_mask.to(dtype=attention_mask.dtype)
if vision_mask.device != attention_mask.device:
vision_mask = vision_mask.to(device=attention_mask.device)
param = self.k_proj_img.weight
vision_features = vision_features.to(device=param.device, dtype=param.dtype)
k_img = self.k_proj_img(vision_features) # [B, T_img, Hkv*d]
v_img = self.v_proj_img(vision_features) # [B, T_img, Hkv*d]
k_img = self._split_heads_kv(k_img) # [B, Hkv, T_img, d]
v_img = self._split_heads_kv(v_img) # [B, Hkv, T_img, d]
k = torch.cat([k_tab, k_img], dim=2) # [B, Hkv, T_tab+T_img, d]
v = torch.cat([v_tab, v_img], dim=2) # [B, Hkv, T_tab+T_img, d]
# STRICT key mask: tab_mask + vision_mask
key_mask = torch.cat([attention_mask, vision_mask], dim=1) # [B, T_tab+T_img]
# Expand KV heads to Q heads (GQA)
if self.group_size != 1:
k = k.repeat_interleave(self.group_size, dim=1) # [B, Hq, T_k, d]
v = v.repeat_interleave(self.group_size, dim=1) # [B, Hq, T_k, d]
T_k = k.shape[2]
key_bias = self._make_key_bias_from_mask(key_mask, key_len=T_k) # [B,1,1,T_k]
# Attention scores: [B, Hq, T_tab, T_k]
scale = 1.0 / math.sqrt(self.head_dim)
attn_scores = torch.einsum("bhtd,bhkd->bhtk", q, k) * scale
attn_scores = attn_scores + key_bias # broadcast
attn_probs = F.softmax(attn_scores.float(), dim=-1)
if self.attn_dropout > 0.0 and self.training:
attn_probs = F.dropout(attn_probs, p=self.attn_dropout)
attn_probs = attn_probs.to(v.dtype)
attn_out = torch.einsum("bhtk,bhkd->bhtd", attn_probs, v) # [B,Hq,T_tab,d]
attn_out = self._merge_heads_q(attn_out) # [B,T_tab,Hq*d]
attn_out = self.o_proj(attn_out) # [B,T_tab,tab_dim]
# Query-side masking (tab only): prevents masked tab tokens from updating residual path
attn_out = attn_out * attention_mask.to(attn_out.dtype).unsqueeze(-1)
if self.resid_dropout > 0.0 and self.training:
attn_out = F.dropout(attn_out, p=self.resid_dropout)
x = x_tab + attn_out
# ---- FFN block (pre-norm)
h2 = self.ffn_norm(x)
gate = self.gate_proj(h2)
up = self.up_proj(h2)
f = self.act(gate, up)
f = self.down_proj(f)
# Query-side masking (tab only)
f = f * attention_mask.to(f.dtype).unsqueeze(-1)
if self.resid_dropout > 0.0 and self.training:
f = F.dropout(f, p=self.resid_dropout)
x = x + f
return x
def _count_params(m: nn.Module) -> Tuple[int, int]:
total = sum(p.numel() for p in m.parameters())
trainable = sum(p.numel() for p in m.parameters() if p.requires_grad)
return total, trainable
def _demo_main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--t_tab", type=int, default=126)
parser.add_argument("--t_img", type=int, default=256)
parser.add_argument("--tabular_dim", type=int, default=768)
parser.add_argument("--vision_dim", type=int, default=768)
parser.add_argument("--num_query_heads", type=int, default=8)
parser.add_argument("--num_kv_heads", type=int, default=2)
parser.add_argument("--head_dim", type=int, default=128)
parser.add_argument("--mlp_ratio", type=float, default=1.5)
parser.add_argument("--dropout", type=float, default=0.0)
parser.add_argument("--with_vision", action="store_true")
parser.add_argument("--dtype", type=str, default="float32", choices=["float16", "bfloat16", "float32"])
parser.add_argument("--device", type=str, default=None)
args = parser.parse_args()
device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu"))
dtype_map = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}
dtype = dtype_map[args.dtype]
layer = TabularImageGQALayer(
tabular_dim=args.tabular_dim,
vision_dim=args.vision_dim,
num_query_heads=args.num_query_heads,
num_kv_heads=args.num_kv_heads,
head_dim=args.head_dim,
mlp_ratio=args.mlp_ratio,
dropout=args.dropout,
).to(device=device, dtype=dtype)
total, trainable = _count_params(layer)
print(f"Layer parameters: {total:,} (trainable: {trainable:,})")
B = args.batch_size
T_tab = args.t_tab
x_tab = torch.randn(B, T_tab, args.tabular_dim, device=device, dtype=dtype)
# Build a typical HF-style 1D attention mask: 1 for valid, 0 for masked/padded.
# Here we create variable valid lengths.
lengths = torch.randint(low=max(1, T_tab // 2), high=T_tab + 1, size=(B,), device=device)
attention_mask = torch.zeros(B, T_tab, device=device, dtype=torch.long)
for b in range(B):
attention_mask[b, : int(lengths[b].item())] = 1
if args.with_vision:
vision = torch.randn(B, args.t_img, args.vision_dim, device=device, dtype=dtype)
# Example vision mask: first half valid for sample 0, all valid for others
vision_mask = torch.ones(B, args.t_img, device=device, dtype=torch.long)
if args.t_img > 0:
vision_mask[0, args.t_img // 2:] = 0
else:
vision = None
vision_mask = None
print("Input x_tab:", tuple(x_tab.shape), x_tab.dtype, x_tab.device)
print("Input attention_mask:", tuple(attention_mask.shape), attention_mask.dtype, attention_mask.device)
print("Input vision_features:", None if vision is None else (tuple(vision.shape), vision.dtype, vision.device))
print("Input vision_mask:",
None if vision_mask is None else (tuple(vision_mask.shape), vision_mask.dtype, vision_mask.device))
with torch.no_grad():
y = layer(
x_tab=x_tab,
attention_mask=attention_mask,
vision_features=vision,
vision_mask=vision_mask,
)
print("Output y_tab:", tuple(y.shape), y.dtype, y.device)
if __name__ == "__main__":
_demo_main()