# decode_categorical.py # -*- coding: utf-8 -*- """ Categorical decoder for tabular transformer. Design (column-wise heads): - Each categorical column corresponds to exactly 1 token. - Each column has its own classifier head: hidden_size -> num_classes[col] Optionally with a small MLP: hidden_size -> middle_size -> num_classes[col] No loss is included here (caller will apply CrossEntropyLoss). """ from typing import List, Optional, Tuple, Union import torch import torch.nn as nn from utils import load_json, GroupedMLP # ============================================================ # Small head builder # ============================================================ def _make_head( hidden_size: int, num_classes: int, middle_size: Optional[int], bias: bool = True, ) -> nn.Module: """ Build a lightweight per-column classifier head. """ if middle_size is None: return nn.Linear(hidden_size, num_classes, bias=bias) return nn.Sequential( nn.Linear(hidden_size, middle_size, bias=bias), nn.GELU(), nn.Linear(middle_size, num_classes, bias=bias), ) # ============================================================ # Decoder # ============================================================ class CategoricalDecoder(nn.Module): """ Column-wise categorical decoder. Design: - Each categorical column corresponds to exactly one token. - Each column has its own classifier head: hidden_size -> num_classes[col] Optionally with a small MLP: hidden_size -> middle_size -> num_classes[col] - In addition, the decoder predicts a per-sample, per-column log-variance term `s` used for heteroscedastic loss weighting. Input: x_cat_tokens: [B, M, H] B = batch size M = number of categorical columns (ordered by col_id) H = hidden size Outputs: Case 1 (return_padded=False): logits_list: List[Tensor] length M logits_list[m]: [B, num_classes[m]] s: [B, M] Predicted log-variance per sample and column: s[b, m] = log sigma^2_{b,m} Intended for heteroscedastic loss weighting. Case 2 (return_padded=True): logits_padded: [B, M, Cmax] Logits padded to the maximum class count across columns. s: [B, M] Same uncertainty prediction as above. valid_mask: [M, Cmax] True for valid class indices for each column. """ def __init__( self, hidden_size: int, cat_vocab_json: str, middle_size: Optional[int] = None, bias: bool = True, homoscedastic: bool = True, ): super().__init__() spec = load_json(cat_vocab_json) items = sorted(spec.items(), key=lambda x: x[1]["col_id"]) col_ids: List[int] = [] num_classes: List[int] = [] for _, val in items: col_ids.append(int(val["col_id"])) num_classes.append(int(val["num_classes"])) self.hidden_size = int(hidden_size) self.num_cols = len(num_classes) self.middle_size = middle_size self.homoscedastic = bool(homoscedastic) # Buffers for debugging / validation / optional padded output self.register_buffer("cat_col_ids", torch.tensor(col_ids, dtype=torch.long), persistent=True) # [M] self.register_buffer("num_classes", torch.tensor(num_classes, dtype=torch.long), persistent=True) # [M] # Build per-column heads heads = [] for c in num_classes: head = _make_head(self.hidden_size, c, middle_size, bias=bias) heads.append(head) self.heads = nn.ModuleList(heads) if self.homoscedastic: self.s_param = nn.Parameter(torch.zeros(self.num_cols)) self.s_head = None else: self.s_head = GroupedMLP( n_var=self.num_cols, n_in=self.hidden_size, n_out=1, middle_size=self.middle_size, ) self.s_param = None def init_weights(self, std: float = 0.02): for head in self.heads: for module in head.modules(): if isinstance(module, nn.Linear): nn.init.normal_(module.weight, std=std) if module.bias is not None: nn.init.zeros_(module.bias) if self.homoscedastic: nn.init.zeros_(self.s_param) else: self.s_head.init_weights(std=0.0) def _check_input(self, x_cat_tokens: torch.Tensor) -> Tuple[int, int, int]: if x_cat_tokens.dim() != 3: raise ValueError(f"x_cat_tokens must be [B,M,H], got {tuple(x_cat_tokens.shape)}") B, M, H = x_cat_tokens.shape if H != self.hidden_size: raise ValueError(f"hidden_size mismatch: got {H}, expected {self.hidden_size}") if M != self.num_cols: raise ValueError(f"categorical token count mismatch: got M={M}, expected {self.num_cols}") return B, M, H @torch.no_grad() def _build_valid_mask(self, device: torch.device) -> torch.Tensor: """ valid_mask[m, j] = True iff j < num_classes[m] """ M = self.num_cols cmax = int(self.num_classes.max().item()) ar = torch.arange(cmax, device=device).view(1, cmax).expand(M, cmax) nc = self.num_classes.view(M, 1).expand(M, cmax) return ar < nc def forward( self, x_cat_tokens: torch.Tensor, return_padded: bool = False, pad_value: Optional[float] = None, ) -> Union[ Tuple[List[torch.Tensor], torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor] ]: """ Args: x_cat_tokens: [B, M, H] B = batch size M = number of categorical columns H = hidden size (per-column token embedding dim) return_padded: False: return (logits_list, s) True: return (logits_padded, s, valid_mask) pad_value: Value used to fill invalid class positions in padded logits. Returns: Case 1 (return_padded=False): logits_list: List length M logits_list[m]: [B, C_m] s: [B, M] s[b, m] = log sigma^2 for sample b, column m Case 2 (return_padded=True): logits_padded: [B, M, Cmax] s: [B, M] valid_mask: [M, Cmax] """ # -------------------------------------------------------- # 1) Basic shape validation # -------------------------------------------------------- # Ensures x_cat_tokens is [B,M,H] and matches decoder config B, M, _ = self._check_input(x_cat_tokens) # -------------------------------------------------------- # 2) Per-column categorical logits # -------------------------------------------------------- # We still use per-column heads because each column # can have a different number of classes C_m. # # logits_list[m] shape: [B, C_m] logits_list: List[torch.Tensor] = [] for m in range(M): # x_cat_tokens[:, m, :] -> [B,H] # heads[m] maps H -> C_m logits_m = self.heads[m](x_cat_tokens[:, m, :]) logits_list.append(logits_m) # -------------------------------------------------------- # 3) Sample-wise & column-wise uncertainty (log-variance) # -------------------------------------------------------- # s_head processes all columns at once (grouped, no loop) # # Input: [B,M,H] # Output: [B,M] # # s[b,m] = log(sigma_{b,m}^2) if self.homoscedastic: s = self.s_param.unsqueeze(0).expand(B, -1) else: s = self.s_head(x_cat_tokens).squeeze(-1) # -------------------------------------------------------- # 4) If no padded output requested # -------------------------------------------------------- if not return_padded: # Return: # logits_list: List of length M # s: [B,M] return logits_list, s # -------------------------------------------------------- # 5) Build padded logits tensor # -------------------------------------------------------- # We unify different C_m into a common Cmax. # # logits_padded shape: [B,M,Cmax] cmax = int(self.num_classes.max().item()) if pad_value is None: pad_value = torch.finfo(x_cat_tokens.dtype).min logits_padded = torch.full( (B, M, cmax), pad_value, device=x_cat_tokens.device, dtype=x_cat_tokens.dtype, ) # Fill valid class positions per column for m in range(M): cm = logits_list[m].size(-1) # C_m logits_padded[:, m, :cm] = logits_list[m] # -------------------------------------------------------- # 6) Build validity mask # -------------------------------------------------------- # valid_mask[m,j] = True if j < C_m # = False otherwise # # Shape: [M, Cmax] valid_class_mask = self._build_valid_mask(device=x_cat_tokens.device) # -------------------------------------------------------- # 7) Return padded outputs # -------------------------------------------------------- return logits_padded, s, valid_class_mask # ============================================================ # DEMO # ============================================================ def _demo_main(): import argparse parser = argparse.ArgumentParser() parser.add_argument("--cat_vocab_json", type=str, default="data/cat_vocab.json") parser.add_argument("--hidden_size", type=int, default=768) parser.add_argument("--middle_size", type=int, default=None) parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--device", type=str, default=None) parser.add_argument("--dtype", type=str, default="float32", choices=["float16", "bfloat16", "float32"]) args = parser.parse_args() device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu")) dtype_map = { "float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32, } dtype = dtype_map[args.dtype] # -------------------------------------------------------- # Load vocab spec # -------------------------------------------------------- spec = load_json(args.cat_vocab_json) items = sorted(spec.items(), key=lambda x_: x_[1]["col_id"]) M = len(items) B = args.batch_size H = args.hidden_size num_classes = [int(s["num_classes"]) for _, s in items] print("===== Categorical Columns =====") for i, (name, s) in enumerate(items): print(f"{i:03d} {name:20s} classes={s['num_classes']}") print() # -------------------------------------------------------- # Build model # -------------------------------------------------------- model = CategoricalDecoder( hidden_size=args.hidden_size, cat_vocab_json=args.cat_vocab_json, middle_size=args.middle_size, ).to(device=device, dtype=dtype) total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Model parameters: {total_params:,} (trainable: {trainable_params:,})") print() # -------------------------------------------------------- # Fake input tokens # -------------------------------------------------------- x = torch.randn(B, M, H, device=device, dtype=dtype) print("Input tokens shape:", tuple(x.shape)) print() # -------------------------------------------------------- # Case 1: logits_list # -------------------------------------------------------- print("===== Forward: logits_list mode =====") with torch.no_grad(): logits_list, s = model(x, return_padded=False) for m, (name, spec_item) in enumerate(items): C = spec_item["num_classes"] print(f"{m:03d} {name:20s} logits:", tuple(logits_list[m].shape), f"(expected {(B, C)})") print("s shape:", tuple(s.shape)) print() # -------------------------------------------------------- # Case 2: padded logits # -------------------------------------------------------- print("===== Forward: padded mode =====") with torch.no_grad(): logits_padded, s2, valid_mask = model(x, return_padded=True) print("logits_padded:", tuple(logits_padded.shape)) print("s:", tuple(s2.shape)) print("valid_mask:", tuple(valid_mask.shape)) print() # -------------------------------------------------------- # Visualize valid mask # -------------------------------------------------------- print("===== Valid class mask (first 10 columns) =====") cols_to_show = min(10, M) for m in range(cols_to_show): cm = num_classes[m] valid = valid_mask[m].sum().item() print(f"col {m:02d} num_classes={cm} valid_mask_sum={valid}") print() # -------------------------------------------------------- # Check padded logits correctness # -------------------------------------------------------- print("===== Padded logits sanity check =====") for m in range(cols_to_show): cm = num_classes[m] valid_region = logits_padded[:, m, :cm] padded_region = logits_padded[:, m, cm:] print(f"col {m:02d} valid region shape:", tuple(valid_region.shape)) if padded_region.numel() > 0: print(f"col {m:02d} padded region mean:", padded_region.mean().item()) print() print("Demo finished successfully.") if __name__ == "__main__": _demo_main()