rtferraz commited on
Commit
1a9dad0
·
verified ·
1 Parent(s): 0c1ca58

Add schema.py — DomainSchema, FieldSpec, FieldType definitions

Browse files
Files changed (1) hide show
  1. src/domain_tokenizer/schema.py +180 -0
src/domain_tokenizer/schema.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Domain Schema Definition for domainTokenizer.
3
+
4
+ A declarative format that describes the fields in a domain's event data.
5
+ Each field type maps to a specific tokenization strategy.
6
+
7
+ References:
8
+ - Nubank nuFormer (arXiv:2507.23267): amount sign(2) + amount bucket(21) + calendar(74) + BPE text
9
+ - ActionPiece (arXiv:2502.13581): unordered feature sets tokenized via BPE-like merging
10
+ - Banking TF (arXiv:2410.08243): date + amount + wording composite tokens
11
+ """
12
+
13
+ from dataclasses import dataclass, field
14
+ from enum import Enum
15
+ from typing import Dict, List, Optional, Any
16
+
17
+
18
+ class FieldType(Enum):
19
+ """Supported field types for domain tokenization.
20
+
21
+ Each type maps to a specific tokenization strategy:
22
+ SIGN → 2 tokens (positive/negative, credit/debit)
23
+ NUMERICAL_CONTINUOUS → quantile-based magnitude bins (default: 21, following Nubank)
24
+ NUMERICAL_DISCRETE → small fixed vocabulary for countable values (quantities, counts)
25
+ CATEGORICAL_FIXED → direct vocabulary mapping from a known set of categories
26
+ TEMPORAL → calendar decomposition into month/dow/dom/hour tokens
27
+ TEXT → BPE subword tokenization (for descriptions, names, free text)
28
+ """
29
+ SIGN = "sign"
30
+ NUMERICAL_CONTINUOUS = "numerical_continuous"
31
+ NUMERICAL_DISCRETE = "numerical_discrete"
32
+ CATEGORICAL_FIXED = "categorical_fixed"
33
+ TEMPORAL = "temporal"
34
+ TEXT = "text"
35
+
36
+
37
+ # Mapping from calendar field name → number of tokens it produces
38
+ CALENDAR_FIELD_SIZES: Dict[str, int] = {
39
+ "month": 12,
40
+ "dow": 7, # day of week
41
+ "dom": 31, # day of month
42
+ "hour": 24,
43
+ "quarter": 4,
44
+ "minute_bin": 4, # 15-min bins: 0-14, 15-29, 30-44, 45-59
45
+ }
46
+
47
+
48
+ @dataclass
49
+ class FieldSpec:
50
+ """Specification for a single field in a domain event.
51
+
52
+ Args:
53
+ name: Field name (must match the key in event dictionaries).
54
+ field_type: How this field should be tokenized.
55
+ prefix: Token prefix override. Defaults to uppercase field name.
56
+ n_bins: Number of quantile bins for NUMERICAL_CONTINUOUS (default: 21, Nubank).
57
+ categories: List of category values for CATEGORICAL_FIXED.
58
+ vocab_size: Explicit vocab size for NUMERICAL_DISCRETE (e.g., 11 for quantities 0-10).
59
+ calendar_fields: Which calendar components to extract for TEMPORAL fields.
60
+ max_value: Upper bound for NUMERICAL_DISCRETE (tokens: 0..max_value, max_value+).
61
+ """
62
+ name: str
63
+ field_type: FieldType
64
+ prefix: Optional[str] = None
65
+ n_bins: int = 21
66
+ categories: Optional[List[str]] = None
67
+ vocab_size: Optional[int] = None
68
+ calendar_fields: List[str] = field(default_factory=lambda: ["month", "dow", "dom", "hour"])
69
+ max_value: Optional[int] = None
70
+
71
+ def __post_init__(self):
72
+ if self.prefix is None:
73
+ self.prefix = self.name.upper()
74
+
75
+ # Validation
76
+ if self.field_type == FieldType.CATEGORICAL_FIXED and self.categories is None:
77
+ raise ValueError(f"Field '{self.name}': CATEGORICAL_FIXED requires 'categories' list")
78
+
79
+ if self.field_type == FieldType.NUMERICAL_DISCRETE and self.max_value is None:
80
+ raise ValueError(f"Field '{self.name}': NUMERICAL_DISCRETE requires 'max_value'")
81
+
82
+ @property
83
+ def token_count(self) -> int:
84
+ """Number of special tokens this field produces in the vocabulary."""
85
+ if self.field_type == FieldType.SIGN:
86
+ return 2
87
+ elif self.field_type == FieldType.NUMERICAL_CONTINUOUS:
88
+ return self.n_bins
89
+ elif self.field_type == FieldType.NUMERICAL_DISCRETE:
90
+ return self.max_value + 2 # 0..max_value + overflow bin
91
+ elif self.field_type == FieldType.CATEGORICAL_FIXED:
92
+ return len(self.categories) + 1 # +1 for unknown
93
+ elif self.field_type == FieldType.TEMPORAL:
94
+ return sum(CALENDAR_FIELD_SIZES.get(cf, 0) for cf in self.calendar_fields)
95
+ elif self.field_type == FieldType.TEXT:
96
+ return 0 # text tokens come from BPE, not special vocab
97
+ return 0
98
+
99
+ @property
100
+ def tokens_per_event(self) -> int:
101
+ """Number of tokens this field contributes per event (fixed part only)."""
102
+ if self.field_type == FieldType.SIGN:
103
+ return 1
104
+ elif self.field_type in (FieldType.NUMERICAL_CONTINUOUS, FieldType.NUMERICAL_DISCRETE, FieldType.CATEGORICAL_FIXED):
105
+ return 1
106
+ elif self.field_type == FieldType.TEMPORAL:
107
+ return len(self.calendar_fields)
108
+ elif self.field_type == FieldType.TEXT:
109
+ return 0 # variable length
110
+ return 0
111
+
112
+
113
+ @dataclass
114
+ class DomainSchema:
115
+ """Complete schema for a domain's event data.
116
+
117
+ A schema defines the ordered list of fields that make up each event
118
+ (transaction, purchase, clinical encounter, etc.). The field order
119
+ determines the token order within each event.
120
+
121
+ Args:
122
+ name: Human-readable domain name (e.g., "finance", "ecommerce").
123
+ fields: Ordered list of field specifications.
124
+ event_separator: Special token to separate events in a sequence.
125
+ description: Optional human-readable description.
126
+ """
127
+ name: str
128
+ fields: List[FieldSpec]
129
+ event_separator: str = "[SEP_EVENT]"
130
+ description: str = ""
131
+
132
+ @property
133
+ def special_token_count(self) -> int:
134
+ """Total number of domain-specific special tokens needed."""
135
+ # Base special tokens: PAD, UNK, BOS, EOS, MASK, CLS, SEP, event separator
136
+ base = 8
137
+ return base + sum(f.token_count for f in self.fields)
138
+
139
+ @property
140
+ def fixed_tokens_per_event(self) -> int:
141
+ """Number of fixed (non-text) tokens per event, including separator."""
142
+ return 1 + sum(f.tokens_per_event for f in self.fields) # +1 for event separator
143
+
144
+ @property
145
+ def has_text_fields(self) -> bool:
146
+ """Whether the schema includes any free-text fields."""
147
+ return any(f.field_type == FieldType.TEXT for f in self.fields)
148
+
149
+ @property
150
+ def text_field_names(self) -> List[str]:
151
+ """Names of all text fields in the schema."""
152
+ return [f.name for f in self.fields if f.field_type == FieldType.TEXT]
153
+
154
+ @property
155
+ def fittable_field_names(self) -> List[str]:
156
+ """Names of fields that require fitting on training data."""
157
+ return [f.name for f in self.fields if f.field_type == FieldType.NUMERICAL_CONTINUOUS]
158
+
159
+ def get_field(self, name: str) -> Optional[FieldSpec]:
160
+ """Look up a field by name."""
161
+ for f in self.fields:
162
+ if f.name == name:
163
+ return f
164
+ return None
165
+
166
+ def summary(self) -> str:
167
+ """Human-readable summary of the schema."""
168
+ lines = [f"DomainSchema: {self.name}"]
169
+ if self.description:
170
+ lines.append(f" {self.description}")
171
+ lines.append(f" Fields: {len(self.fields)}")
172
+ lines.append(f" Special tokens: {self.special_token_count}")
173
+ lines.append(f" Fixed tokens/event: {self.fixed_tokens_per_event}")
174
+ lines.append(f" Has text fields: {self.has_text_fields}")
175
+ lines.append(f" Requires fitting: {self.fittable_field_names}")
176
+ lines.append("")
177
+ for f in self.fields:
178
+ lines.append(f" [{f.field_type.value}] {f.name} → prefix={f.prefix}, "
179
+ f"tokens_in_vocab={f.token_count}, tokens_per_event={f.tokens_per_event}")
180
+ return "\n".join(lines)