File size: 8,414 Bytes
6fb6c07 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 | # decode_numeric.py
# -*- coding: utf-8 -*-
"""
Numeric decoder module for tabular transformer.
Symmetric to embed_numeric.py (bucketed by n_in):
- For each bucket (same n_in), we decode tokens without a Python for-loop over columns.
- Uses a batched per-variable MLP with per-column parameters (NOT shared across V).
Input:
x_tokens: [B, total_numeric_tokens, H]
token order must match numeric_vocab.json:
groups by n_in ascending, within group by feature name,
and within each feature: n_in tokens.
Output:
values_by_nin: Dict[int, Tensor]
n_in -> x_hat [B, V, n_in]
middle_size:
- None: 1-layer per-variable Linear
- int : 2-layer per-variable MLP (Linear -> GELU -> Linear)
"""
from typing import Dict, List, Optional
import torch
import torch.nn as nn
from utils import GroupedMLP, load_json
class NumericDecoder(nn.Module):
"""
Decode numeric tokens back to numeric values, bucketed by n_in.
Input:
x_tokens: [B, total_numeric_tokens, H]
Output:
values_by_nin:
n_in -> y_hat [B, V, n_in]
s_by_nin:
n_in -> s [B, V]
where s = log(sigma^2), shared across the n_in dimensions
of each variable, intended for heteroscedastic loss computation.
"""
def __init__(
self,
hidden_size: int,
numeric_vocab_json: str,
middle_size: Optional[int] = None,
homoscedastic: bool = True,
):
super().__init__()
self.hidden_size = int(hidden_size)
self.middle_size = None if middle_size is None else int(middle_size)
self.homoscedastic = bool(homoscedastic)
spec = load_json(numeric_vocab_json)
self.groups: List[Dict] = list(spec["groups"])
self.total_numeric_tokens = int(spec["total_numeric_tokens"])
self.group_token_offsets: Dict[str, int] = dict(spec.get("group_token_offsets", {}))
self.group_v_decoders = nn.ModuleList()
self.group_s_decoders = nn.ModuleList()
self.group_nins: List[int] = []
self.group_Vs: List[int] = []
for g in self.groups:
n_in = int(g["n_in"])
names = list(g["feature_names"])
V = len(names)
self.group_nins.append(n_in) # noqa
self.group_Vs.append(V)
# value decoder: [B,V,n_in*H] -> [B,V,n_in]
self.group_v_decoders.append(
GroupedMLP(
n_var=V,
n_in=n_in * self.hidden_size,
n_out=n_in,
middle_size=self.middle_size,
)
)
# uncertainty decoder: [B,V,H] -> [B,V,1] -> [B,V]
if not self.homoscedastic:
self.group_s_decoders.append(
GroupedMLP(
n_var=V,
n_in=self.hidden_size,
n_out=1,
middle_size=self.middle_size,
)
)
if self.homoscedastic:
self.group_s_params = nn.ParameterList(
[nn.Parameter(torch.zeros(V)) for V in self.group_Vs]
)
else:
self.group_s_params = None
# spec integrity check
running = 0
for g in self.groups:
n_in = int(g["n_in"])
V = len(g["feature_names"])
key = str(n_in)
if key not in self.group_token_offsets:
raise ValueError(f"Missing group_token_offsets entry for n_in={n_in}")
if int(self.group_token_offsets[key]) != running:
raise ValueError(
f"group_token_offsets[{key}]={self.group_token_offsets[key]} does not match expected {running}"
)
running += V * n_in
if running != self.total_numeric_tokens:
raise ValueError(
f"total_numeric_tokens={self.total_numeric_tokens} does not match expected {running}"
)
def init_weights(self, std: float = 0.02):
for dec in self.group_v_decoders:
dec.init_weights(std=std)
if self.homoscedastic:
for p in self.group_s_params:
nn.init.zeros_(p)
else:
for dec in self.group_s_decoders:
dec.init_weights(std=0.0)
def forward(self, x_tokens: torch.Tensor):
if x_tokens.dim() != 3:
raise ValueError(f"x_tokens must be [B,T,H], got {tuple(x_tokens.shape)}")
B, T, H = x_tokens.shape
if H != self.hidden_size:
raise ValueError(f"hidden_size mismatch: got H={H}, expected {self.hidden_size}")
if T != self.total_numeric_tokens:
raise ValueError(f"token length mismatch: got T={T}, expected {self.total_numeric_tokens}")
value_out: Dict[int, torch.Tensor] = {}
s_out: Dict[int, torch.Tensor] = {}
for gi, n_in in enumerate(self.group_nins):
key = str(n_in)
start = int(self.group_token_offsets[key])
V = self.group_Vs[gi]
length = V * n_in
xg_tok = x_tokens[:, start:start + length, :] # [B, V*n_in, H]
xg_tok4 = xg_tok.reshape(B, V, n_in, H) # [B, V, n_in, H]
xg_flat = xg_tok4.reshape(B, V, n_in * H) # [B, V, n_in*H]
# values: [B, V, n_in]
y = self.group_v_decoders[gi](xg_flat)
# s = log sigma^2: [B, V]
if self.homoscedastic:
s = self.group_s_params[gi].unsqueeze(0).expand(B, -1)
else:
x_var = xg_tok4.mean(dim=2) # [B, V, H]
s = self.group_s_decoders[gi](x_var).squeeze(-1) # [B, V]
value_out[n_in] = y
s_out[n_in] = s
return value_out, s_out
# ============================================================
# DEMO
# ============================================================
def _demo_main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--numeric_vocab_json", type=str, default="data/numeric_vocab.json")
parser.add_argument("--hidden_size", type=int, default=768)
parser.add_argument("--middle_size", type=int, default=-1,
help="If <0 -> one-layer. If >=0 -> two-layer with this middle size.")
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--device", type=str, default=None)
parser.add_argument("--dtype", type=str, default="float32", choices=["float16", "bfloat16", "float32"])
args = parser.parse_args()
device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu"))
dtype_map = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}
dtype = dtype_map[args.dtype]
# Directly load existing numeric vocab spec
spec = load_json(args.numeric_vocab_json)
print(f"Loaded numeric vocab spec from: {args.numeric_vocab_json}")
print(f"Groups (n_in -> V):", {int(g['n_in']): len(g['feature_names']) for g in spec["groups"]})
print("total_numeric_tokens:", spec["total_numeric_tokens"])
print("group_token_offsets:", spec["group_token_offsets"])
middle_size = None if args.middle_size < 0 else int(args.middle_size)
model = NumericDecoder(
hidden_size=args.hidden_size,
numeric_vocab_json=args.numeric_vocab_json,
middle_size=middle_size,
).to(device=device, dtype=dtype)
model.eval()
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters (NumericDecoder): {total_params:,} (trainable: {trainable_params:,})")
B = args.batch_size
T = int(spec["total_numeric_tokens"])
H = args.hidden_size
x_tokens = torch.randn(B, T, H, device=device, dtype=dtype)
with torch.no_grad():
values_by_nin, s_by_nin = model(x_tokens)
print("Input tokens:", tuple(x_tokens.shape), x_tokens.dtype, x_tokens.device)
print("Decoded values:", {k: tuple(v.shape) for k, v in values_by_nin.items()})
print("Decoded s:", {k: tuple(s.shape) for k, s in s_by_nin.items()})
# values_by_nin[n_in]: [B, V, n_in]
# s_by_nin[n_in]: [B, V]
if __name__ == "__main__":
_demo_main()
|