# decode_numeric.py # -*- coding: utf-8 -*- """ Numeric decoder module for tabular transformer. Symmetric to embed_numeric.py (bucketed by n_in): - For each bucket (same n_in), we decode tokens without a Python for-loop over columns. - Uses a batched per-variable MLP with per-column parameters (NOT shared across V). Input: x_tokens: [B, total_numeric_tokens, H] token order must match numeric_vocab.json: groups by n_in ascending, within group by feature name, and within each feature: n_in tokens. Output: values_by_nin: Dict[int, Tensor] n_in -> x_hat [B, V, n_in] middle_size: - None: 1-layer per-variable Linear - int : 2-layer per-variable MLP (Linear -> GELU -> Linear) """ from typing import Dict, List, Optional import torch import torch.nn as nn from utils import GroupedMLP, load_json class NumericDecoder(nn.Module): """ Decode numeric tokens back to numeric values, bucketed by n_in. Input: x_tokens: [B, total_numeric_tokens, H] Output: values_by_nin: n_in -> y_hat [B, V, n_in] s_by_nin: n_in -> s [B, V] where s = log(sigma^2), shared across the n_in dimensions of each variable, intended for heteroscedastic loss computation. """ def __init__( self, hidden_size: int, numeric_vocab_json: str, middle_size: Optional[int] = None, homoscedastic: bool = True, ): super().__init__() self.hidden_size = int(hidden_size) self.middle_size = None if middle_size is None else int(middle_size) self.homoscedastic = bool(homoscedastic) spec = load_json(numeric_vocab_json) self.groups: List[Dict] = list(spec["groups"]) self.total_numeric_tokens = int(spec["total_numeric_tokens"]) self.group_token_offsets: Dict[str, int] = dict(spec.get("group_token_offsets", {})) self.group_v_decoders = nn.ModuleList() self.group_s_decoders = nn.ModuleList() self.group_nins: List[int] = [] self.group_Vs: List[int] = [] for g in self.groups: n_in = int(g["n_in"]) names = list(g["feature_names"]) V = len(names) self.group_nins.append(n_in) # noqa self.group_Vs.append(V) # value decoder: [B,V,n_in*H] -> [B,V,n_in] self.group_v_decoders.append( GroupedMLP( n_var=V, n_in=n_in * self.hidden_size, n_out=n_in, middle_size=self.middle_size, ) ) # uncertainty decoder: [B,V,H] -> [B,V,1] -> [B,V] if not self.homoscedastic: self.group_s_decoders.append( GroupedMLP( n_var=V, n_in=self.hidden_size, n_out=1, middle_size=self.middle_size, ) ) if self.homoscedastic: self.group_s_params = nn.ParameterList( [nn.Parameter(torch.zeros(V)) for V in self.group_Vs] ) else: self.group_s_params = None # spec integrity check running = 0 for g in self.groups: n_in = int(g["n_in"]) V = len(g["feature_names"]) key = str(n_in) if key not in self.group_token_offsets: raise ValueError(f"Missing group_token_offsets entry for n_in={n_in}") if int(self.group_token_offsets[key]) != running: raise ValueError( f"group_token_offsets[{key}]={self.group_token_offsets[key]} does not match expected {running}" ) running += V * n_in if running != self.total_numeric_tokens: raise ValueError( f"total_numeric_tokens={self.total_numeric_tokens} does not match expected {running}" ) def init_weights(self, std: float = 0.02): for dec in self.group_v_decoders: dec.init_weights(std=std) if self.homoscedastic: for p in self.group_s_params: nn.init.zeros_(p) else: for dec in self.group_s_decoders: dec.init_weights(std=0.0) def forward(self, x_tokens: torch.Tensor): if x_tokens.dim() != 3: raise ValueError(f"x_tokens must be [B,T,H], got {tuple(x_tokens.shape)}") B, T, H = x_tokens.shape if H != self.hidden_size: raise ValueError(f"hidden_size mismatch: got H={H}, expected {self.hidden_size}") if T != self.total_numeric_tokens: raise ValueError(f"token length mismatch: got T={T}, expected {self.total_numeric_tokens}") value_out: Dict[int, torch.Tensor] = {} s_out: Dict[int, torch.Tensor] = {} for gi, n_in in enumerate(self.group_nins): key = str(n_in) start = int(self.group_token_offsets[key]) V = self.group_Vs[gi] length = V * n_in xg_tok = x_tokens[:, start:start + length, :] # [B, V*n_in, H] xg_tok4 = xg_tok.reshape(B, V, n_in, H) # [B, V, n_in, H] xg_flat = xg_tok4.reshape(B, V, n_in * H) # [B, V, n_in*H] # values: [B, V, n_in] y = self.group_v_decoders[gi](xg_flat) # s = log sigma^2: [B, V] if self.homoscedastic: s = self.group_s_params[gi].unsqueeze(0).expand(B, -1) else: x_var = xg_tok4.mean(dim=2) # [B, V, H] s = self.group_s_decoders[gi](x_var).squeeze(-1) # [B, V] value_out[n_in] = y s_out[n_in] = s return value_out, s_out # ============================================================ # DEMO # ============================================================ def _demo_main(): import argparse parser = argparse.ArgumentParser() parser.add_argument("--numeric_vocab_json", type=str, default="data/numeric_vocab.json") parser.add_argument("--hidden_size", type=int, default=768) parser.add_argument("--middle_size", type=int, default=-1, help="If <0 -> one-layer. If >=0 -> two-layer with this middle size.") parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--device", type=str, default=None) parser.add_argument("--dtype", type=str, default="float32", choices=["float16", "bfloat16", "float32"]) args = parser.parse_args() device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu")) dtype_map = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32} dtype = dtype_map[args.dtype] # Directly load existing numeric vocab spec spec = load_json(args.numeric_vocab_json) print(f"Loaded numeric vocab spec from: {args.numeric_vocab_json}") print(f"Groups (n_in -> V):", {int(g['n_in']): len(g['feature_names']) for g in spec["groups"]}) print("total_numeric_tokens:", spec["total_numeric_tokens"]) print("group_token_offsets:", spec["group_token_offsets"]) middle_size = None if args.middle_size < 0 else int(args.middle_size) model = NumericDecoder( hidden_size=args.hidden_size, numeric_vocab_json=args.numeric_vocab_json, middle_size=middle_size, ).to(device=device, dtype=dtype) model.eval() total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Total parameters (NumericDecoder): {total_params:,} (trainable: {trainable_params:,})") B = args.batch_size T = int(spec["total_numeric_tokens"]) H = args.hidden_size x_tokens = torch.randn(B, T, H, device=device, dtype=dtype) with torch.no_grad(): values_by_nin, s_by_nin = model(x_tokens) print("Input tokens:", tuple(x_tokens.shape), x_tokens.dtype, x_tokens.device) print("Decoded values:", {k: tuple(v.shape) for k, v in values_by_nin.items()}) print("Decoded s:", {k: tuple(s.shape) for k, s in s_by_nin.items()}) # values_by_nin[n_in]: [B, V, n_in] # s_by_nin[n_in]: [B, V] if __name__ == "__main__": _demo_main()