| """ |
| Domain Schema Definition for domainTokenizer. |
| |
| A declarative format that describes the fields in a domain's event data. |
| Each field type maps to a specific tokenization strategy. |
| |
| References: |
| - Nubank nuFormer (arXiv:2507.23267): amount sign(2) + amount bucket(21) + calendar(74) + BPE text |
| - ActionPiece (arXiv:2502.13581): unordered feature sets tokenized via BPE-like merging |
| - Banking TF (arXiv:2410.08243): date + amount + wording composite tokens |
| """ |
|
|
| from dataclasses import dataclass, field |
| from enum import Enum |
| from typing import Dict, List, Optional, Any |
|
|
|
|
| class FieldType(Enum): |
| """Supported field types for domain tokenization. |
| |
| Each type maps to a specific tokenization strategy: |
| SIGN β 2 tokens (positive/negative, credit/debit) |
| NUMERICAL_CONTINUOUS β quantile-based magnitude bins (default: 21, following Nubank) |
| NUMERICAL_DISCRETE β small fixed vocabulary for countable values (quantities, counts) |
| CATEGORICAL_FIXED β direct vocabulary mapping from a known set of categories |
| TEMPORAL β calendar decomposition into month/dow/dom/hour tokens |
| TEXT β BPE subword tokenization (for descriptions, names, free text) |
| """ |
| SIGN = "sign" |
| NUMERICAL_CONTINUOUS = "numerical_continuous" |
| NUMERICAL_DISCRETE = "numerical_discrete" |
| CATEGORICAL_FIXED = "categorical_fixed" |
| TEMPORAL = "temporal" |
| TEXT = "text" |
|
|
|
|
| |
| CALENDAR_FIELD_SIZES: Dict[str, int] = { |
| "month": 12, |
| "dow": 7, |
| "dom": 31, |
| "hour": 24, |
| "quarter": 4, |
| "minute_bin": 4, |
| } |
|
|
|
|
| @dataclass |
| class FieldSpec: |
| """Specification for a single field in a domain event. |
| |
| Args: |
| name: Field name (must match the key in event dictionaries). |
| field_type: How this field should be tokenized. |
| prefix: Token prefix override. Defaults to uppercase field name. |
| n_bins: Number of quantile bins for NUMERICAL_CONTINUOUS (default: 21, Nubank). |
| categories: List of category values for CATEGORICAL_FIXED. |
| vocab_size: Explicit vocab size for NUMERICAL_DISCRETE (e.g., 11 for quantities 0-10). |
| calendar_fields: Which calendar components to extract for TEMPORAL fields. |
| max_value: Upper bound for NUMERICAL_DISCRETE (tokens: 0..max_value, max_value+). |
| """ |
| name: str |
| field_type: FieldType |
| prefix: Optional[str] = None |
| n_bins: int = 21 |
| categories: Optional[List[str]] = None |
| vocab_size: Optional[int] = None |
| calendar_fields: List[str] = field(default_factory=lambda: ["month", "dow", "dom", "hour"]) |
| max_value: Optional[int] = None |
|
|
| def __post_init__(self): |
| if self.prefix is None: |
| self.prefix = self.name.upper() |
| |
| |
| if self.field_type == FieldType.CATEGORICAL_FIXED and self.categories is None: |
| raise ValueError(f"Field '{self.name}': CATEGORICAL_FIXED requires 'categories' list") |
| |
| if self.field_type == FieldType.NUMERICAL_DISCRETE and self.max_value is None: |
| raise ValueError(f"Field '{self.name}': NUMERICAL_DISCRETE requires 'max_value'") |
|
|
| @property |
| def token_count(self) -> int: |
| """Number of special tokens this field produces in the vocabulary.""" |
| if self.field_type == FieldType.SIGN: |
| return 2 |
| elif self.field_type == FieldType.NUMERICAL_CONTINUOUS: |
| return self.n_bins |
| elif self.field_type == FieldType.NUMERICAL_DISCRETE: |
| return self.max_value + 2 |
| elif self.field_type == FieldType.CATEGORICAL_FIXED: |
| return len(self.categories) + 1 |
| elif self.field_type == FieldType.TEMPORAL: |
| return sum(CALENDAR_FIELD_SIZES.get(cf, 0) for cf in self.calendar_fields) |
| elif self.field_type == FieldType.TEXT: |
| return 0 |
| return 0 |
|
|
| @property |
| def tokens_per_event(self) -> int: |
| """Number of tokens this field contributes per event (fixed part only).""" |
| if self.field_type == FieldType.SIGN: |
| return 1 |
| elif self.field_type in (FieldType.NUMERICAL_CONTINUOUS, FieldType.NUMERICAL_DISCRETE, FieldType.CATEGORICAL_FIXED): |
| return 1 |
| elif self.field_type == FieldType.TEMPORAL: |
| return len(self.calendar_fields) |
| elif self.field_type == FieldType.TEXT: |
| return 0 |
| return 0 |
|
|
|
|
| @dataclass |
| class DomainSchema: |
| """Complete schema for a domain's event data. |
| |
| A schema defines the ordered list of fields that make up each event |
| (transaction, purchase, clinical encounter, etc.). The field order |
| determines the token order within each event. |
| |
| Args: |
| name: Human-readable domain name (e.g., "finance", "ecommerce"). |
| fields: Ordered list of field specifications. |
| event_separator: Special token to separate events in a sequence. |
| description: Optional human-readable description. |
| """ |
| name: str |
| fields: List[FieldSpec] |
| event_separator: str = "[SEP_EVENT]" |
| description: str = "" |
|
|
| @property |
| def special_token_count(self) -> int: |
| """Total number of domain-specific special tokens needed.""" |
| |
| base = 8 |
| return base + sum(f.token_count for f in self.fields) |
|
|
| @property |
| def fixed_tokens_per_event(self) -> int: |
| """Number of fixed (non-text) tokens per event, including separator.""" |
| return 1 + sum(f.tokens_per_event for f in self.fields) |
|
|
| @property |
| def has_text_fields(self) -> bool: |
| """Whether the schema includes any free-text fields.""" |
| return any(f.field_type == FieldType.TEXT for f in self.fields) |
|
|
| @property |
| def text_field_names(self) -> List[str]: |
| """Names of all text fields in the schema.""" |
| return [f.name for f in self.fields if f.field_type == FieldType.TEXT] |
|
|
| @property |
| def fittable_field_names(self) -> List[str]: |
| """Names of fields that require fitting on training data.""" |
| return [f.name for f in self.fields if f.field_type == FieldType.NUMERICAL_CONTINUOUS] |
|
|
| def get_field(self, name: str) -> Optional[FieldSpec]: |
| """Look up a field by name.""" |
| for f in self.fields: |
| if f.name == name: |
| return f |
| return None |
|
|
| def summary(self) -> str: |
| """Human-readable summary of the schema.""" |
| lines = [f"DomainSchema: {self.name}"] |
| if self.description: |
| lines.append(f" {self.description}") |
| lines.append(f" Fields: {len(self.fields)}") |
| lines.append(f" Special tokens: {self.special_token_count}") |
| lines.append(f" Fixed tokens/event: {self.fixed_tokens_per_event}") |
| lines.append(f" Has text fields: {self.has_text_fields}") |
| lines.append(f" Requires fitting: {self.fittable_field_names}") |
| lines.append("") |
| for f in self.fields: |
| lines.append(f" [{f.field_type.value}] {f.name} β prefix={f.prefix}, " |
| f"tokens_in_vocab={f.token_count}, tokens_per_event={f.tokens_per_event}") |
| return "\n".join(lines) |
|
|