rtferraz's picture
Add schema.py β€” DomainSchema, FieldSpec, FieldType definitions
1a9dad0 verified
raw
history blame
7.49 kB
"""
Domain Schema Definition for domainTokenizer.
A declarative format that describes the fields in a domain's event data.
Each field type maps to a specific tokenization strategy.
References:
- Nubank nuFormer (arXiv:2507.23267): amount sign(2) + amount bucket(21) + calendar(74) + BPE text
- ActionPiece (arXiv:2502.13581): unordered feature sets tokenized via BPE-like merging
- Banking TF (arXiv:2410.08243): date + amount + wording composite tokens
"""
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, List, Optional, Any
class FieldType(Enum):
"""Supported field types for domain tokenization.
Each type maps to a specific tokenization strategy:
SIGN β†’ 2 tokens (positive/negative, credit/debit)
NUMERICAL_CONTINUOUS β†’ quantile-based magnitude bins (default: 21, following Nubank)
NUMERICAL_DISCRETE β†’ small fixed vocabulary for countable values (quantities, counts)
CATEGORICAL_FIXED β†’ direct vocabulary mapping from a known set of categories
TEMPORAL β†’ calendar decomposition into month/dow/dom/hour tokens
TEXT β†’ BPE subword tokenization (for descriptions, names, free text)
"""
SIGN = "sign"
NUMERICAL_CONTINUOUS = "numerical_continuous"
NUMERICAL_DISCRETE = "numerical_discrete"
CATEGORICAL_FIXED = "categorical_fixed"
TEMPORAL = "temporal"
TEXT = "text"
# Mapping from calendar field name β†’ number of tokens it produces
CALENDAR_FIELD_SIZES: Dict[str, int] = {
"month": 12,
"dow": 7, # day of week
"dom": 31, # day of month
"hour": 24,
"quarter": 4,
"minute_bin": 4, # 15-min bins: 0-14, 15-29, 30-44, 45-59
}
@dataclass
class FieldSpec:
"""Specification for a single field in a domain event.
Args:
name: Field name (must match the key in event dictionaries).
field_type: How this field should be tokenized.
prefix: Token prefix override. Defaults to uppercase field name.
n_bins: Number of quantile bins for NUMERICAL_CONTINUOUS (default: 21, Nubank).
categories: List of category values for CATEGORICAL_FIXED.
vocab_size: Explicit vocab size for NUMERICAL_DISCRETE (e.g., 11 for quantities 0-10).
calendar_fields: Which calendar components to extract for TEMPORAL fields.
max_value: Upper bound for NUMERICAL_DISCRETE (tokens: 0..max_value, max_value+).
"""
name: str
field_type: FieldType
prefix: Optional[str] = None
n_bins: int = 21
categories: Optional[List[str]] = None
vocab_size: Optional[int] = None
calendar_fields: List[str] = field(default_factory=lambda: ["month", "dow", "dom", "hour"])
max_value: Optional[int] = None
def __post_init__(self):
if self.prefix is None:
self.prefix = self.name.upper()
# Validation
if self.field_type == FieldType.CATEGORICAL_FIXED and self.categories is None:
raise ValueError(f"Field '{self.name}': CATEGORICAL_FIXED requires 'categories' list")
if self.field_type == FieldType.NUMERICAL_DISCRETE and self.max_value is None:
raise ValueError(f"Field '{self.name}': NUMERICAL_DISCRETE requires 'max_value'")
@property
def token_count(self) -> int:
"""Number of special tokens this field produces in the vocabulary."""
if self.field_type == FieldType.SIGN:
return 2
elif self.field_type == FieldType.NUMERICAL_CONTINUOUS:
return self.n_bins
elif self.field_type == FieldType.NUMERICAL_DISCRETE:
return self.max_value + 2 # 0..max_value + overflow bin
elif self.field_type == FieldType.CATEGORICAL_FIXED:
return len(self.categories) + 1 # +1 for unknown
elif self.field_type == FieldType.TEMPORAL:
return sum(CALENDAR_FIELD_SIZES.get(cf, 0) for cf in self.calendar_fields)
elif self.field_type == FieldType.TEXT:
return 0 # text tokens come from BPE, not special vocab
return 0
@property
def tokens_per_event(self) -> int:
"""Number of tokens this field contributes per event (fixed part only)."""
if self.field_type == FieldType.SIGN:
return 1
elif self.field_type in (FieldType.NUMERICAL_CONTINUOUS, FieldType.NUMERICAL_DISCRETE, FieldType.CATEGORICAL_FIXED):
return 1
elif self.field_type == FieldType.TEMPORAL:
return len(self.calendar_fields)
elif self.field_type == FieldType.TEXT:
return 0 # variable length
return 0
@dataclass
class DomainSchema:
"""Complete schema for a domain's event data.
A schema defines the ordered list of fields that make up each event
(transaction, purchase, clinical encounter, etc.). The field order
determines the token order within each event.
Args:
name: Human-readable domain name (e.g., "finance", "ecommerce").
fields: Ordered list of field specifications.
event_separator: Special token to separate events in a sequence.
description: Optional human-readable description.
"""
name: str
fields: List[FieldSpec]
event_separator: str = "[SEP_EVENT]"
description: str = ""
@property
def special_token_count(self) -> int:
"""Total number of domain-specific special tokens needed."""
# Base special tokens: PAD, UNK, BOS, EOS, MASK, CLS, SEP, event separator
base = 8
return base + sum(f.token_count for f in self.fields)
@property
def fixed_tokens_per_event(self) -> int:
"""Number of fixed (non-text) tokens per event, including separator."""
return 1 + sum(f.tokens_per_event for f in self.fields) # +1 for event separator
@property
def has_text_fields(self) -> bool:
"""Whether the schema includes any free-text fields."""
return any(f.field_type == FieldType.TEXT for f in self.fields)
@property
def text_field_names(self) -> List[str]:
"""Names of all text fields in the schema."""
return [f.name for f in self.fields if f.field_type == FieldType.TEXT]
@property
def fittable_field_names(self) -> List[str]:
"""Names of fields that require fitting on training data."""
return [f.name for f in self.fields if f.field_type == FieldType.NUMERICAL_CONTINUOUS]
def get_field(self, name: str) -> Optional[FieldSpec]:
"""Look up a field by name."""
for f in self.fields:
if f.name == name:
return f
return None
def summary(self) -> str:
"""Human-readable summary of the schema."""
lines = [f"DomainSchema: {self.name}"]
if self.description:
lines.append(f" {self.description}")
lines.append(f" Fields: {len(self.fields)}")
lines.append(f" Special tokens: {self.special_token_count}")
lines.append(f" Fixed tokens/event: {self.fixed_tokens_per_event}")
lines.append(f" Has text fields: {self.has_text_fields}")
lines.append(f" Requires fitting: {self.fittable_field_names}")
lines.append("")
for f in self.fields:
lines.append(f" [{f.field_type.value}] {f.name} β†’ prefix={f.prefix}, "
f"tokens_in_vocab={f.token_count}, tokens_per_event={f.tokens_per_event}")
return "\n".join(lines)