| |
| |
|
|
| """ |
| Categorical decoder for tabular transformer. |
| |
| Design (column-wise heads): |
| - Each categorical column corresponds to exactly 1 token. |
| - Each column has its own classifier head: |
| hidden_size -> num_classes[col] |
| Optionally with a small MLP: |
| hidden_size -> middle_size -> num_classes[col] |
| |
| No loss is included here (caller will apply CrossEntropyLoss). |
| """ |
|
|
| from typing import List, Optional, Tuple, Union |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from utils import load_json, GroupedMLP |
|
|
|
|
| |
| |
| |
|
|
| def _make_head( |
| hidden_size: int, |
| num_classes: int, |
| middle_size: Optional[int], |
| bias: bool = True, |
| ) -> nn.Module: |
| """ |
| Build a lightweight per-column classifier head. |
| """ |
| if middle_size is None: |
| return nn.Linear(hidden_size, num_classes, bias=bias) |
|
|
| return nn.Sequential( |
| nn.Linear(hidden_size, middle_size, bias=bias), |
| nn.GELU(), |
| nn.Linear(middle_size, num_classes, bias=bias), |
| ) |
|
|
|
|
| |
| |
| |
|
|
| class CategoricalDecoder(nn.Module): |
| """ |
| Column-wise categorical decoder. |
| |
| Design: |
| - Each categorical column corresponds to exactly one token. |
| - Each column has its own classifier head: |
| hidden_size -> num_classes[col] |
| Optionally with a small MLP: |
| hidden_size -> middle_size -> num_classes[col] |
| |
| - In addition, the decoder predicts a per-sample, per-column |
| log-variance term `s` used for heteroscedastic loss weighting. |
| |
| Input: |
| x_cat_tokens: [B, M, H] |
| B = batch size |
| M = number of categorical columns (ordered by col_id) |
| H = hidden size |
| |
| Outputs: |
| |
| Case 1 (return_padded=False): |
| logits_list: List[Tensor] length M |
| logits_list[m]: [B, num_classes[m]] |
| |
| s: [B, M] |
| Predicted log-variance per sample and column: |
| s[b, m] = log sigma^2_{b,m} |
| Intended for heteroscedastic loss weighting. |
| |
| Case 2 (return_padded=True): |
| logits_padded: [B, M, Cmax] |
| Logits padded to the maximum class count across columns. |
| |
| s: [B, M] |
| Same uncertainty prediction as above. |
| |
| valid_mask: [M, Cmax] |
| True for valid class indices for each column. |
| """ |
|
|
| def __init__( |
| self, |
| hidden_size: int, |
| cat_vocab_json: str, |
| middle_size: Optional[int] = None, |
| bias: bool = True, |
| homoscedastic: bool = True, |
| ): |
| super().__init__() |
|
|
| spec = load_json(cat_vocab_json) |
| items = sorted(spec.items(), key=lambda x: x[1]["col_id"]) |
|
|
| col_ids: List[int] = [] |
| num_classes: List[int] = [] |
|
|
| for _, val in items: |
| col_ids.append(int(val["col_id"])) |
| num_classes.append(int(val["num_classes"])) |
|
|
| self.hidden_size = int(hidden_size) |
| self.num_cols = len(num_classes) |
| self.middle_size = middle_size |
| self.homoscedastic = bool(homoscedastic) |
|
|
| |
| self.register_buffer("cat_col_ids", torch.tensor(col_ids, dtype=torch.long), persistent=True) |
| self.register_buffer("num_classes", torch.tensor(num_classes, dtype=torch.long), persistent=True) |
|
|
| |
| heads = [] |
| for c in num_classes: |
| head = _make_head(self.hidden_size, c, middle_size, bias=bias) |
| heads.append(head) |
|
|
| self.heads = nn.ModuleList(heads) |
|
|
| if self.homoscedastic: |
| self.s_param = nn.Parameter(torch.zeros(self.num_cols)) |
| self.s_head = None |
| else: |
| self.s_head = GroupedMLP( |
| n_var=self.num_cols, |
| n_in=self.hidden_size, |
| n_out=1, |
| middle_size=self.middle_size, |
| ) |
| self.s_param = None |
|
|
| def init_weights(self, std: float = 0.02): |
| for head in self.heads: |
| for module in head.modules(): |
| if isinstance(module, nn.Linear): |
| nn.init.normal_(module.weight, std=std) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
|
|
| if self.homoscedastic: |
| nn.init.zeros_(self.s_param) |
| else: |
| self.s_head.init_weights(std=0.0) |
|
|
| def _check_input(self, x_cat_tokens: torch.Tensor) -> Tuple[int, int, int]: |
| if x_cat_tokens.dim() != 3: |
| raise ValueError(f"x_cat_tokens must be [B,M,H], got {tuple(x_cat_tokens.shape)}") |
| B, M, H = x_cat_tokens.shape |
| if H != self.hidden_size: |
| raise ValueError(f"hidden_size mismatch: got {H}, expected {self.hidden_size}") |
| if M != self.num_cols: |
| raise ValueError(f"categorical token count mismatch: got M={M}, expected {self.num_cols}") |
| return B, M, H |
|
|
| @torch.no_grad() |
| def _build_valid_mask(self, device: torch.device) -> torch.Tensor: |
| """ |
| valid_mask[m, j] = True iff j < num_classes[m] |
| """ |
| M = self.num_cols |
| cmax = int(self.num_classes.max().item()) |
| ar = torch.arange(cmax, device=device).view(1, cmax).expand(M, cmax) |
| nc = self.num_classes.view(M, 1).expand(M, cmax) |
| return ar < nc |
|
|
| def forward( |
| self, |
| x_cat_tokens: torch.Tensor, |
| return_padded: bool = False, |
| pad_value: Optional[float] = None, |
| ) -> Union[ |
| Tuple[List[torch.Tensor], torch.Tensor], |
| Tuple[torch.Tensor, torch.Tensor, torch.Tensor] |
| ]: |
| """ |
| Args: |
| x_cat_tokens: [B, M, H] |
| B = batch size |
| M = number of categorical columns |
| H = hidden size (per-column token embedding dim) |
| |
| return_padded: |
| False: |
| return (logits_list, s) |
| True: |
| return (logits_padded, s, valid_mask) |
| |
| pad_value: |
| Value used to fill invalid class positions in padded logits. |
| |
| Returns: |
| |
| Case 1 (return_padded=False): |
| logits_list: List length M |
| logits_list[m]: [B, C_m] |
| s: [B, M] |
| s[b, m] = log sigma^2 for sample b, column m |
| |
| Case 2 (return_padded=True): |
| logits_padded: [B, M, Cmax] |
| s: [B, M] |
| valid_mask: [M, Cmax] |
| """ |
|
|
| |
| |
| |
| |
| B, M, _ = self._check_input(x_cat_tokens) |
|
|
| |
| |
| |
| |
| |
| |
| |
| logits_list: List[torch.Tensor] = [] |
| for m in range(M): |
| |
| |
| logits_m = self.heads[m](x_cat_tokens[:, m, :]) |
| logits_list.append(logits_m) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if self.homoscedastic: |
| s = self.s_param.unsqueeze(0).expand(B, -1) |
| else: |
| s = self.s_head(x_cat_tokens).squeeze(-1) |
|
|
| |
| |
| |
| if not return_padded: |
| |
| |
| |
| return logits_list, s |
|
|
| |
| |
| |
| |
| |
| |
| cmax = int(self.num_classes.max().item()) |
|
|
| if pad_value is None: |
| pad_value = torch.finfo(x_cat_tokens.dtype).min |
| logits_padded = torch.full( |
| (B, M, cmax), |
| pad_value, |
| device=x_cat_tokens.device, |
| dtype=x_cat_tokens.dtype, |
| ) |
|
|
| |
| for m in range(M): |
| cm = logits_list[m].size(-1) |
| logits_padded[:, m, :cm] = logits_list[m] |
|
|
| |
| |
| |
| |
| |
| |
| |
| valid_class_mask = self._build_valid_mask(device=x_cat_tokens.device) |
|
|
| |
| |
| |
| return logits_padded, s, valid_class_mask |
|
|
|
|
| |
| |
| |
|
|
| def _demo_main(): |
| import argparse |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument("--cat_vocab_json", type=str, default="data/cat_vocab.json") |
| parser.add_argument("--hidden_size", type=int, default=768) |
| parser.add_argument("--middle_size", type=int, default=None) |
| parser.add_argument("--batch_size", type=int, default=4) |
| parser.add_argument("--device", type=str, default=None) |
| parser.add_argument("--dtype", type=str, default="float32", choices=["float16", "bfloat16", "float32"]) |
| args = parser.parse_args() |
|
|
| device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu")) |
| dtype_map = { |
| "float16": torch.float16, |
| "bfloat16": torch.bfloat16, |
| "float32": torch.float32, |
| } |
| dtype = dtype_map[args.dtype] |
|
|
| |
| |
| |
| spec = load_json(args.cat_vocab_json) |
| items = sorted(spec.items(), key=lambda x_: x_[1]["col_id"]) |
|
|
| M = len(items) |
| B = args.batch_size |
| H = args.hidden_size |
|
|
| num_classes = [int(s["num_classes"]) for _, s in items] |
|
|
| print("===== Categorical Columns =====") |
| for i, (name, s) in enumerate(items): |
| print(f"{i:03d} {name:20s} classes={s['num_classes']}") |
| print() |
|
|
| |
| |
| |
| model = CategoricalDecoder( |
| hidden_size=args.hidden_size, |
| cat_vocab_json=args.cat_vocab_json, |
| middle_size=args.middle_size, |
| ).to(device=device, dtype=dtype) |
|
|
| total_params = sum(p.numel() for p in model.parameters()) |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
| print(f"Model parameters: {total_params:,} (trainable: {trainable_params:,})") |
| print() |
|
|
| |
| |
| |
| x = torch.randn(B, M, H, device=device, dtype=dtype) |
|
|
| print("Input tokens shape:", tuple(x.shape)) |
| print() |
|
|
| |
| |
| |
| print("===== Forward: logits_list mode =====") |
|
|
| with torch.no_grad(): |
| logits_list, s = model(x, return_padded=False) |
|
|
| for m, (name, spec_item) in enumerate(items): |
| C = spec_item["num_classes"] |
| print(f"{m:03d} {name:20s} logits:", tuple(logits_list[m].shape), f"(expected {(B, C)})") |
|
|
| print("s shape:", tuple(s.shape)) |
| print() |
|
|
| |
| |
| |
| print("===== Forward: padded mode =====") |
|
|
| with torch.no_grad(): |
| logits_padded, s2, valid_mask = model(x, return_padded=True) |
|
|
| print("logits_padded:", tuple(logits_padded.shape)) |
| print("s:", tuple(s2.shape)) |
| print("valid_mask:", tuple(valid_mask.shape)) |
| print() |
|
|
| |
| |
| |
| print("===== Valid class mask (first 10 columns) =====") |
|
|
| cols_to_show = min(10, M) |
| for m in range(cols_to_show): |
| cm = num_classes[m] |
| valid = valid_mask[m].sum().item() |
| print(f"col {m:02d} num_classes={cm} valid_mask_sum={valid}") |
|
|
| print() |
|
|
| |
| |
| |
| print("===== Padded logits sanity check =====") |
|
|
| for m in range(cols_to_show): |
| cm = num_classes[m] |
|
|
| valid_region = logits_padded[:, m, :cm] |
| padded_region = logits_padded[:, m, cm:] |
|
|
| print(f"col {m:02d} valid region shape:", tuple(valid_region.shape)) |
|
|
| if padded_region.numel() > 0: |
| print(f"col {m:02d} padded region mean:", padded_region.mean().item()) |
|
|
| print() |
|
|
| print("Demo finished successfully.") |
|
|
|
|
| if __name__ == "__main__": |
| _demo_main() |
|
|