| |
| |
|
|
| """ |
| Categorical embedding module for tabular transformer. |
| |
| Design: |
| - Each categorical column = 1 token |
| - Value embedding: ONE global lookup table using (offset + local_id) |
| - ID embedding: ONE categorical column-ID embedding table |
| - Explicit col_id stored in cat_vocab.json (no implicit ordering assumptions) |
| |
| Outputs: |
| local_ids [B,M] -> tokens [B,M,H] |
| """ |
|
|
| from dataclasses import dataclass |
| from typing import Dict, List, Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from utils import load_json, save_json |
|
|
| SPECIAL_MASK = "__MASK__" |
|
|
|
|
| |
| |
| |
|
|
| def get_categorical_feature_names_from_meta(tabular_meta: Dict) -> List[str]: |
| """ |
| Deterministic ordering: |
| alphabetical by feature name. |
| """ |
| cols = [] |
| for k, v in tabular_meta.items(): |
| if v.get("dataclass") == "categorical" and not v.get("is_array_valued", False): |
| cols.append(k) |
| return sorted(cols) |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class CatColSpec: |
| name: str |
| col_id: int |
| offset: int |
| num_classes: int |
| mask_local_id: int |
| label2id: Dict[str, int] |
|
|
|
|
| def build_cat_vocab_spec_from_meta( |
| tabular_meta: Dict, |
| categorical_feature_names: List[str], |
| label_order: str = "alpha", |
| ) -> Dict[str, CatColSpec]: |
| vocab: Dict[str, CatColSpec] = {} |
|
|
| offset = 0 |
| for j, col in enumerate(categorical_feature_names): |
| info = tabular_meta[col] |
| class_stats = info.get("class_stats", {}) or {} |
|
|
| |
| if label_order == "alpha": |
| labels = sorted(class_stats.keys()) |
| elif label_order == "freq_desc": |
| labels = sorted(class_stats.keys(), key=lambda k: (-class_stats[k], k)) |
| else: |
| raise ValueError("label_order must be alpha or freq_desc") |
|
|
| label2id = {lab: i for i, lab in enumerate(labels)} |
|
|
| mask_local_id = len(labels) |
| label2id[SPECIAL_MASK] = mask_local_id |
|
|
| spec = CatColSpec( |
| name=col, |
| col_id=j, |
| offset=offset, |
| num_classes=mask_local_id + 1, |
| mask_local_id=mask_local_id, |
| label2id=label2id, |
| ) |
| vocab[col] = spec |
|
|
| offset += spec.num_classes |
|
|
| return vocab |
|
|
|
|
| def save_cat_vocab_json(vocab: Dict[str, CatColSpec], path: str) -> None: |
| out = {} |
|
|
| for col, spec in vocab.items(): |
| out[col] = { |
| "col_id": spec.col_id, |
| "offset": spec.offset, |
| "num_classes": spec.num_classes, |
| "mask_local_id": spec.mask_local_id, |
| "global_id_start": spec.offset, |
| "global_id_end": spec.offset + spec.num_classes - 1, |
| "label2id": spec.label2id, |
| } |
|
|
| save_json(out, path) |
|
|
|
|
| |
| |
| |
|
|
| class CategoricalValueEmbedding(nn.Module): |
| """ |
| Global value embedding using offsets. |
| """ |
|
|
| def __init__(self, hidden_size: int, cat_vocab_json: str): |
| super().__init__() |
|
|
| spec = load_json(cat_vocab_json) |
|
|
| |
| items = sorted(spec.items(), key=lambda x: x[1]["col_id"]) |
|
|
| offsets = [] |
| num_classes = [] |
| col_ids = [] |
|
|
| total_vocab = 0 |
|
|
| for name, s in items: |
| offsets.append(int(s["offset"])) |
| num_classes.append(int(s["num_classes"])) |
| col_ids.append(int(s["col_id"])) |
| total_vocab = max(total_vocab, s["offset"] + s["num_classes"]) |
|
|
| self.hidden_size = int(hidden_size) |
| self.total_vocab_size = int(total_vocab) |
| |
| self.emb = nn.Embedding(self.total_vocab_size, self.hidden_size) |
|
|
| self.register_buffer("offsets", torch.tensor(offsets, dtype=torch.long), persistent=True) |
| self.register_buffer("num_classes", torch.tensor(num_classes, dtype=torch.long), persistent=True) |
| self.register_buffer("col_ids", torch.tensor(col_ids, dtype=torch.long), persistent=True) |
|
|
| def init_weights(self, std=0.02): |
| nn.init.normal_(self.emb.weight, std=std) |
|
|
| def forward(self, local_ids: torch.LongTensor) -> torch.Tensor: |
| """ |
| local_ids: [B,M] |
| returns: [B,M,H] |
| """ |
|
|
| if local_ids.dim() != 2: |
| raise ValueError("local_ids must be [B,M]") |
|
|
| B, M = local_ids.shape |
|
|
| if M != self.offsets.numel(): |
| raise ValueError("Column count mismatch") |
|
|
| if torch.any(local_ids < 0): |
| raise ValueError("Negative local_id") |
|
|
| nc = self.num_classes.view(1, M).expand(B, M) |
| if torch.any(local_ids >= nc): |
| raise ValueError("local_ids out of range") |
|
|
| gid = self.offsets.view(1, M) + local_ids |
| return self.emb(gid) |
|
|
|
|
| class CategoricalIdEmbedding(nn.Module): |
| """ |
| Explicit categorical column ID embedding. |
| """ |
|
|
| def __init__(self, hidden_size: int, cat_vocab_json: str): |
| super().__init__() |
|
|
| spec = load_json(cat_vocab_json) |
| items = sorted(spec.items(), key=lambda x: x[1]["col_id"]) |
|
|
| col_ids = [s["col_id"] for _, s in items] |
| max_col_id = max(col_ids) |
|
|
| self.emb = nn.Embedding(max_col_id + 1, hidden_size) |
|
|
| self.register_buffer( |
| "cat_col_ids", |
| torch.tensor(col_ids, dtype=torch.long), |
| persistent=True, |
| ) |
|
|
| self.hidden_size = hidden_size |
|
|
| def init_weights(self, std=0.02): |
| nn.init.normal_(self.emb.weight, std=std) |
|
|
| def forward(self, batch_size: int) -> torch.Tensor: |
| """ |
| returns [B,M,H] |
| """ |
| id_vec = self.emb(self.cat_col_ids) |
| return id_vec.view(1, -1, self.hidden_size).expand(batch_size, -1, -1) |
|
|
|
|
| class CategoricalEmbedding(nn.Module): |
| """ |
| token = value_embedding + categorical_id_embedding |
| """ |
|
|
| def __init__(self, hidden_size: int, cat_vocab_json: str): |
| super().__init__() |
|
|
| self.value_emb = CategoricalValueEmbedding(hidden_size, cat_vocab_json) |
| self.id_emb = CategoricalIdEmbedding(hidden_size, cat_vocab_json) |
|
|
| def init_weights(self, std=0.02): |
| self.value_emb.init_weights(std=std) |
| self.id_emb.init_weights(std=std) |
|
|
| def forward( |
| self, |
| local_ids: torch.LongTensor, |
| valid_positions: Optional[torch.Tensor] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Returns: |
| tokens: [B, M, H] |
| token_mask: [B, M] (1=valid, 0=invalid) |
| """ |
| if local_ids.dim() != 2: |
| raise ValueError(f"local_ids must be [B,M], got {tuple(local_ids.shape)}") |
| B, M = local_ids.shape |
|
|
| tokens = self.value_emb(local_ids) + self.id_emb(B) |
|
|
| |
| valid = torch.ones((B, M), dtype=torch.bool, device=local_ids.device) |
|
|
| if valid_positions is not None: |
| if valid_positions.dtype == torch.bool: |
| if valid_positions.shape != (B, M): |
| raise ValueError( |
| f"valid_positions (bool) must be [B,M]=({B}, {M}), got {tuple(valid_positions.shape)}") |
| valid = valid_positions.to(device=local_ids.device) |
| else: |
| |
| if valid_positions.dim() != 2 or valid_positions.size(1) != 2: |
| raise ValueError("valid_positions (indices) must be [K,2] with (batch_idx, col_idx)") |
| valid = torch.zeros((B, M), dtype=torch.bool, device=local_ids.device) |
| b_idx = valid_positions[:, 0].to(device=local_ids.device, dtype=torch.long) |
| m_idx = valid_positions[:, 1].to(device=local_ids.device, dtype=torch.long) |
| valid[b_idx, m_idx] = True |
|
|
| |
| token_mask = valid.to(dtype=torch.long) |
|
|
| |
| |
| |
| |
| |
|
|
| return tokens, token_mask |
|
|
|
|
| |
| |
| |
|
|
| def _demo_main(): |
| import argparse |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument("--tabular_meta", type=str, default="data/tabular_meta.json") |
| parser.add_argument("--cat_vocab_json", type=str, default="data/cat_vocab.json") |
| parser.add_argument("--hidden_size", type=int, default=768) |
| parser.add_argument("--batch_size", type=int, default=4) |
| args = parser.parse_args() |
|
|
| tabular_meta = load_json(args.tabular_meta) |
|
|
| cat_names = get_categorical_feature_names_from_meta(tabular_meta) |
| print(f"Found {len(cat_names)} categorical columns") |
|
|
| vocab = build_cat_vocab_spec_from_meta(tabular_meta, cat_names) |
| save_cat_vocab_json(vocab, args.cat_vocab_json) |
| print(f"Saved vocab to {args.cat_vocab_json}") |
|
|
| model = CategoricalEmbedding( |
| hidden_size=args.hidden_size, |
| cat_vocab_json=args.cat_vocab_json, |
| ) |
| total_params = sum(p.numel() for p in model.parameters()) |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| print(f"Total parameters (CategoricalEmbedding): {total_params:,} (trainable: {trainable_params:,})") |
|
|
| B = args.batch_size |
| M = len(cat_names) |
|
|
| local_ids = torch.zeros((B, M), dtype=torch.long) |
|
|
| with torch.no_grad(): |
| out, mask = model(local_ids) |
|
|
| print("local_ids:", tuple(local_ids.shape)) |
| print("output:", tuple(out.shape)) |
| print("mask:", tuple(mask.shape)) |
|
|
|
|
| if __name__ == "__main__": |
| _demo_main() |
|
|