soilformer / modelling /decode_categorical.py
Kuangdai
Initial release of SoilFormer
6fb6c07
# decode_categorical.py
# -*- coding: utf-8 -*-
"""
Categorical decoder for tabular transformer.
Design (column-wise heads):
- Each categorical column corresponds to exactly 1 token.
- Each column has its own classifier head:
hidden_size -> num_classes[col]
Optionally with a small MLP:
hidden_size -> middle_size -> num_classes[col]
No loss is included here (caller will apply CrossEntropyLoss).
"""
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from utils import load_json, GroupedMLP
# ============================================================
# Small head builder
# ============================================================
def _make_head(
hidden_size: int,
num_classes: int,
middle_size: Optional[int],
bias: bool = True,
) -> nn.Module:
"""
Build a lightweight per-column classifier head.
"""
if middle_size is None:
return nn.Linear(hidden_size, num_classes, bias=bias)
return nn.Sequential(
nn.Linear(hidden_size, middle_size, bias=bias),
nn.GELU(),
nn.Linear(middle_size, num_classes, bias=bias),
)
# ============================================================
# Decoder
# ============================================================
class CategoricalDecoder(nn.Module):
"""
Column-wise categorical decoder.
Design:
- Each categorical column corresponds to exactly one token.
- Each column has its own classifier head:
hidden_size -> num_classes[col]
Optionally with a small MLP:
hidden_size -> middle_size -> num_classes[col]
- In addition, the decoder predicts a per-sample, per-column
log-variance term `s` used for heteroscedastic loss weighting.
Input:
x_cat_tokens: [B, M, H]
B = batch size
M = number of categorical columns (ordered by col_id)
H = hidden size
Outputs:
Case 1 (return_padded=False):
logits_list: List[Tensor] length M
logits_list[m]: [B, num_classes[m]]
s: [B, M]
Predicted log-variance per sample and column:
s[b, m] = log sigma^2_{b,m}
Intended for heteroscedastic loss weighting.
Case 2 (return_padded=True):
logits_padded: [B, M, Cmax]
Logits padded to the maximum class count across columns.
s: [B, M]
Same uncertainty prediction as above.
valid_mask: [M, Cmax]
True for valid class indices for each column.
"""
def __init__(
self,
hidden_size: int,
cat_vocab_json: str,
middle_size: Optional[int] = None,
bias: bool = True,
homoscedastic: bool = True,
):
super().__init__()
spec = load_json(cat_vocab_json)
items = sorted(spec.items(), key=lambda x: x[1]["col_id"])
col_ids: List[int] = []
num_classes: List[int] = []
for _, val in items:
col_ids.append(int(val["col_id"]))
num_classes.append(int(val["num_classes"]))
self.hidden_size = int(hidden_size)
self.num_cols = len(num_classes)
self.middle_size = middle_size
self.homoscedastic = bool(homoscedastic)
# Buffers for debugging / validation / optional padded output
self.register_buffer("cat_col_ids", torch.tensor(col_ids, dtype=torch.long), persistent=True) # [M]
self.register_buffer("num_classes", torch.tensor(num_classes, dtype=torch.long), persistent=True) # [M]
# Build per-column heads
heads = []
for c in num_classes:
head = _make_head(self.hidden_size, c, middle_size, bias=bias)
heads.append(head)
self.heads = nn.ModuleList(heads)
if self.homoscedastic:
self.s_param = nn.Parameter(torch.zeros(self.num_cols))
self.s_head = None
else:
self.s_head = GroupedMLP(
n_var=self.num_cols,
n_in=self.hidden_size,
n_out=1,
middle_size=self.middle_size,
)
self.s_param = None
def init_weights(self, std: float = 0.02):
for head in self.heads:
for module in head.modules():
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, std=std)
if module.bias is not None:
nn.init.zeros_(module.bias)
if self.homoscedastic:
nn.init.zeros_(self.s_param)
else:
self.s_head.init_weights(std=0.0)
def _check_input(self, x_cat_tokens: torch.Tensor) -> Tuple[int, int, int]:
if x_cat_tokens.dim() != 3:
raise ValueError(f"x_cat_tokens must be [B,M,H], got {tuple(x_cat_tokens.shape)}")
B, M, H = x_cat_tokens.shape
if H != self.hidden_size:
raise ValueError(f"hidden_size mismatch: got {H}, expected {self.hidden_size}")
if M != self.num_cols:
raise ValueError(f"categorical token count mismatch: got M={M}, expected {self.num_cols}")
return B, M, H
@torch.no_grad()
def _build_valid_mask(self, device: torch.device) -> torch.Tensor:
"""
valid_mask[m, j] = True iff j < num_classes[m]
"""
M = self.num_cols
cmax = int(self.num_classes.max().item())
ar = torch.arange(cmax, device=device).view(1, cmax).expand(M, cmax)
nc = self.num_classes.view(M, 1).expand(M, cmax)
return ar < nc
def forward(
self,
x_cat_tokens: torch.Tensor,
return_padded: bool = False,
pad_value: Optional[float] = None,
) -> Union[
Tuple[List[torch.Tensor], torch.Tensor],
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
]:
"""
Args:
x_cat_tokens: [B, M, H]
B = batch size
M = number of categorical columns
H = hidden size (per-column token embedding dim)
return_padded:
False:
return (logits_list, s)
True:
return (logits_padded, s, valid_mask)
pad_value:
Value used to fill invalid class positions in padded logits.
Returns:
Case 1 (return_padded=False):
logits_list: List length M
logits_list[m]: [B, C_m]
s: [B, M]
s[b, m] = log sigma^2 for sample b, column m
Case 2 (return_padded=True):
logits_padded: [B, M, Cmax]
s: [B, M]
valid_mask: [M, Cmax]
"""
# --------------------------------------------------------
# 1) Basic shape validation
# --------------------------------------------------------
# Ensures x_cat_tokens is [B,M,H] and matches decoder config
B, M, _ = self._check_input(x_cat_tokens)
# --------------------------------------------------------
# 2) Per-column categorical logits
# --------------------------------------------------------
# We still use per-column heads because each column
# can have a different number of classes C_m.
#
# logits_list[m] shape: [B, C_m]
logits_list: List[torch.Tensor] = []
for m in range(M):
# x_cat_tokens[:, m, :] -> [B,H]
# heads[m] maps H -> C_m
logits_m = self.heads[m](x_cat_tokens[:, m, :])
logits_list.append(logits_m)
# --------------------------------------------------------
# 3) Sample-wise & column-wise uncertainty (log-variance)
# --------------------------------------------------------
# s_head processes all columns at once (grouped, no loop)
#
# Input: [B,M,H]
# Output: [B,M]
#
# s[b,m] = log(sigma_{b,m}^2)
if self.homoscedastic:
s = self.s_param.unsqueeze(0).expand(B, -1)
else:
s = self.s_head(x_cat_tokens).squeeze(-1)
# --------------------------------------------------------
# 4) If no padded output requested
# --------------------------------------------------------
if not return_padded:
# Return:
# logits_list: List of length M
# s: [B,M]
return logits_list, s
# --------------------------------------------------------
# 5) Build padded logits tensor
# --------------------------------------------------------
# We unify different C_m into a common Cmax.
#
# logits_padded shape: [B,M,Cmax]
cmax = int(self.num_classes.max().item())
if pad_value is None:
pad_value = torch.finfo(x_cat_tokens.dtype).min
logits_padded = torch.full(
(B, M, cmax),
pad_value,
device=x_cat_tokens.device,
dtype=x_cat_tokens.dtype,
)
# Fill valid class positions per column
for m in range(M):
cm = logits_list[m].size(-1) # C_m
logits_padded[:, m, :cm] = logits_list[m]
# --------------------------------------------------------
# 6) Build validity mask
# --------------------------------------------------------
# valid_mask[m,j] = True if j < C_m
# = False otherwise
#
# Shape: [M, Cmax]
valid_class_mask = self._build_valid_mask(device=x_cat_tokens.device)
# --------------------------------------------------------
# 7) Return padded outputs
# --------------------------------------------------------
return logits_padded, s, valid_class_mask
# ============================================================
# DEMO
# ============================================================
def _demo_main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--cat_vocab_json", type=str, default="data/cat_vocab.json")
parser.add_argument("--hidden_size", type=int, default=768)
parser.add_argument("--middle_size", type=int, default=None)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--device", type=str, default=None)
parser.add_argument("--dtype", type=str, default="float32", choices=["float16", "bfloat16", "float32"])
args = parser.parse_args()
device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu"))
dtype_map = {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32,
}
dtype = dtype_map[args.dtype]
# --------------------------------------------------------
# Load vocab spec
# --------------------------------------------------------
spec = load_json(args.cat_vocab_json)
items = sorted(spec.items(), key=lambda x_: x_[1]["col_id"])
M = len(items)
B = args.batch_size
H = args.hidden_size
num_classes = [int(s["num_classes"]) for _, s in items]
print("===== Categorical Columns =====")
for i, (name, s) in enumerate(items):
print(f"{i:03d} {name:20s} classes={s['num_classes']}")
print()
# --------------------------------------------------------
# Build model
# --------------------------------------------------------
model = CategoricalDecoder(
hidden_size=args.hidden_size,
cat_vocab_json=args.cat_vocab_json,
middle_size=args.middle_size,
).to(device=device, dtype=dtype)
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model parameters: {total_params:,} (trainable: {trainable_params:,})")
print()
# --------------------------------------------------------
# Fake input tokens
# --------------------------------------------------------
x = torch.randn(B, M, H, device=device, dtype=dtype)
print("Input tokens shape:", tuple(x.shape))
print()
# --------------------------------------------------------
# Case 1: logits_list
# --------------------------------------------------------
print("===== Forward: logits_list mode =====")
with torch.no_grad():
logits_list, s = model(x, return_padded=False)
for m, (name, spec_item) in enumerate(items):
C = spec_item["num_classes"]
print(f"{m:03d} {name:20s} logits:", tuple(logits_list[m].shape), f"(expected {(B, C)})")
print("s shape:", tuple(s.shape))
print()
# --------------------------------------------------------
# Case 2: padded logits
# --------------------------------------------------------
print("===== Forward: padded mode =====")
with torch.no_grad():
logits_padded, s2, valid_mask = model(x, return_padded=True)
print("logits_padded:", tuple(logits_padded.shape))
print("s:", tuple(s2.shape))
print("valid_mask:", tuple(valid_mask.shape))
print()
# --------------------------------------------------------
# Visualize valid mask
# --------------------------------------------------------
print("===== Valid class mask (first 10 columns) =====")
cols_to_show = min(10, M)
for m in range(cols_to_show):
cm = num_classes[m]
valid = valid_mask[m].sum().item()
print(f"col {m:02d} num_classes={cm} valid_mask_sum={valid}")
print()
# --------------------------------------------------------
# Check padded logits correctness
# --------------------------------------------------------
print("===== Padded logits sanity check =====")
for m in range(cols_to_show):
cm = num_classes[m]
valid_region = logits_padded[:, m, :cm]
padded_region = logits_padded[:, m, cm:]
print(f"col {m:02d} valid region shape:", tuple(valid_region.shape))
if padded_region.numel() > 0:
print(f"col {m:02d} padded region mean:", padded_region.mean().item())
print()
print("Demo finished successfully.")
if __name__ == "__main__":
_demo_main()