# embed_numeric.py # -*- coding: utf-8 -*- """ Numeric embedding module for tabular transformer. Updates in this version: - numeric_vocab.json now includes: - total_numeric_tokens - group_token_offsets (by n_in) - demo_main prints total parameter count Design: - scalar numeric (n_in=1): 1 token - vector numeric (n_in=L): L tokens - per bucket (same n_in): GroupedMLP with per-column weights (no for-loop over columns) input : [B, V, n_in] output : [B, V*n_in, H] - middle_size: - None: 1-layer - int : 2-layer (Linear -> GELU -> Linear) - NumericIdEmbedding: - per numeric column id embedding [H] - broadcast across that column's n_in tokens """ from dataclasses import dataclass from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn from utils import load_json, save_json, GroupedMLP # ============================================================ # Meta parsing # ============================================================ def infer_n_in_from_meta_item(info: Dict) -> int: return int(info["array_length"]) if info["is_array_valued"] else 1 def get_numeric_feature_names_and_dims_from_meta(tabular_meta: Dict) -> List[Tuple[str, int]]: """ Return list of (feature_name, n_in) for numeric features. Heuristic: - info['dataclass'] == 'numeric' is treated as numeric. """ out: List[Tuple[str, int]] = [] for name, info in tabular_meta.items(): if info.get("dataclass") != "numeric": continue n_in = infer_n_in_from_meta_item(info) out.append((name, n_in)) # deterministic: group by n_in then name out.sort(key=lambda x: (x[1], x[0])) return out # ============================================================ # Vocab/spec building # ============================================================ @dataclass class NumColSpec: name: str col_id: int n_in: int group_index: int index_within_group: int def build_numeric_vocab_spec_from_meta(tabular_meta: Dict) -> Dict: """ Build numeric_vocab.json dict. Output keys: - ordered_feature_names - features[name] = {col_id, n_in, group_index, index_within_group} - groups = [{n_in, feature_names}, ...] sorted by n_in asc - total_numeric_tokens - group_token_offsets: { "": } token order is groups by n_in asc, within group by feature name """ feats = get_numeric_feature_names_and_dims_from_meta(tabular_meta) if not feats: raise ValueError("No numeric features found (dataclass=='numeric').") # group by n_in groups_map: Dict[int, List[str]] = {} for name, n_in in feats: groups_map.setdefault(n_in, []).append(name) for n_in in groups_map: groups_map[n_in] = sorted(groups_map[n_in]) group_nins = sorted(groups_map.keys()) groups: List[Dict] = [] ordered_feature_names: List[str] = [] for n_in in group_nins: names = groups_map[n_in] groups.append({"n_in": int(n_in), "feature_names": names}) ordered_feature_names.extend(names) # build per-feature mapping name_to_group: Dict[str, Tuple[int, int]] = {} for gi, g in enumerate(groups): for idx, nm in enumerate(g["feature_names"]): name_to_group[nm] = (gi, idx) features: Dict[str, Dict] = {} for col_id, nm in enumerate(ordered_feature_names): gi, idx = name_to_group[nm] n_in = int(groups[gi]["n_in"]) features[nm] = { "col_id": int(col_id), "n_in": int(n_in), "group_index": int(gi), "index_within_group": int(idx), } # total tokens + group token offsets total_numeric_tokens = 0 group_token_offsets: Dict[str, int] = {} running = 0 for g in groups: n_in = int(g["n_in"]) group_token_offsets[str(n_in)] = int(running) V = len(g["feature_names"]) running += V * n_in total_numeric_tokens += V * n_in spec = { "ordered_feature_names": ordered_feature_names, "features": features, "groups": groups, "total_numeric_tokens": int(total_numeric_tokens), "group_token_offsets": group_token_offsets, # keys are strings to be JSON-friendly } return spec # ============================================================ # Core modules # ============================================================ class NumericIdEmbedding(nn.Module): """ Per-numeric-column ID embedding in the GLOBAL numeric namespace. Broadcast each global column id vector across its n_in tokens. """ def __init__(self, num_numeric_cols: int, hidden_size: int): super().__init__() self.num_numeric_cols = int(num_numeric_cols) self.hidden_size = int(hidden_size) self.emb = nn.Embedding(self.num_numeric_cols, self.hidden_size) def forward(self, global_col_ids: torch.LongTensor, batch_size: int, n_in: int) -> torch.Tensor: """ global_col_ids: [V] in global numeric namespace returns: [B, V*n_in, H] """ if global_col_ids.dim() != 1: raise ValueError(f"global_col_ids must be [V], got {tuple(global_col_ids.shape)}") V = global_col_ids.numel() n_in = int(n_in) id_vec = self.emb(global_col_ids) # [V, H] id_vec = id_vec.view(1, V, 1, self.hidden_size).expand(batch_size, V, n_in, self.hidden_size) return id_vec.reshape(batch_size, V * n_in, self.hidden_size) def init_weights(self, std: float = 0.02): nn.init.normal_(self.emb.weight, std=std) class NumericMaskEmbedding(nn.Module): """ Per-bucket numeric mask embedding. Local to one (n_in) group / bucket. Parameter shape: [num_bucket_cols, n_in, H] So missing numeric columns are represented by: (bucket-local column index, sub-token index) """ def __init__(self, num_bucket_cols: int, n_in: int, hidden_size: int): super().__init__() self.num_bucket_cols = int(num_bucket_cols) self.n_in = int(n_in) self.hidden_size = int(hidden_size) self.emb = nn.Parameter( torch.empty(self.num_bucket_cols, self.n_in, self.hidden_size) ) def forward(self, local_col_ids: torch.LongTensor, batch_size: int) -> torch.Tensor: """ local_col_ids: [V] bucket-local ids, usually 0 to V-1 returns: [B, V*n_in, H] """ if local_col_ids.dim() != 1: raise ValueError(f"local_col_ids must be [V], got {tuple(local_col_ids.shape)}") V = local_col_ids.numel() mask_vec = self.emb[local_col_ids] # [V, n_in, H] mask_vec = mask_vec.unsqueeze(0).expand(batch_size, V, self.n_in, self.hidden_size) return mask_vec.reshape(batch_size, V * self.n_in, self.hidden_size) def init_weights(self, std: float = 0.02): nn.init.normal_(self.emb, std=std) class NumericEmbedding(nn.Module): """ Full numeric embedding for all numeric columns described by numeric_vocab.json. Forward expects bucketed input: values_by_nin: { n_in: x[B, V, n_in] } where V must match the feature count and order of that n_in group. Output token ordering: groups by n_in ascending (as stored in spec["groups"]), within each group by feature_names order. """ def __init__(self, hidden_size: int, numeric_vocab_json: str, middle_size: Optional[int] = None): super().__init__() self.hidden_size = int(hidden_size) self.middle_size = None if middle_size is None else int(middle_size) spec = load_json(numeric_vocab_json) self.ordered_feature_names: List[str] = list(spec["ordered_feature_names"]) self.features: Dict[str, Dict] = dict(spec["features"]) self.groups: List[Dict] = list(spec["groups"]) self.total_numeric_tokens = int(spec.get("total_numeric_tokens", -1)) num_cols = len(self.ordered_feature_names) # Global numeric namespace id embedding self.id_emb = NumericIdEmbedding( num_numeric_cols=num_cols, hidden_size=self.hidden_size, ) # Per-group mask embedding self.mask_emb = nn.ModuleDict() # Per-group value embedding self.group_mlps = nn.ModuleList() self.group_nins: List[int] = [] self._num_groups = len(self.groups) # Optional: useful for debugging / downstream checks self.group_sizes: List[int] = [] # Build one block per group for gi, g in enumerate(self.groups): n_in = int(g["n_in"]) names = list(g["feature_names"]) V = len(names) self.group_nins.append(n_in) self.group_sizes.append(V) # ---- spec consistency check # group_index and index_within_group in features must match groups[gi]["feature_names"] order local_ids = [] for local_idx, nm in enumerate(names): f = self.features[nm] if int(f["group_index"]) != gi: raise ValueError( f"Feature {nm} has group_index={f['group_index']}, expected {gi}" ) if int(f["n_in"]) != n_in: raise ValueError( f"Feature {nm} has n_in={f['n_in']}, expected {n_in}" ) if int(f["index_within_group"]) != local_idx: raise ValueError( f"Feature {nm} has index_within_group={f['index_within_group']}, expected {local_idx}" ) local_ids.append(int(f["index_within_group"])) # strict check: local ids must be exactly 0 to V-1 with no gap / no duplicate if sorted(local_ids) != list(range(V)): raise ValueError( f"Group gi={gi}, n_in={n_in} has invalid index_within_group set: " f"got {sorted(local_ids)}, expected {list(range(V))}" ) # ---- observed value path: bucket-local ordering self.group_mlps.append( GroupedMLP( n_var=V, n_in=n_in, n_out=n_in * self.hidden_size, middle_size=self.middle_size, ) ) # ---- global ids for NumericIdEmbedding global_col_ids = [int(self.features[nm]["col_id"]) for nm in names] self.register_buffer( f"group_global_col_ids_{gi}", torch.tensor(global_col_ids, dtype=torch.long), persistent=True, ) # ---- local ids for NumericMaskEmbedding local_col_ids = [int(self.features[nm]["index_within_group"]) for nm in names] self.register_buffer( f"group_local_col_ids_{gi}", torch.tensor(local_col_ids, dtype=torch.long), persistent=True, ) # one mask embedding per bucket self.mask_emb[str(n_in)] = NumericMaskEmbedding( num_bucket_cols=V, n_in=n_in, hidden_size=self.hidden_size, ) if self.total_numeric_tokens < 0: self.total_numeric_tokens = sum( len(g["feature_names"]) * int(g["n_in"]) for g in self.groups ) def init_weights(self, std: float = 0.02): self.id_emb.init_weights(std=std) for _, mask_mod in self.mask_emb.items(): mask_mod.init_weights(std=std) for mlp in self.group_mlps: mlp.init_weights(std=std) def forward( self, values_by_nin: Dict[int, torch.Tensor], valid_positions_by_nin: Optional[Dict[int, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: values_by_nin: { n_in: x } where x is [B, V, n_in] Missing numeric values are assumed already filled in x. valid_positions_by_nin (optional): { n_in: valid_cols } where valid_cols is BoolTensor [B, V] True means this COLUMN is observed/valid. Note: This is COLUMN-level mask, not token-level. It is expanded to token-level by repeating across n_in. Returns: tokens: [B, total_numeric_tokens, H] token_mask: [B, total_numeric_tokens] (1=valid, 0=missing) """ outs = [] masks = [] batch_size = None for gi, n_in in enumerate(self.group_nins): if n_in not in values_by_nin: raise KeyError(f"Missing bucket input for n_in={n_in}") x = values_by_nin[n_in] # [B, V, n_in] if x.dim() != 3 or x.size(-1) != n_in: raise ValueError(f"Bucket n_in={n_in} expects x [B,V,{n_in}], got {tuple(x.shape)}") if batch_size is None: batch_size = x.size(0) elif x.size(0) != batch_size: raise ValueError("All buckets must share the same batch size") B, V, _ = x.shape expected_V = self.group_sizes[gi] if V != expected_V: raise ValueError( f"Bucket n_in={n_in} expects V={expected_V}, got V={V}" ) # column-level valid mask [B, V] if valid_positions_by_nin is None: valid_cols = torch.ones((B, V), dtype=torch.bool, device=x.device) else: if n_in not in valid_positions_by_nin: raise KeyError(f"Missing valid mask for bucket n_in={n_in}") valid_cols = valid_positions_by_nin[n_in] if valid_cols.dtype != torch.bool: raise ValueError( f"valid_positions_by_nin[{n_in}] must be bool tensor, got {valid_cols.dtype}" ) if valid_cols.shape != (B, V): raise ValueError( f"valid_positions_by_nin[{n_in}] must be [B,V]=[{B},{V}], got {tuple(valid_cols.shape)}" ) valid_cols = valid_cols.to(device=x.device) # ---- observed numeric value embedding mlp = self.group_mlps[gi] param = next(mlp.parameters()) x = x.to(device=param.device, dtype=param.dtype) # [B, V, n_in] -> [B, V, n_in*H] y = mlp(x) # [B, V, n_in*H] -> [B, V*n_in, H] y_tok = y.view(B, V, n_in, self.hidden_size).reshape(B, V * n_in, self.hidden_size) # [B, V] -> [B, V*n_in] valid_tok = valid_cols.unsqueeze(-1).expand(B, V, n_in).reshape(B, V * n_in) # ---- missing replacement: bucket-local mask embedding local_col_ids = getattr(self, f"group_local_col_ids_{gi}") # [V] mask_tok = self.mask_emb[str(n_in)](local_col_ids, batch_size=B) if (~valid_tok).any(): y_tok = torch.where( valid_tok.unsqueeze(-1), y_tok, mask_tok, ) # ---- add global numeric column id embedding global_col_ids = getattr(self, f"group_global_col_ids_{gi}") # [V] y_tok = y_tok + self.id_emb(global_col_ids, batch_size=B, n_in=n_in) token_mask = valid_tok.to(dtype=torch.long) outs.append(y_tok) masks.append(token_mask) tokens = torch.cat(outs, dim=1) token_mask = torch.cat(masks, dim=1) if token_mask.shape[:2] != tokens.shape[:2]: raise RuntimeError("token_mask shape mismatch with tokens") return tokens, token_mask # ============================================================ # DEMO # ============================================================ def _demo_main(): import argparse parser = argparse.ArgumentParser() parser.add_argument("--tabular_meta", type=str, default="data/tabular_meta.json") parser.add_argument("--numeric_vocab_json", type=str, default="data/numeric_vocab.json") parser.add_argument("--hidden_size", type=int, default=768) parser.add_argument("--middle_size", type=int, default=-1, help="If <0 -> one-layer. If >=0 -> two-layer with this middle size.") parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--device", type=str, default=None) parser.add_argument("--dtype", type=str, default="float32", choices=["float16", "bfloat16", "float32"]) args = parser.parse_args() device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu")) dtype_map = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32} dtype = dtype_map[args.dtype] meta = load_json(args.tabular_meta) spec = build_numeric_vocab_spec_from_meta(meta) save_json(spec, args.numeric_vocab_json) print(f"Saved numeric vocab spec to: {args.numeric_vocab_json}") print(f"Groups (n_in -> V):", {g["n_in"]: len(g["feature_names"]) for g in spec["groups"]}) print("total_numeric_tokens:", spec["total_numeric_tokens"]) print("group_token_offsets:", spec["group_token_offsets"]) middle_size = None if args.middle_size < 0 else int(args.middle_size) model = NumericEmbedding( hidden_size=args.hidden_size, numeric_vocab_json=args.numeric_vocab_json, middle_size=middle_size, ).to(device=device, dtype=dtype) model.init_weights() model.eval() total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Total parameters (NumericEmbedding): {total_params:,} (trainable: {trainable_params:,})") # create demo inputs bucketed by n_in B = args.batch_size values_by_nin: Dict[int, torch.Tensor] = {} valid_positions_by_nin: Dict[int, torch.Tensor] = {} for g in spec["groups"]: n_in = int(g["n_in"]) V = len(g["feature_names"]) # random numeric inputs x = torch.randn(B, V, n_in, device=device, dtype=dtype) values_by_nin[n_in] = x # Build valid mask (column-level) # shape: [B, V], True = valid valid_cols = torch.ones((B, V), dtype=torch.bool, device=device) # Mark first sample's first 2 columns as invalid num_to_invalidate = min(2, V) valid_cols[0, :num_to_invalidate] = False valid_positions_by_nin[n_in] = valid_cols with torch.no_grad(): out, mask = model(values_by_nin, valid_positions_by_nin) print("Buckets:", {k: tuple(v.shape) for k, v in values_by_nin.items()}) print("Output tokens:", tuple(out.shape), out.dtype, out.device) # [B, total_numeric_tokens, H] print("Masks:", tuple(mask.shape), mask.dtype, mask.device) # [B, total_numeric_tokens] # ---- Inspect first sample print("\nFirst sample mask (first 5 tokens):") print(mask[0, :5]) print("\nFirst sample token L2 norms (first 5 tokens):") print(out[0, :5].norm(dim=-1)) print("\nSecond sample mask (first 5 tokens):") print(mask[1, :5]) print("\nSecond sample token L2 norms (first 5 tokens):") print(out[1, :5].norm(dim=-1)) if __name__ == "__main__": _demo_main()