File size: 7,489 Bytes
1a9dad0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 | """
Domain Schema Definition for domainTokenizer.
A declarative format that describes the fields in a domain's event data.
Each field type maps to a specific tokenization strategy.
References:
- Nubank nuFormer (arXiv:2507.23267): amount sign(2) + amount bucket(21) + calendar(74) + BPE text
- ActionPiece (arXiv:2502.13581): unordered feature sets tokenized via BPE-like merging
- Banking TF (arXiv:2410.08243): date + amount + wording composite tokens
"""
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, List, Optional, Any
class FieldType(Enum):
"""Supported field types for domain tokenization.
Each type maps to a specific tokenization strategy:
SIGN → 2 tokens (positive/negative, credit/debit)
NUMERICAL_CONTINUOUS → quantile-based magnitude bins (default: 21, following Nubank)
NUMERICAL_DISCRETE → small fixed vocabulary for countable values (quantities, counts)
CATEGORICAL_FIXED → direct vocabulary mapping from a known set of categories
TEMPORAL → calendar decomposition into month/dow/dom/hour tokens
TEXT → BPE subword tokenization (for descriptions, names, free text)
"""
SIGN = "sign"
NUMERICAL_CONTINUOUS = "numerical_continuous"
NUMERICAL_DISCRETE = "numerical_discrete"
CATEGORICAL_FIXED = "categorical_fixed"
TEMPORAL = "temporal"
TEXT = "text"
# Mapping from calendar field name → number of tokens it produces
CALENDAR_FIELD_SIZES: Dict[str, int] = {
"month": 12,
"dow": 7, # day of week
"dom": 31, # day of month
"hour": 24,
"quarter": 4,
"minute_bin": 4, # 15-min bins: 0-14, 15-29, 30-44, 45-59
}
@dataclass
class FieldSpec:
"""Specification for a single field in a domain event.
Args:
name: Field name (must match the key in event dictionaries).
field_type: How this field should be tokenized.
prefix: Token prefix override. Defaults to uppercase field name.
n_bins: Number of quantile bins for NUMERICAL_CONTINUOUS (default: 21, Nubank).
categories: List of category values for CATEGORICAL_FIXED.
vocab_size: Explicit vocab size for NUMERICAL_DISCRETE (e.g., 11 for quantities 0-10).
calendar_fields: Which calendar components to extract for TEMPORAL fields.
max_value: Upper bound for NUMERICAL_DISCRETE (tokens: 0..max_value, max_value+).
"""
name: str
field_type: FieldType
prefix: Optional[str] = None
n_bins: int = 21
categories: Optional[List[str]] = None
vocab_size: Optional[int] = None
calendar_fields: List[str] = field(default_factory=lambda: ["month", "dow", "dom", "hour"])
max_value: Optional[int] = None
def __post_init__(self):
if self.prefix is None:
self.prefix = self.name.upper()
# Validation
if self.field_type == FieldType.CATEGORICAL_FIXED and self.categories is None:
raise ValueError(f"Field '{self.name}': CATEGORICAL_FIXED requires 'categories' list")
if self.field_type == FieldType.NUMERICAL_DISCRETE and self.max_value is None:
raise ValueError(f"Field '{self.name}': NUMERICAL_DISCRETE requires 'max_value'")
@property
def token_count(self) -> int:
"""Number of special tokens this field produces in the vocabulary."""
if self.field_type == FieldType.SIGN:
return 2
elif self.field_type == FieldType.NUMERICAL_CONTINUOUS:
return self.n_bins
elif self.field_type == FieldType.NUMERICAL_DISCRETE:
return self.max_value + 2 # 0..max_value + overflow bin
elif self.field_type == FieldType.CATEGORICAL_FIXED:
return len(self.categories) + 1 # +1 for unknown
elif self.field_type == FieldType.TEMPORAL:
return sum(CALENDAR_FIELD_SIZES.get(cf, 0) for cf in self.calendar_fields)
elif self.field_type == FieldType.TEXT:
return 0 # text tokens come from BPE, not special vocab
return 0
@property
def tokens_per_event(self) -> int:
"""Number of tokens this field contributes per event (fixed part only)."""
if self.field_type == FieldType.SIGN:
return 1
elif self.field_type in (FieldType.NUMERICAL_CONTINUOUS, FieldType.NUMERICAL_DISCRETE, FieldType.CATEGORICAL_FIXED):
return 1
elif self.field_type == FieldType.TEMPORAL:
return len(self.calendar_fields)
elif self.field_type == FieldType.TEXT:
return 0 # variable length
return 0
@dataclass
class DomainSchema:
"""Complete schema for a domain's event data.
A schema defines the ordered list of fields that make up each event
(transaction, purchase, clinical encounter, etc.). The field order
determines the token order within each event.
Args:
name: Human-readable domain name (e.g., "finance", "ecommerce").
fields: Ordered list of field specifications.
event_separator: Special token to separate events in a sequence.
description: Optional human-readable description.
"""
name: str
fields: List[FieldSpec]
event_separator: str = "[SEP_EVENT]"
description: str = ""
@property
def special_token_count(self) -> int:
"""Total number of domain-specific special tokens needed."""
# Base special tokens: PAD, UNK, BOS, EOS, MASK, CLS, SEP, event separator
base = 8
return base + sum(f.token_count for f in self.fields)
@property
def fixed_tokens_per_event(self) -> int:
"""Number of fixed (non-text) tokens per event, including separator."""
return 1 + sum(f.tokens_per_event for f in self.fields) # +1 for event separator
@property
def has_text_fields(self) -> bool:
"""Whether the schema includes any free-text fields."""
return any(f.field_type == FieldType.TEXT for f in self.fields)
@property
def text_field_names(self) -> List[str]:
"""Names of all text fields in the schema."""
return [f.name for f in self.fields if f.field_type == FieldType.TEXT]
@property
def fittable_field_names(self) -> List[str]:
"""Names of fields that require fitting on training data."""
return [f.name for f in self.fields if f.field_type == FieldType.NUMERICAL_CONTINUOUS]
def get_field(self, name: str) -> Optional[FieldSpec]:
"""Look up a field by name."""
for f in self.fields:
if f.name == name:
return f
return None
def summary(self) -> str:
"""Human-readable summary of the schema."""
lines = [f"DomainSchema: {self.name}"]
if self.description:
lines.append(f" {self.description}")
lines.append(f" Fields: {len(self.fields)}")
lines.append(f" Special tokens: {self.special_token_count}")
lines.append(f" Fixed tokens/event: {self.fixed_tokens_per_event}")
lines.append(f" Has text fields: {self.has_text_fields}")
lines.append(f" Requires fitting: {self.fittable_field_names}")
lines.append("")
for f in self.fields:
lines.append(f" [{f.field_type.value}] {f.name} → prefix={f.prefix}, "
f"tokens_in_vocab={f.token_count}, tokens_per_event={f.tokens_per_event}")
return "\n".join(lines)
|