File size: 7,489 Bytes
1a9dad0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
"""
Domain Schema Definition for domainTokenizer.

A declarative format that describes the fields in a domain's event data.
Each field type maps to a specific tokenization strategy.

References:
  - Nubank nuFormer (arXiv:2507.23267): amount sign(2) + amount bucket(21) + calendar(74) + BPE text
  - ActionPiece (arXiv:2502.13581): unordered feature sets tokenized via BPE-like merging
  - Banking TF (arXiv:2410.08243): date + amount + wording composite tokens
"""

from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, List, Optional, Any


class FieldType(Enum):
    """Supported field types for domain tokenization.
    
    Each type maps to a specific tokenization strategy:
      SIGN               → 2 tokens (positive/negative, credit/debit)
      NUMERICAL_CONTINUOUS → quantile-based magnitude bins (default: 21, following Nubank)
      NUMERICAL_DISCRETE  → small fixed vocabulary for countable values (quantities, counts)
      CATEGORICAL_FIXED   → direct vocabulary mapping from a known set of categories
      TEMPORAL            → calendar decomposition into month/dow/dom/hour tokens
      TEXT                → BPE subword tokenization (for descriptions, names, free text)
    """
    SIGN = "sign"
    NUMERICAL_CONTINUOUS = "numerical_continuous"
    NUMERICAL_DISCRETE = "numerical_discrete"
    CATEGORICAL_FIXED = "categorical_fixed"
    TEMPORAL = "temporal"
    TEXT = "text"


# Mapping from calendar field name → number of tokens it produces
CALENDAR_FIELD_SIZES: Dict[str, int] = {
    "month": 12,
    "dow": 7,       # day of week
    "dom": 31,      # day of month
    "hour": 24,
    "quarter": 4,
    "minute_bin": 4, # 15-min bins: 0-14, 15-29, 30-44, 45-59
}


@dataclass
class FieldSpec:
    """Specification for a single field in a domain event.
    
    Args:
        name: Field name (must match the key in event dictionaries).
        field_type: How this field should be tokenized.
        prefix: Token prefix override. Defaults to uppercase field name.
        n_bins: Number of quantile bins for NUMERICAL_CONTINUOUS (default: 21, Nubank).
        categories: List of category values for CATEGORICAL_FIXED.
        vocab_size: Explicit vocab size for NUMERICAL_DISCRETE (e.g., 11 for quantities 0-10).
        calendar_fields: Which calendar components to extract for TEMPORAL fields.
        max_value: Upper bound for NUMERICAL_DISCRETE (tokens: 0..max_value, max_value+).
    """
    name: str
    field_type: FieldType
    prefix: Optional[str] = None
    n_bins: int = 21
    categories: Optional[List[str]] = None
    vocab_size: Optional[int] = None
    calendar_fields: List[str] = field(default_factory=lambda: ["month", "dow", "dom", "hour"])
    max_value: Optional[int] = None

    def __post_init__(self):
        if self.prefix is None:
            self.prefix = self.name.upper()
        
        # Validation
        if self.field_type == FieldType.CATEGORICAL_FIXED and self.categories is None:
            raise ValueError(f"Field '{self.name}': CATEGORICAL_FIXED requires 'categories' list")
        
        if self.field_type == FieldType.NUMERICAL_DISCRETE and self.max_value is None:
            raise ValueError(f"Field '{self.name}': NUMERICAL_DISCRETE requires 'max_value'")

    @property
    def token_count(self) -> int:
        """Number of special tokens this field produces in the vocabulary."""
        if self.field_type == FieldType.SIGN:
            return 2
        elif self.field_type == FieldType.NUMERICAL_CONTINUOUS:
            return self.n_bins
        elif self.field_type == FieldType.NUMERICAL_DISCRETE:
            return self.max_value + 2  # 0..max_value + overflow bin
        elif self.field_type == FieldType.CATEGORICAL_FIXED:
            return len(self.categories) + 1  # +1 for unknown
        elif self.field_type == FieldType.TEMPORAL:
            return sum(CALENDAR_FIELD_SIZES.get(cf, 0) for cf in self.calendar_fields)
        elif self.field_type == FieldType.TEXT:
            return 0  # text tokens come from BPE, not special vocab
        return 0

    @property
    def tokens_per_event(self) -> int:
        """Number of tokens this field contributes per event (fixed part only)."""
        if self.field_type == FieldType.SIGN:
            return 1
        elif self.field_type in (FieldType.NUMERICAL_CONTINUOUS, FieldType.NUMERICAL_DISCRETE, FieldType.CATEGORICAL_FIXED):
            return 1
        elif self.field_type == FieldType.TEMPORAL:
            return len(self.calendar_fields)
        elif self.field_type == FieldType.TEXT:
            return 0  # variable length
        return 0


@dataclass
class DomainSchema:
    """Complete schema for a domain's event data.
    
    A schema defines the ordered list of fields that make up each event
    (transaction, purchase, clinical encounter, etc.). The field order
    determines the token order within each event.
    
    Args:
        name: Human-readable domain name (e.g., "finance", "ecommerce").
        fields: Ordered list of field specifications.
        event_separator: Special token to separate events in a sequence.
        description: Optional human-readable description.
    """
    name: str
    fields: List[FieldSpec]
    event_separator: str = "[SEP_EVENT]"
    description: str = ""

    @property
    def special_token_count(self) -> int:
        """Total number of domain-specific special tokens needed."""
        # Base special tokens: PAD, UNK, BOS, EOS, MASK, CLS, SEP, event separator
        base = 8
        return base + sum(f.token_count for f in self.fields)

    @property
    def fixed_tokens_per_event(self) -> int:
        """Number of fixed (non-text) tokens per event, including separator."""
        return 1 + sum(f.tokens_per_event for f in self.fields)  # +1 for event separator

    @property
    def has_text_fields(self) -> bool:
        """Whether the schema includes any free-text fields."""
        return any(f.field_type == FieldType.TEXT for f in self.fields)

    @property
    def text_field_names(self) -> List[str]:
        """Names of all text fields in the schema."""
        return [f.name for f in self.fields if f.field_type == FieldType.TEXT]

    @property
    def fittable_field_names(self) -> List[str]:
        """Names of fields that require fitting on training data."""
        return [f.name for f in self.fields if f.field_type == FieldType.NUMERICAL_CONTINUOUS]

    def get_field(self, name: str) -> Optional[FieldSpec]:
        """Look up a field by name."""
        for f in self.fields:
            if f.name == name:
                return f
        return None

    def summary(self) -> str:
        """Human-readable summary of the schema."""
        lines = [f"DomainSchema: {self.name}"]
        if self.description:
            lines.append(f"  {self.description}")
        lines.append(f"  Fields: {len(self.fields)}")
        lines.append(f"  Special tokens: {self.special_token_count}")
        lines.append(f"  Fixed tokens/event: {self.fixed_tokens_per_event}")
        lines.append(f"  Has text fields: {self.has_text_fields}")
        lines.append(f"  Requires fitting: {self.fittable_field_names}")
        lines.append("")
        for f in self.fields:
            lines.append(f"  [{f.field_type.value}] {f.name} → prefix={f.prefix}, "
                        f"tokens_in_vocab={f.token_count}, tokens_per_event={f.tokens_per_event}")
        return "\n".join(lines)