# embed_categorical.py # -*- coding: utf-8 -*- """ Categorical embedding module for tabular transformer. Design: - Each categorical column = 1 token - Value embedding: ONE global lookup table using (offset + local_id) - ID embedding: ONE categorical column-ID embedding table - Explicit col_id stored in cat_vocab.json (no implicit ordering assumptions) Outputs: local_ids [B,M] -> tokens [B,M,H] """ from dataclasses import dataclass from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn from utils import load_json, save_json SPECIAL_MASK = "__MASK__" # ============================================================ # Meta → categorical column list # ============================================================ def get_categorical_feature_names_from_meta(tabular_meta: Dict) -> List[str]: """ Deterministic ordering: alphabetical by feature name. """ cols = [] for k, v in tabular_meta.items(): if v.get("dataclass") == "categorical" and not v.get("is_array_valued", False): cols.append(k) return sorted(cols) # ============================================================ # Vocab spec # ============================================================ @dataclass class CatColSpec: name: str col_id: int offset: int num_classes: int mask_local_id: int label2id: Dict[str, int] def build_cat_vocab_spec_from_meta( tabular_meta: Dict, categorical_feature_names: List[str], label_order: str = "alpha", ) -> Dict[str, CatColSpec]: vocab: Dict[str, CatColSpec] = {} offset = 0 for j, col in enumerate(categorical_feature_names): info = tabular_meta[col] class_stats = info.get("class_stats", {}) or {} # deterministic label order if label_order == "alpha": labels = sorted(class_stats.keys()) elif label_order == "freq_desc": labels = sorted(class_stats.keys(), key=lambda k: (-class_stats[k], k)) else: raise ValueError("label_order must be alpha or freq_desc") label2id = {lab: i for i, lab in enumerate(labels)} mask_local_id = len(labels) label2id[SPECIAL_MASK] = mask_local_id spec = CatColSpec( name=col, col_id=j, # EXPLICIT categorical column id offset=offset, num_classes=mask_local_id + 1, mask_local_id=mask_local_id, label2id=label2id, ) vocab[col] = spec offset += spec.num_classes return vocab def save_cat_vocab_json(vocab: Dict[str, CatColSpec], path: str) -> None: out = {} for col, spec in vocab.items(): out[col] = { "col_id": spec.col_id, "offset": spec.offset, "num_classes": spec.num_classes, "mask_local_id": spec.mask_local_id, "global_id_start": spec.offset, "global_id_end": spec.offset + spec.num_classes - 1, "label2id": spec.label2id, } save_json(out, path) # ============================================================ # Embedding modules # ============================================================ class CategoricalValueEmbedding(nn.Module): """ Global value embedding using offsets. """ def __init__(self, hidden_size: int, cat_vocab_json: str): super().__init__() spec = load_json(cat_vocab_json) # sort by col_id to ensure consistent tensor layout items = sorted(spec.items(), key=lambda x: x[1]["col_id"]) offsets = [] num_classes = [] col_ids = [] total_vocab = 0 for name, s in items: offsets.append(int(s["offset"])) num_classes.append(int(s["num_classes"])) col_ids.append(int(s["col_id"])) total_vocab = max(total_vocab, s["offset"] + s["num_classes"]) self.hidden_size = int(hidden_size) self.total_vocab_size = int(total_vocab) # Merge all classes to avoid many small nn.Embedding modules self.emb = nn.Embedding(self.total_vocab_size, self.hidden_size) self.register_buffer("offsets", torch.tensor(offsets, dtype=torch.long), persistent=True) self.register_buffer("num_classes", torch.tensor(num_classes, dtype=torch.long), persistent=True) self.register_buffer("col_ids", torch.tensor(col_ids, dtype=torch.long), persistent=True) def init_weights(self, std=0.02): nn.init.normal_(self.emb.weight, std=std) def forward(self, local_ids: torch.LongTensor) -> torch.Tensor: """ local_ids: [B,M] returns: [B,M,H] """ if local_ids.dim() != 2: raise ValueError("local_ids must be [B,M]") B, M = local_ids.shape if M != self.offsets.numel(): raise ValueError("Column count mismatch") if torch.any(local_ids < 0): raise ValueError("Negative local_id") nc = self.num_classes.view(1, M).expand(B, M) if torch.any(local_ids >= nc): raise ValueError("local_ids out of range") gid = self.offsets.view(1, M) + local_ids return self.emb(gid) class CategoricalIdEmbedding(nn.Module): """ Explicit categorical column ID embedding. """ def __init__(self, hidden_size: int, cat_vocab_json: str): super().__init__() spec = load_json(cat_vocab_json) items = sorted(spec.items(), key=lambda x: x[1]["col_id"]) col_ids = [s["col_id"] for _, s in items] max_col_id = max(col_ids) self.emb = nn.Embedding(max_col_id + 1, hidden_size) self.register_buffer( "cat_col_ids", torch.tensor(col_ids, dtype=torch.long), persistent=True, ) self.hidden_size = hidden_size def init_weights(self, std=0.02): nn.init.normal_(self.emb.weight, std=std) def forward(self, batch_size: int) -> torch.Tensor: """ returns [B,M,H] """ id_vec = self.emb(self.cat_col_ids) # [M,H] return id_vec.view(1, -1, self.hidden_size).expand(batch_size, -1, -1) class CategoricalEmbedding(nn.Module): """ token = value_embedding + categorical_id_embedding """ def __init__(self, hidden_size: int, cat_vocab_json: str): super().__init__() self.value_emb = CategoricalValueEmbedding(hidden_size, cat_vocab_json) self.id_emb = CategoricalIdEmbedding(hidden_size, cat_vocab_json) def init_weights(self, std=0.02): self.value_emb.init_weights(std=std) self.id_emb.init_weights(std=std) def forward( self, local_ids: torch.LongTensor, # [B, M] valid_positions: Optional[torch.Tensor] = None, # Bool [B,M] (True=valid) or indices [K,2] ) -> Tuple[torch.Tensor, torch.Tensor]: """ Returns: tokens: [B, M, H] token_mask: [B, M] (1=valid, 0=invalid) """ if local_ids.dim() != 2: raise ValueError(f"local_ids must be [B,M], got {tuple(local_ids.shape)}") B, M = local_ids.shape tokens = self.value_emb(local_ids) + self.id_emb(B) # [B,M,H] # Default: all tokens are valid valid = torch.ones((B, M), dtype=torch.bool, device=local_ids.device) if valid_positions is not None: if valid_positions.dtype == torch.bool: if valid_positions.shape != (B, M): raise ValueError( f"valid_positions (bool) must be [B,M]=({B}, {M}), got {tuple(valid_positions.shape)}") valid = valid_positions.to(device=local_ids.device) else: # Optional: support index pairs [K,2] where each row is (b_idx, m_idx) for valid positions if valid_positions.dim() != 2 or valid_positions.size(1) != 2: raise ValueError("valid_positions (indices) must be [K,2] with (batch_idx, col_idx)") valid = torch.zeros((B, M), dtype=torch.bool, device=local_ids.device) b_idx = valid_positions[:, 0].to(device=local_ids.device, dtype=torch.long) m_idx = valid_positions[:, 1].to(device=local_ids.device, dtype=torch.long) valid[b_idx, m_idx] = True # Token mask: 1=valid, 0=invalid token_mask = valid.to(dtype=torch.long) # [B,M] # This is WRONG: we should allow __MASK__ to attend other columns # # Invalid tokens must not contribute # invalid = ~valid # if invalid.any(): # tokens = tokens.masked_fill(invalid.unsqueeze(-1), 0.0) return tokens, token_mask # ============================================================ # DEMO # ============================================================ def _demo_main(): import argparse parser = argparse.ArgumentParser() parser.add_argument("--tabular_meta", type=str, default="data/tabular_meta.json") parser.add_argument("--cat_vocab_json", type=str, default="data/cat_vocab.json") parser.add_argument("--hidden_size", type=int, default=768) parser.add_argument("--batch_size", type=int, default=4) args = parser.parse_args() tabular_meta = load_json(args.tabular_meta) cat_names = get_categorical_feature_names_from_meta(tabular_meta) print(f"Found {len(cat_names)} categorical columns") vocab = build_cat_vocab_spec_from_meta(tabular_meta, cat_names) save_cat_vocab_json(vocab, args.cat_vocab_json) print(f"Saved vocab to {args.cat_vocab_json}") model = CategoricalEmbedding( hidden_size=args.hidden_size, cat_vocab_json=args.cat_vocab_json, ) total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Total parameters (CategoricalEmbedding): {total_params:,} (trainable: {trainable_params:,})") B = args.batch_size M = len(cat_names) local_ids = torch.zeros((B, M), dtype=torch.long) with torch.no_grad(): out, mask = model(local_ids) print("local_ids:", tuple(local_ids.shape)) print("output:", tuple(out.shape)) # [B,M,H] print("mask:", tuple(mask.shape)) # [B,M] if __name__ == "__main__": _demo_main()