| |
| |
|
|
| """ |
| Numeric decoder module for tabular transformer. |
| |
| Symmetric to embed_numeric.py (bucketed by n_in): |
| - For each bucket (same n_in), we decode tokens without a Python for-loop over columns. |
| - Uses a batched per-variable MLP with per-column parameters (NOT shared across V). |
| |
| Input: |
| x_tokens: [B, total_numeric_tokens, H] |
| token order must match numeric_vocab.json: |
| groups by n_in ascending, within group by feature name, |
| and within each feature: n_in tokens. |
| |
| Output: |
| values_by_nin: Dict[int, Tensor] |
| n_in -> x_hat [B, V, n_in] |
| |
| middle_size: |
| - None: 1-layer per-variable Linear |
| - int : 2-layer per-variable MLP (Linear -> GELU -> Linear) |
| """ |
|
|
| from typing import Dict, List, Optional |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from utils import GroupedMLP, load_json |
|
|
|
|
| class NumericDecoder(nn.Module): |
| """ |
| Decode numeric tokens back to numeric values, bucketed by n_in. |
| |
| Input: |
| x_tokens: [B, total_numeric_tokens, H] |
| |
| Output: |
| values_by_nin: |
| n_in -> y_hat [B, V, n_in] |
| |
| s_by_nin: |
| n_in -> s [B, V] |
| where s = log(sigma^2), shared across the n_in dimensions |
| of each variable, intended for heteroscedastic loss computation. |
| """ |
|
|
| def __init__( |
| self, |
| hidden_size: int, |
| numeric_vocab_json: str, |
| middle_size: Optional[int] = None, |
| homoscedastic: bool = True, |
| ): |
| super().__init__() |
| self.hidden_size = int(hidden_size) |
| self.middle_size = None if middle_size is None else int(middle_size) |
| self.homoscedastic = bool(homoscedastic) |
|
|
| spec = load_json(numeric_vocab_json) |
| self.groups: List[Dict] = list(spec["groups"]) |
| self.total_numeric_tokens = int(spec["total_numeric_tokens"]) |
| self.group_token_offsets: Dict[str, int] = dict(spec.get("group_token_offsets", {})) |
|
|
| self.group_v_decoders = nn.ModuleList() |
| self.group_s_decoders = nn.ModuleList() |
| self.group_nins: List[int] = [] |
| self.group_Vs: List[int] = [] |
|
|
| for g in self.groups: |
| n_in = int(g["n_in"]) |
| names = list(g["feature_names"]) |
| V = len(names) |
|
|
| self.group_nins.append(n_in) |
| self.group_Vs.append(V) |
|
|
| |
| self.group_v_decoders.append( |
| GroupedMLP( |
| n_var=V, |
| n_in=n_in * self.hidden_size, |
| n_out=n_in, |
| middle_size=self.middle_size, |
| ) |
| ) |
|
|
| |
| if not self.homoscedastic: |
| self.group_s_decoders.append( |
| GroupedMLP( |
| n_var=V, |
| n_in=self.hidden_size, |
| n_out=1, |
| middle_size=self.middle_size, |
| ) |
| ) |
|
|
| if self.homoscedastic: |
| self.group_s_params = nn.ParameterList( |
| [nn.Parameter(torch.zeros(V)) for V in self.group_Vs] |
| ) |
| else: |
| self.group_s_params = None |
|
|
| |
| running = 0 |
| for g in self.groups: |
| n_in = int(g["n_in"]) |
| V = len(g["feature_names"]) |
| key = str(n_in) |
|
|
| if key not in self.group_token_offsets: |
| raise ValueError(f"Missing group_token_offsets entry for n_in={n_in}") |
| if int(self.group_token_offsets[key]) != running: |
| raise ValueError( |
| f"group_token_offsets[{key}]={self.group_token_offsets[key]} does not match expected {running}" |
| ) |
|
|
| running += V * n_in |
|
|
| if running != self.total_numeric_tokens: |
| raise ValueError( |
| f"total_numeric_tokens={self.total_numeric_tokens} does not match expected {running}" |
| ) |
|
|
| def init_weights(self, std: float = 0.02): |
| for dec in self.group_v_decoders: |
| dec.init_weights(std=std) |
|
|
| if self.homoscedastic: |
| for p in self.group_s_params: |
| nn.init.zeros_(p) |
| else: |
| for dec in self.group_s_decoders: |
| dec.init_weights(std=0.0) |
|
|
| def forward(self, x_tokens: torch.Tensor): |
| if x_tokens.dim() != 3: |
| raise ValueError(f"x_tokens must be [B,T,H], got {tuple(x_tokens.shape)}") |
|
|
| B, T, H = x_tokens.shape |
| if H != self.hidden_size: |
| raise ValueError(f"hidden_size mismatch: got H={H}, expected {self.hidden_size}") |
| if T != self.total_numeric_tokens: |
| raise ValueError(f"token length mismatch: got T={T}, expected {self.total_numeric_tokens}") |
|
|
| value_out: Dict[int, torch.Tensor] = {} |
| s_out: Dict[int, torch.Tensor] = {} |
|
|
| for gi, n_in in enumerate(self.group_nins): |
| key = str(n_in) |
| start = int(self.group_token_offsets[key]) |
|
|
| V = self.group_Vs[gi] |
| length = V * n_in |
|
|
| xg_tok = x_tokens[:, start:start + length, :] |
| xg_tok4 = xg_tok.reshape(B, V, n_in, H) |
| xg_flat = xg_tok4.reshape(B, V, n_in * H) |
|
|
| |
| y = self.group_v_decoders[gi](xg_flat) |
|
|
| |
| if self.homoscedastic: |
| s = self.group_s_params[gi].unsqueeze(0).expand(B, -1) |
| else: |
| x_var = xg_tok4.mean(dim=2) |
| s = self.group_s_decoders[gi](x_var).squeeze(-1) |
|
|
| value_out[n_in] = y |
| s_out[n_in] = s |
|
|
| return value_out, s_out |
|
|
|
|
| |
| |
| |
|
|
| def _demo_main(): |
| import argparse |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument("--numeric_vocab_json", type=str, default="data/numeric_vocab.json") |
| parser.add_argument("--hidden_size", type=int, default=768) |
| parser.add_argument("--middle_size", type=int, default=-1, |
| help="If <0 -> one-layer. If >=0 -> two-layer with this middle size.") |
| parser.add_argument("--batch_size", type=int, default=4) |
| parser.add_argument("--device", type=str, default=None) |
| parser.add_argument("--dtype", type=str, default="float32", choices=["float16", "bfloat16", "float32"]) |
| args = parser.parse_args() |
|
|
| device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu")) |
| dtype_map = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32} |
| dtype = dtype_map[args.dtype] |
|
|
| |
| spec = load_json(args.numeric_vocab_json) |
| print(f"Loaded numeric vocab spec from: {args.numeric_vocab_json}") |
| print(f"Groups (n_in -> V):", {int(g['n_in']): len(g['feature_names']) for g in spec["groups"]}) |
| print("total_numeric_tokens:", spec["total_numeric_tokens"]) |
| print("group_token_offsets:", spec["group_token_offsets"]) |
|
|
| middle_size = None if args.middle_size < 0 else int(args.middle_size) |
| model = NumericDecoder( |
| hidden_size=args.hidden_size, |
| numeric_vocab_json=args.numeric_vocab_json, |
| middle_size=middle_size, |
| ).to(device=device, dtype=dtype) |
| model.eval() |
|
|
| total_params = sum(p.numel() for p in model.parameters()) |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| print(f"Total parameters (NumericDecoder): {total_params:,} (trainable: {trainable_params:,})") |
|
|
| B = args.batch_size |
| T = int(spec["total_numeric_tokens"]) |
| H = args.hidden_size |
|
|
| x_tokens = torch.randn(B, T, H, device=device, dtype=dtype) |
|
|
| with torch.no_grad(): |
| values_by_nin, s_by_nin = model(x_tokens) |
|
|
| print("Input tokens:", tuple(x_tokens.shape), x_tokens.dtype, x_tokens.device) |
| print("Decoded values:", {k: tuple(v.shape) for k, v in values_by_nin.items()}) |
| print("Decoded s:", {k: tuple(s.shape) for k, s in s_by_nin.items()}) |
| |
| |
|
|
|
|
| if __name__ == "__main__": |
| _demo_main() |
|
|