soilformer / modelling /embed_numeric.py
Kuangdai
Initial release of SoilFormer
6fb6c07
# embed_numeric.py
# -*- coding: utf-8 -*-
"""
Numeric embedding module for tabular transformer.
Updates in this version:
- numeric_vocab.json now includes:
- total_numeric_tokens
- group_token_offsets (by n_in)
- demo_main prints total parameter count
Design:
- scalar numeric (n_in=1): 1 token
- vector numeric (n_in=L): L tokens
- per bucket (same n_in): GroupedMLP with per-column weights (no for-loop over columns)
input : [B, V, n_in]
output : [B, V*n_in, H]
- middle_size:
- None: 1-layer
- int : 2-layer (Linear -> GELU -> Linear)
- NumericIdEmbedding:
- per numeric column id embedding [H]
- broadcast across that column's n_in tokens
"""
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from utils import load_json, save_json, GroupedMLP
# ============================================================
# Meta parsing
# ============================================================
def infer_n_in_from_meta_item(info: Dict) -> int:
return int(info["array_length"]) if info["is_array_valued"] else 1
def get_numeric_feature_names_and_dims_from_meta(tabular_meta: Dict) -> List[Tuple[str, int]]:
"""
Return list of (feature_name, n_in) for numeric features.
Heuristic:
- info['dataclass'] == 'numeric' is treated as numeric.
"""
out: List[Tuple[str, int]] = []
for name, info in tabular_meta.items():
if info.get("dataclass") != "numeric":
continue
n_in = infer_n_in_from_meta_item(info)
out.append((name, n_in))
# deterministic: group by n_in then name
out.sort(key=lambda x: (x[1], x[0]))
return out
# ============================================================
# Vocab/spec building
# ============================================================
@dataclass
class NumColSpec:
name: str
col_id: int
n_in: int
group_index: int
index_within_group: int
def build_numeric_vocab_spec_from_meta(tabular_meta: Dict) -> Dict:
"""
Build numeric_vocab.json dict.
Output keys:
- ordered_feature_names
- features[name] = {col_id, n_in, group_index, index_within_group}
- groups = [{n_in, feature_names}, ...] sorted by n_in asc
- total_numeric_tokens
- group_token_offsets: { "<n_in>": <start_token_index> }
token order is groups by n_in asc, within group by feature name
"""
feats = get_numeric_feature_names_and_dims_from_meta(tabular_meta)
if not feats:
raise ValueError("No numeric features found (dataclass=='numeric').")
# group by n_in
groups_map: Dict[int, List[str]] = {}
for name, n_in in feats:
groups_map.setdefault(n_in, []).append(name)
for n_in in groups_map:
groups_map[n_in] = sorted(groups_map[n_in])
group_nins = sorted(groups_map.keys())
groups: List[Dict] = []
ordered_feature_names: List[str] = []
for n_in in group_nins:
names = groups_map[n_in]
groups.append({"n_in": int(n_in), "feature_names": names})
ordered_feature_names.extend(names)
# build per-feature mapping
name_to_group: Dict[str, Tuple[int, int]] = {}
for gi, g in enumerate(groups):
for idx, nm in enumerate(g["feature_names"]):
name_to_group[nm] = (gi, idx)
features: Dict[str, Dict] = {}
for col_id, nm in enumerate(ordered_feature_names):
gi, idx = name_to_group[nm]
n_in = int(groups[gi]["n_in"])
features[nm] = {
"col_id": int(col_id),
"n_in": int(n_in),
"group_index": int(gi),
"index_within_group": int(idx),
}
# total tokens + group token offsets
total_numeric_tokens = 0
group_token_offsets: Dict[str, int] = {}
running = 0
for g in groups:
n_in = int(g["n_in"])
group_token_offsets[str(n_in)] = int(running)
V = len(g["feature_names"])
running += V * n_in
total_numeric_tokens += V * n_in
spec = {
"ordered_feature_names": ordered_feature_names,
"features": features,
"groups": groups,
"total_numeric_tokens": int(total_numeric_tokens),
"group_token_offsets": group_token_offsets, # keys are strings to be JSON-friendly
}
return spec
# ============================================================
# Core modules
# ============================================================
class NumericIdEmbedding(nn.Module):
"""
Per-numeric-column ID embedding in the GLOBAL numeric namespace.
Broadcast each global column id vector across its n_in tokens.
"""
def __init__(self, num_numeric_cols: int, hidden_size: int):
super().__init__()
self.num_numeric_cols = int(num_numeric_cols)
self.hidden_size = int(hidden_size)
self.emb = nn.Embedding(self.num_numeric_cols, self.hidden_size)
def forward(self, global_col_ids: torch.LongTensor, batch_size: int, n_in: int) -> torch.Tensor:
"""
global_col_ids: [V] in global numeric namespace
returns: [B, V*n_in, H]
"""
if global_col_ids.dim() != 1:
raise ValueError(f"global_col_ids must be [V], got {tuple(global_col_ids.shape)}")
V = global_col_ids.numel()
n_in = int(n_in)
id_vec = self.emb(global_col_ids) # [V, H]
id_vec = id_vec.view(1, V, 1, self.hidden_size).expand(batch_size, V, n_in, self.hidden_size)
return id_vec.reshape(batch_size, V * n_in, self.hidden_size)
def init_weights(self, std: float = 0.02):
nn.init.normal_(self.emb.weight, std=std)
class NumericMaskEmbedding(nn.Module):
"""
Per-bucket numeric mask embedding.
Local to one (n_in) group / bucket.
Parameter shape:
[num_bucket_cols, n_in, H]
So missing numeric columns are represented by:
(bucket-local column index, sub-token index)
"""
def __init__(self, num_bucket_cols: int, n_in: int, hidden_size: int):
super().__init__()
self.num_bucket_cols = int(num_bucket_cols)
self.n_in = int(n_in)
self.hidden_size = int(hidden_size)
self.emb = nn.Parameter(
torch.empty(self.num_bucket_cols, self.n_in, self.hidden_size)
)
def forward(self, local_col_ids: torch.LongTensor, batch_size: int) -> torch.Tensor:
"""
local_col_ids: [V] bucket-local ids, usually 0 to V-1
returns: [B, V*n_in, H]
"""
if local_col_ids.dim() != 1:
raise ValueError(f"local_col_ids must be [V], got {tuple(local_col_ids.shape)}")
V = local_col_ids.numel()
mask_vec = self.emb[local_col_ids] # [V, n_in, H]
mask_vec = mask_vec.unsqueeze(0).expand(batch_size, V, self.n_in, self.hidden_size)
return mask_vec.reshape(batch_size, V * self.n_in, self.hidden_size)
def init_weights(self, std: float = 0.02):
nn.init.normal_(self.emb, std=std)
class NumericEmbedding(nn.Module):
"""
Full numeric embedding for all numeric columns described by numeric_vocab.json.
Forward expects bucketed input:
values_by_nin: { n_in: x[B, V, n_in] }
where V must match the feature count and order of that n_in group.
Output token ordering:
groups by n_in ascending (as stored in spec["groups"]),
within each group by feature_names order.
"""
def __init__(self, hidden_size: int, numeric_vocab_json: str, middle_size: Optional[int] = None):
super().__init__()
self.hidden_size = int(hidden_size)
self.middle_size = None if middle_size is None else int(middle_size)
spec = load_json(numeric_vocab_json)
self.ordered_feature_names: List[str] = list(spec["ordered_feature_names"])
self.features: Dict[str, Dict] = dict(spec["features"])
self.groups: List[Dict] = list(spec["groups"])
self.total_numeric_tokens = int(spec.get("total_numeric_tokens", -1))
num_cols = len(self.ordered_feature_names)
# Global numeric namespace id embedding
self.id_emb = NumericIdEmbedding(
num_numeric_cols=num_cols,
hidden_size=self.hidden_size,
)
# Per-group mask embedding
self.mask_emb = nn.ModuleDict()
# Per-group value embedding
self.group_mlps = nn.ModuleList()
self.group_nins: List[int] = []
self._num_groups = len(self.groups)
# Optional: useful for debugging / downstream checks
self.group_sizes: List[int] = []
# Build one block per group
for gi, g in enumerate(self.groups):
n_in = int(g["n_in"])
names = list(g["feature_names"])
V = len(names)
self.group_nins.append(n_in)
self.group_sizes.append(V)
# ---- spec consistency check
# group_index and index_within_group in features must match groups[gi]["feature_names"] order
local_ids = []
for local_idx, nm in enumerate(names):
f = self.features[nm]
if int(f["group_index"]) != gi:
raise ValueError(
f"Feature {nm} has group_index={f['group_index']}, expected {gi}"
)
if int(f["n_in"]) != n_in:
raise ValueError(
f"Feature {nm} has n_in={f['n_in']}, expected {n_in}"
)
if int(f["index_within_group"]) != local_idx:
raise ValueError(
f"Feature {nm} has index_within_group={f['index_within_group']}, expected {local_idx}"
)
local_ids.append(int(f["index_within_group"]))
# strict check: local ids must be exactly 0 to V-1 with no gap / no duplicate
if sorted(local_ids) != list(range(V)):
raise ValueError(
f"Group gi={gi}, n_in={n_in} has invalid index_within_group set: "
f"got {sorted(local_ids)}, expected {list(range(V))}"
)
# ---- observed value path: bucket-local ordering
self.group_mlps.append(
GroupedMLP(
n_var=V,
n_in=n_in,
n_out=n_in * self.hidden_size,
middle_size=self.middle_size,
)
)
# ---- global ids for NumericIdEmbedding
global_col_ids = [int(self.features[nm]["col_id"]) for nm in names]
self.register_buffer(
f"group_global_col_ids_{gi}",
torch.tensor(global_col_ids, dtype=torch.long),
persistent=True,
)
# ---- local ids for NumericMaskEmbedding
local_col_ids = [int(self.features[nm]["index_within_group"]) for nm in names]
self.register_buffer(
f"group_local_col_ids_{gi}",
torch.tensor(local_col_ids, dtype=torch.long),
persistent=True,
)
# one mask embedding per bucket
self.mask_emb[str(n_in)] = NumericMaskEmbedding(
num_bucket_cols=V,
n_in=n_in,
hidden_size=self.hidden_size,
)
if self.total_numeric_tokens < 0:
self.total_numeric_tokens = sum(
len(g["feature_names"]) * int(g["n_in"]) for g in self.groups
)
def init_weights(self, std: float = 0.02):
self.id_emb.init_weights(std=std)
for _, mask_mod in self.mask_emb.items():
mask_mod.init_weights(std=std)
for mlp in self.group_mlps:
mlp.init_weights(std=std)
def forward(
self,
values_by_nin: Dict[int, torch.Tensor],
valid_positions_by_nin: Optional[Dict[int, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
values_by_nin:
{ n_in: x } where x is [B, V, n_in]
Missing numeric values are assumed already filled in x.
valid_positions_by_nin (optional):
{ n_in: valid_cols } where valid_cols is BoolTensor [B, V]
True means this COLUMN is observed/valid.
Note:
This is COLUMN-level mask, not token-level.
It is expanded to token-level by repeating across n_in.
Returns:
tokens: [B, total_numeric_tokens, H]
token_mask: [B, total_numeric_tokens] (1=valid, 0=missing)
"""
outs = []
masks = []
batch_size = None
for gi, n_in in enumerate(self.group_nins):
if n_in not in values_by_nin:
raise KeyError(f"Missing bucket input for n_in={n_in}")
x = values_by_nin[n_in] # [B, V, n_in]
if x.dim() != 3 or x.size(-1) != n_in:
raise ValueError(f"Bucket n_in={n_in} expects x [B,V,{n_in}], got {tuple(x.shape)}")
if batch_size is None:
batch_size = x.size(0)
elif x.size(0) != batch_size:
raise ValueError("All buckets must share the same batch size")
B, V, _ = x.shape
expected_V = self.group_sizes[gi]
if V != expected_V:
raise ValueError(
f"Bucket n_in={n_in} expects V={expected_V}, got V={V}"
)
# column-level valid mask [B, V]
if valid_positions_by_nin is None:
valid_cols = torch.ones((B, V), dtype=torch.bool, device=x.device)
else:
if n_in not in valid_positions_by_nin:
raise KeyError(f"Missing valid mask for bucket n_in={n_in}")
valid_cols = valid_positions_by_nin[n_in]
if valid_cols.dtype != torch.bool:
raise ValueError(
f"valid_positions_by_nin[{n_in}] must be bool tensor, got {valid_cols.dtype}"
)
if valid_cols.shape != (B, V):
raise ValueError(
f"valid_positions_by_nin[{n_in}] must be [B,V]=[{B},{V}], got {tuple(valid_cols.shape)}"
)
valid_cols = valid_cols.to(device=x.device)
# ---- observed numeric value embedding
mlp = self.group_mlps[gi]
param = next(mlp.parameters())
x = x.to(device=param.device, dtype=param.dtype)
# [B, V, n_in] -> [B, V, n_in*H]
y = mlp(x)
# [B, V, n_in*H] -> [B, V*n_in, H]
y_tok = y.view(B, V, n_in, self.hidden_size).reshape(B, V * n_in, self.hidden_size)
# [B, V] -> [B, V*n_in]
valid_tok = valid_cols.unsqueeze(-1).expand(B, V, n_in).reshape(B, V * n_in)
# ---- missing replacement: bucket-local mask embedding
local_col_ids = getattr(self, f"group_local_col_ids_{gi}") # [V]
mask_tok = self.mask_emb[str(n_in)](local_col_ids, batch_size=B)
if (~valid_tok).any():
y_tok = torch.where(
valid_tok.unsqueeze(-1),
y_tok,
mask_tok,
)
# ---- add global numeric column id embedding
global_col_ids = getattr(self, f"group_global_col_ids_{gi}") # [V]
y_tok = y_tok + self.id_emb(global_col_ids, batch_size=B, n_in=n_in)
token_mask = valid_tok.to(dtype=torch.long)
outs.append(y_tok)
masks.append(token_mask)
tokens = torch.cat(outs, dim=1)
token_mask = torch.cat(masks, dim=1)
if token_mask.shape[:2] != tokens.shape[:2]:
raise RuntimeError("token_mask shape mismatch with tokens")
return tokens, token_mask
# ============================================================
# DEMO
# ============================================================
def _demo_main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--tabular_meta", type=str, default="data/tabular_meta.json")
parser.add_argument("--numeric_vocab_json", type=str, default="data/numeric_vocab.json")
parser.add_argument("--hidden_size", type=int, default=768)
parser.add_argument("--middle_size", type=int, default=-1,
help="If <0 -> one-layer. If >=0 -> two-layer with this middle size.")
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--device", type=str, default=None)
parser.add_argument("--dtype", type=str, default="float32", choices=["float16", "bfloat16", "float32"])
args = parser.parse_args()
device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu"))
dtype_map = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}
dtype = dtype_map[args.dtype]
meta = load_json(args.tabular_meta)
spec = build_numeric_vocab_spec_from_meta(meta)
save_json(spec, args.numeric_vocab_json)
print(f"Saved numeric vocab spec to: {args.numeric_vocab_json}")
print(f"Groups (n_in -> V):", {g["n_in"]: len(g["feature_names"]) for g in spec["groups"]})
print("total_numeric_tokens:", spec["total_numeric_tokens"])
print("group_token_offsets:", spec["group_token_offsets"])
middle_size = None if args.middle_size < 0 else int(args.middle_size)
model = NumericEmbedding(
hidden_size=args.hidden_size,
numeric_vocab_json=args.numeric_vocab_json,
middle_size=middle_size,
).to(device=device, dtype=dtype)
model.init_weights()
model.eval()
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters (NumericEmbedding): {total_params:,} (trainable: {trainable_params:,})")
# create demo inputs bucketed by n_in
B = args.batch_size
values_by_nin: Dict[int, torch.Tensor] = {}
valid_positions_by_nin: Dict[int, torch.Tensor] = {}
for g in spec["groups"]:
n_in = int(g["n_in"])
V = len(g["feature_names"])
# random numeric inputs
x = torch.randn(B, V, n_in, device=device, dtype=dtype)
values_by_nin[n_in] = x
# Build valid mask (column-level)
# shape: [B, V], True = valid
valid_cols = torch.ones((B, V), dtype=torch.bool, device=device)
# Mark first sample's first 2 columns as invalid
num_to_invalidate = min(2, V)
valid_cols[0, :num_to_invalidate] = False
valid_positions_by_nin[n_in] = valid_cols
with torch.no_grad():
out, mask = model(values_by_nin, valid_positions_by_nin)
print("Buckets:", {k: tuple(v.shape) for k, v in values_by_nin.items()})
print("Output tokens:", tuple(out.shape), out.dtype, out.device) # [B, total_numeric_tokens, H]
print("Masks:", tuple(mask.shape), mask.dtype, mask.device) # [B, total_numeric_tokens]
# ---- Inspect first sample
print("\nFirst sample mask (first 5 tokens):")
print(mask[0, :5])
print("\nFirst sample token L2 norms (first 5 tokens):")
print(out[0, :5].norm(dim=-1))
print("\nSecond sample mask (first 5 tokens):")
print(mask[1, :5])
print("\nSecond sample token L2 norms (first 5 tokens):")
print(out[1, :5].norm(dim=-1))
if __name__ == "__main__":
_demo_main()