| |
| |
|
|
| """ |
| Numeric embedding module for tabular transformer. |
| |
| Updates in this version: |
| - numeric_vocab.json now includes: |
| - total_numeric_tokens |
| - group_token_offsets (by n_in) |
| - demo_main prints total parameter count |
| |
| Design: |
| - scalar numeric (n_in=1): 1 token |
| - vector numeric (n_in=L): L tokens |
| - per bucket (same n_in): GroupedMLP with per-column weights (no for-loop over columns) |
| input : [B, V, n_in] |
| output : [B, V*n_in, H] |
| - middle_size: |
| - None: 1-layer |
| - int : 2-layer (Linear -> GELU -> Linear) |
| - NumericIdEmbedding: |
| - per numeric column id embedding [H] |
| - broadcast across that column's n_in tokens |
| """ |
|
|
| from dataclasses import dataclass |
| from typing import Dict, List, Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from utils import load_json, save_json, GroupedMLP |
|
|
|
|
| |
| |
| |
|
|
| def infer_n_in_from_meta_item(info: Dict) -> int: |
| return int(info["array_length"]) if info["is_array_valued"] else 1 |
|
|
|
|
| def get_numeric_feature_names_and_dims_from_meta(tabular_meta: Dict) -> List[Tuple[str, int]]: |
| """ |
| Return list of (feature_name, n_in) for numeric features. |
| |
| Heuristic: |
| - info['dataclass'] == 'numeric' is treated as numeric. |
| """ |
| out: List[Tuple[str, int]] = [] |
| for name, info in tabular_meta.items(): |
| if info.get("dataclass") != "numeric": |
| continue |
| n_in = infer_n_in_from_meta_item(info) |
| out.append((name, n_in)) |
| |
| out.sort(key=lambda x: (x[1], x[0])) |
| return out |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class NumColSpec: |
| name: str |
| col_id: int |
| n_in: int |
| group_index: int |
| index_within_group: int |
|
|
|
|
| def build_numeric_vocab_spec_from_meta(tabular_meta: Dict) -> Dict: |
| """ |
| Build numeric_vocab.json dict. |
| |
| Output keys: |
| - ordered_feature_names |
| - features[name] = {col_id, n_in, group_index, index_within_group} |
| - groups = [{n_in, feature_names}, ...] sorted by n_in asc |
| - total_numeric_tokens |
| - group_token_offsets: { "<n_in>": <start_token_index> } |
| token order is groups by n_in asc, within group by feature name |
| """ |
| feats = get_numeric_feature_names_and_dims_from_meta(tabular_meta) |
| if not feats: |
| raise ValueError("No numeric features found (dataclass=='numeric').") |
|
|
| |
| groups_map: Dict[int, List[str]] = {} |
| for name, n_in in feats: |
| groups_map.setdefault(n_in, []).append(name) |
|
|
| for n_in in groups_map: |
| groups_map[n_in] = sorted(groups_map[n_in]) |
|
|
| group_nins = sorted(groups_map.keys()) |
|
|
| groups: List[Dict] = [] |
| ordered_feature_names: List[str] = [] |
|
|
| for n_in in group_nins: |
| names = groups_map[n_in] |
| groups.append({"n_in": int(n_in), "feature_names": names}) |
| ordered_feature_names.extend(names) |
|
|
| |
| name_to_group: Dict[str, Tuple[int, int]] = {} |
| for gi, g in enumerate(groups): |
| for idx, nm in enumerate(g["feature_names"]): |
| name_to_group[nm] = (gi, idx) |
|
|
| features: Dict[str, Dict] = {} |
| for col_id, nm in enumerate(ordered_feature_names): |
| gi, idx = name_to_group[nm] |
| n_in = int(groups[gi]["n_in"]) |
| features[nm] = { |
| "col_id": int(col_id), |
| "n_in": int(n_in), |
| "group_index": int(gi), |
| "index_within_group": int(idx), |
| } |
|
|
| |
| total_numeric_tokens = 0 |
| group_token_offsets: Dict[str, int] = {} |
| running = 0 |
| for g in groups: |
| n_in = int(g["n_in"]) |
| group_token_offsets[str(n_in)] = int(running) |
| V = len(g["feature_names"]) |
| running += V * n_in |
| total_numeric_tokens += V * n_in |
|
|
| spec = { |
| "ordered_feature_names": ordered_feature_names, |
| "features": features, |
| "groups": groups, |
| "total_numeric_tokens": int(total_numeric_tokens), |
| "group_token_offsets": group_token_offsets, |
| } |
| return spec |
|
|
|
|
| |
| |
| |
|
|
| class NumericIdEmbedding(nn.Module): |
| """ |
| Per-numeric-column ID embedding in the GLOBAL numeric namespace. |
| Broadcast each global column id vector across its n_in tokens. |
| """ |
|
|
| def __init__(self, num_numeric_cols: int, hidden_size: int): |
| super().__init__() |
| self.num_numeric_cols = int(num_numeric_cols) |
| self.hidden_size = int(hidden_size) |
| self.emb = nn.Embedding(self.num_numeric_cols, self.hidden_size) |
|
|
| def forward(self, global_col_ids: torch.LongTensor, batch_size: int, n_in: int) -> torch.Tensor: |
| """ |
| global_col_ids: [V] in global numeric namespace |
| returns: [B, V*n_in, H] |
| """ |
| if global_col_ids.dim() != 1: |
| raise ValueError(f"global_col_ids must be [V], got {tuple(global_col_ids.shape)}") |
|
|
| V = global_col_ids.numel() |
| n_in = int(n_in) |
|
|
| id_vec = self.emb(global_col_ids) |
| id_vec = id_vec.view(1, V, 1, self.hidden_size).expand(batch_size, V, n_in, self.hidden_size) |
| return id_vec.reshape(batch_size, V * n_in, self.hidden_size) |
|
|
| def init_weights(self, std: float = 0.02): |
| nn.init.normal_(self.emb.weight, std=std) |
|
|
|
|
| class NumericMaskEmbedding(nn.Module): |
| """ |
| Per-bucket numeric mask embedding. |
| Local to one (n_in) group / bucket. |
| |
| Parameter shape: |
| [num_bucket_cols, n_in, H] |
| |
| So missing numeric columns are represented by: |
| (bucket-local column index, sub-token index) |
| """ |
|
|
| def __init__(self, num_bucket_cols: int, n_in: int, hidden_size: int): |
| super().__init__() |
| self.num_bucket_cols = int(num_bucket_cols) |
| self.n_in = int(n_in) |
| self.hidden_size = int(hidden_size) |
|
|
| self.emb = nn.Parameter( |
| torch.empty(self.num_bucket_cols, self.n_in, self.hidden_size) |
| ) |
|
|
| def forward(self, local_col_ids: torch.LongTensor, batch_size: int) -> torch.Tensor: |
| """ |
| local_col_ids: [V] bucket-local ids, usually 0 to V-1 |
| returns: [B, V*n_in, H] |
| """ |
| if local_col_ids.dim() != 1: |
| raise ValueError(f"local_col_ids must be [V], got {tuple(local_col_ids.shape)}") |
|
|
| V = local_col_ids.numel() |
| mask_vec = self.emb[local_col_ids] |
| mask_vec = mask_vec.unsqueeze(0).expand(batch_size, V, self.n_in, self.hidden_size) |
| return mask_vec.reshape(batch_size, V * self.n_in, self.hidden_size) |
|
|
| def init_weights(self, std: float = 0.02): |
| nn.init.normal_(self.emb, std=std) |
|
|
|
|
| class NumericEmbedding(nn.Module): |
| """ |
| Full numeric embedding for all numeric columns described by numeric_vocab.json. |
| |
| Forward expects bucketed input: |
| values_by_nin: { n_in: x[B, V, n_in] } |
| where V must match the feature count and order of that n_in group. |
| |
| Output token ordering: |
| groups by n_in ascending (as stored in spec["groups"]), |
| within each group by feature_names order. |
| """ |
|
|
| def __init__(self, hidden_size: int, numeric_vocab_json: str, middle_size: Optional[int] = None): |
| super().__init__() |
| self.hidden_size = int(hidden_size) |
| self.middle_size = None if middle_size is None else int(middle_size) |
|
|
| spec = load_json(numeric_vocab_json) |
| self.ordered_feature_names: List[str] = list(spec["ordered_feature_names"]) |
| self.features: Dict[str, Dict] = dict(spec["features"]) |
| self.groups: List[Dict] = list(spec["groups"]) |
| self.total_numeric_tokens = int(spec.get("total_numeric_tokens", -1)) |
|
|
| num_cols = len(self.ordered_feature_names) |
|
|
| |
| self.id_emb = NumericIdEmbedding( |
| num_numeric_cols=num_cols, |
| hidden_size=self.hidden_size, |
| ) |
|
|
| |
| self.mask_emb = nn.ModuleDict() |
|
|
| |
| self.group_mlps = nn.ModuleList() |
|
|
| self.group_nins: List[int] = [] |
| self._num_groups = len(self.groups) |
|
|
| |
| self.group_sizes: List[int] = [] |
|
|
| |
| for gi, g in enumerate(self.groups): |
| n_in = int(g["n_in"]) |
| names = list(g["feature_names"]) |
| V = len(names) |
|
|
| self.group_nins.append(n_in) |
| self.group_sizes.append(V) |
|
|
| |
| |
| local_ids = [] |
| for local_idx, nm in enumerate(names): |
| f = self.features[nm] |
|
|
| if int(f["group_index"]) != gi: |
| raise ValueError( |
| f"Feature {nm} has group_index={f['group_index']}, expected {gi}" |
| ) |
| if int(f["n_in"]) != n_in: |
| raise ValueError( |
| f"Feature {nm} has n_in={f['n_in']}, expected {n_in}" |
| ) |
| if int(f["index_within_group"]) != local_idx: |
| raise ValueError( |
| f"Feature {nm} has index_within_group={f['index_within_group']}, expected {local_idx}" |
| ) |
|
|
| local_ids.append(int(f["index_within_group"])) |
|
|
| |
| if sorted(local_ids) != list(range(V)): |
| raise ValueError( |
| f"Group gi={gi}, n_in={n_in} has invalid index_within_group set: " |
| f"got {sorted(local_ids)}, expected {list(range(V))}" |
| ) |
|
|
| |
| self.group_mlps.append( |
| GroupedMLP( |
| n_var=V, |
| n_in=n_in, |
| n_out=n_in * self.hidden_size, |
| middle_size=self.middle_size, |
| ) |
| ) |
|
|
| |
| global_col_ids = [int(self.features[nm]["col_id"]) for nm in names] |
| self.register_buffer( |
| f"group_global_col_ids_{gi}", |
| torch.tensor(global_col_ids, dtype=torch.long), |
| persistent=True, |
| ) |
|
|
| |
| local_col_ids = [int(self.features[nm]["index_within_group"]) for nm in names] |
| self.register_buffer( |
| f"group_local_col_ids_{gi}", |
| torch.tensor(local_col_ids, dtype=torch.long), |
| persistent=True, |
| ) |
|
|
| |
| self.mask_emb[str(n_in)] = NumericMaskEmbedding( |
| num_bucket_cols=V, |
| n_in=n_in, |
| hidden_size=self.hidden_size, |
| ) |
|
|
| if self.total_numeric_tokens < 0: |
| self.total_numeric_tokens = sum( |
| len(g["feature_names"]) * int(g["n_in"]) for g in self.groups |
| ) |
|
|
| def init_weights(self, std: float = 0.02): |
| self.id_emb.init_weights(std=std) |
|
|
| for _, mask_mod in self.mask_emb.items(): |
| mask_mod.init_weights(std=std) |
|
|
| for mlp in self.group_mlps: |
| mlp.init_weights(std=std) |
|
|
| def forward( |
| self, |
| values_by_nin: Dict[int, torch.Tensor], |
| valid_positions_by_nin: Optional[Dict[int, torch.Tensor]] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Args: |
| values_by_nin: |
| { n_in: x } where x is [B, V, n_in] |
| Missing numeric values are assumed already filled in x. |
| |
| valid_positions_by_nin (optional): |
| { n_in: valid_cols } where valid_cols is BoolTensor [B, V] |
| True means this COLUMN is observed/valid. |
| |
| Note: |
| This is COLUMN-level mask, not token-level. |
| It is expanded to token-level by repeating across n_in. |
| |
| Returns: |
| tokens: [B, total_numeric_tokens, H] |
| token_mask: [B, total_numeric_tokens] (1=valid, 0=missing) |
| """ |
| outs = [] |
| masks = [] |
| batch_size = None |
|
|
| for gi, n_in in enumerate(self.group_nins): |
| if n_in not in values_by_nin: |
| raise KeyError(f"Missing bucket input for n_in={n_in}") |
|
|
| x = values_by_nin[n_in] |
| if x.dim() != 3 or x.size(-1) != n_in: |
| raise ValueError(f"Bucket n_in={n_in} expects x [B,V,{n_in}], got {tuple(x.shape)}") |
|
|
| if batch_size is None: |
| batch_size = x.size(0) |
| elif x.size(0) != batch_size: |
| raise ValueError("All buckets must share the same batch size") |
|
|
| B, V, _ = x.shape |
|
|
| expected_V = self.group_sizes[gi] |
| if V != expected_V: |
| raise ValueError( |
| f"Bucket n_in={n_in} expects V={expected_V}, got V={V}" |
| ) |
|
|
| |
| if valid_positions_by_nin is None: |
| valid_cols = torch.ones((B, V), dtype=torch.bool, device=x.device) |
| else: |
| if n_in not in valid_positions_by_nin: |
| raise KeyError(f"Missing valid mask for bucket n_in={n_in}") |
|
|
| valid_cols = valid_positions_by_nin[n_in] |
| if valid_cols.dtype != torch.bool: |
| raise ValueError( |
| f"valid_positions_by_nin[{n_in}] must be bool tensor, got {valid_cols.dtype}" |
| ) |
| if valid_cols.shape != (B, V): |
| raise ValueError( |
| f"valid_positions_by_nin[{n_in}] must be [B,V]=[{B},{V}], got {tuple(valid_cols.shape)}" |
| ) |
| valid_cols = valid_cols.to(device=x.device) |
|
|
| |
| mlp = self.group_mlps[gi] |
| param = next(mlp.parameters()) |
| x = x.to(device=param.device, dtype=param.dtype) |
|
|
| |
| y = mlp(x) |
|
|
| |
| y_tok = y.view(B, V, n_in, self.hidden_size).reshape(B, V * n_in, self.hidden_size) |
|
|
| |
| valid_tok = valid_cols.unsqueeze(-1).expand(B, V, n_in).reshape(B, V * n_in) |
|
|
| |
| local_col_ids = getattr(self, f"group_local_col_ids_{gi}") |
| mask_tok = self.mask_emb[str(n_in)](local_col_ids, batch_size=B) |
|
|
| if (~valid_tok).any(): |
| y_tok = torch.where( |
| valid_tok.unsqueeze(-1), |
| y_tok, |
| mask_tok, |
| ) |
|
|
| |
| global_col_ids = getattr(self, f"group_global_col_ids_{gi}") |
| y_tok = y_tok + self.id_emb(global_col_ids, batch_size=B, n_in=n_in) |
|
|
| token_mask = valid_tok.to(dtype=torch.long) |
|
|
| outs.append(y_tok) |
| masks.append(token_mask) |
|
|
| tokens = torch.cat(outs, dim=1) |
| token_mask = torch.cat(masks, dim=1) |
|
|
| if token_mask.shape[:2] != tokens.shape[:2]: |
| raise RuntimeError("token_mask shape mismatch with tokens") |
|
|
| return tokens, token_mask |
|
|
|
|
| |
| |
| |
|
|
| def _demo_main(): |
| import argparse |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument("--tabular_meta", type=str, default="data/tabular_meta.json") |
| parser.add_argument("--numeric_vocab_json", type=str, default="data/numeric_vocab.json") |
| parser.add_argument("--hidden_size", type=int, default=768) |
| parser.add_argument("--middle_size", type=int, default=-1, |
| help="If <0 -> one-layer. If >=0 -> two-layer with this middle size.") |
| parser.add_argument("--batch_size", type=int, default=4) |
| parser.add_argument("--device", type=str, default=None) |
| parser.add_argument("--dtype", type=str, default="float32", choices=["float16", "bfloat16", "float32"]) |
| args = parser.parse_args() |
|
|
| device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu")) |
| dtype_map = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32} |
| dtype = dtype_map[args.dtype] |
|
|
| meta = load_json(args.tabular_meta) |
|
|
| spec = build_numeric_vocab_spec_from_meta(meta) |
| save_json(spec, args.numeric_vocab_json) |
| print(f"Saved numeric vocab spec to: {args.numeric_vocab_json}") |
| print(f"Groups (n_in -> V):", {g["n_in"]: len(g["feature_names"]) for g in spec["groups"]}) |
| print("total_numeric_tokens:", spec["total_numeric_tokens"]) |
| print("group_token_offsets:", spec["group_token_offsets"]) |
|
|
| middle_size = None if args.middle_size < 0 else int(args.middle_size) |
| model = NumericEmbedding( |
| hidden_size=args.hidden_size, |
| numeric_vocab_json=args.numeric_vocab_json, |
| middle_size=middle_size, |
| ).to(device=device, dtype=dtype) |
| model.init_weights() |
| model.eval() |
|
|
| total_params = sum(p.numel() for p in model.parameters()) |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| print(f"Total parameters (NumericEmbedding): {total_params:,} (trainable: {trainable_params:,})") |
|
|
| |
| B = args.batch_size |
| values_by_nin: Dict[int, torch.Tensor] = {} |
| valid_positions_by_nin: Dict[int, torch.Tensor] = {} |
|
|
| for g in spec["groups"]: |
| n_in = int(g["n_in"]) |
| V = len(g["feature_names"]) |
|
|
| |
| x = torch.randn(B, V, n_in, device=device, dtype=dtype) |
| values_by_nin[n_in] = x |
|
|
| |
| |
| valid_cols = torch.ones((B, V), dtype=torch.bool, device=device) |
|
|
| |
| num_to_invalidate = min(2, V) |
| valid_cols[0, :num_to_invalidate] = False |
|
|
| valid_positions_by_nin[n_in] = valid_cols |
|
|
| with torch.no_grad(): |
| out, mask = model(values_by_nin, valid_positions_by_nin) |
|
|
| print("Buckets:", {k: tuple(v.shape) for k, v in values_by_nin.items()}) |
| print("Output tokens:", tuple(out.shape), out.dtype, out.device) |
| print("Masks:", tuple(mask.shape), mask.dtype, mask.device) |
|
|
| |
| print("\nFirst sample mask (first 5 tokens):") |
| print(mask[0, :5]) |
|
|
| print("\nFirst sample token L2 norms (first 5 tokens):") |
| print(out[0, :5].norm(dim=-1)) |
|
|
| print("\nSecond sample mask (first 5 tokens):") |
| print(mask[1, :5]) |
|
|
| print("\nSecond sample token L2 norms (first 5 tokens):") |
| print(out[1, :5].norm(dim=-1)) |
|
|
|
|
| if __name__ == "__main__": |
| _demo_main() |
|
|