rtferraz commited on
Commit
818a2e9
·
verified ·
1 Parent(s): 511f3aa

Add domain_tokenizer.py — DomainTokenizerBuilder (core assembler, HF integration)

Browse files
src/domain_tokenizer/tokenizers/domain_tokenizer.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Domain Tokenizer Builder — assembles per-field tokenizers into an
3
+ HF-compatible PreTrainedTokenizerFast.
4
+
5
+ This is the core of domainTokenizer: it takes a DomainSchema, builds
6
+ per-field tokenizers, fits data-dependent ones, and produces a single
7
+ HuggingFace tokenizer that can encode domain events as token ID sequences.
8
+
9
+ The output tokenizer is fully compatible with HF Trainer, push_to_hub,
10
+ from_pretrained, etc.
11
+
12
+ References:
13
+ - Nubank nuFormer: V = V_special(97) U V_BPE -- ~14 tokens/transaction
14
+ - ActionPiece: items as unordered feature sets -> tokenized sequences
15
+ """
16
+
17
+ import json
18
+ import os
19
+ from datetime import datetime
20
+ from pathlib import Path
21
+ from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union
22
+
23
+ import numpy as np
24
+ from tokenizers import Tokenizer, decoders, models, pre_tokenizers, trainers
25
+ from transformers import PreTrainedTokenizerFast
26
+
27
+ from ..schema import DomainSchema, FieldSpec, FieldType
28
+ from .field_tokenizers import (
29
+ BaseFieldTokenizer,
30
+ CalendarTokenizer,
31
+ CategoricalTokenizer,
32
+ DiscreteNumericalTokenizer,
33
+ MagnitudeBucketTokenizer,
34
+ SignTokenizer,
35
+ create_field_tokenizer,
36
+ )
37
+
38
+
39
+ # Control tokens -- always present
40
+ CONTROL_TOKENS = ["[PAD]", "[UNK]", "[BOS]", "[EOS]", "[MASK]", "[CLS]", "[SEP]"]
41
+
42
+
43
+ class DomainTokenizerBuilder:
44
+ """Builds an HF-compatible tokenizer from a DomainSchema.
45
+
46
+ Workflow:
47
+ 1. builder = DomainTokenizerBuilder(schema)
48
+ 2. builder.fit(events) # fit magnitude bins etc.
49
+ 3. hf_tok = builder.build(text_corpus) # build HF tokenizer
50
+ 4. tokens = builder.tokenize_event(event) # tokenize a single event
51
+ 5. ids = hf_tok(tokens_str) # convert to IDs
52
+
53
+ Or use the convenience method:
54
+ 6. ids = builder.encode_event(event, hf_tok) # event -> IDs in one call
55
+ 7. ids = builder.encode_sequence(events, hf_tok) # full sequence -> IDs
56
+
57
+ Example (finance):
58
+ >>> from domain_tokenizer.schemas.predefined import FINANCE_SCHEMA
59
+ >>> builder = DomainTokenizerBuilder(FINANCE_SCHEMA)
60
+ >>> builder.fit(training_events)
61
+ >>> hf_tokenizer = builder.build(text_corpus=descriptions)
62
+ >>> token_ids = builder.encode_sequence(user_transactions, hf_tokenizer, max_length=2048)
63
+ """
64
+
65
+ def __init__(self, schema: DomainSchema):
66
+ self.schema = schema
67
+ self.field_tokenizers: Dict[str, Optional[BaseFieldTokenizer]] = {}
68
+ self._is_fitted = False
69
+ self._build_field_tokenizers()
70
+
71
+ def _build_field_tokenizers(self):
72
+ """Instantiate a field tokenizer for each field in the schema."""
73
+ for spec in self.schema.fields:
74
+ self.field_tokenizers[spec.name] = create_field_tokenizer(spec)
75
+
76
+ def fit(self, events: Sequence[Dict[str, Any]]) -> "DomainTokenizerBuilder":
77
+ """Fit data-dependent tokenizers on training events.
78
+
79
+ Currently fits: NUMERICAL_CONTINUOUS fields (magnitude bucket bins).
80
+
81
+ Args:
82
+ events: Iterable of event dicts, e.g. [{"amount": 79.99, ...}, ...]
83
+
84
+ Returns:
85
+ self (for chaining)
86
+ """
87
+ for spec in self.schema.fields:
88
+ if spec.field_type == FieldType.NUMERICAL_CONTINUOUS:
89
+ tok = self.field_tokenizers[spec.name]
90
+ values = []
91
+ for event in events:
92
+ v = event.get(spec.name) if isinstance(event, dict) else getattr(event, spec.name, None)
93
+ if v is not None:
94
+ values.append(float(v))
95
+ if values:
96
+ tok.fit(np.array(values))
97
+ else:
98
+ raise ValueError(f"No values found for field '{spec.name}' during fitting")
99
+ self._is_fitted = True
100
+ return self
101
+
102
+ @property
103
+ def is_fitted(self) -> bool:
104
+ """Whether all data-dependent tokenizers have been fitted."""
105
+ if not self.schema.fittable_field_names:
106
+ return True
107
+ return self._is_fitted
108
+
109
+ def _collect_special_tokens(self) -> List[str]:
110
+ """Collect all special tokens: control + event separator + per-field domain tokens."""
111
+ tokens = list(CONTROL_TOKENS)
112
+ tokens.append(self.schema.event_separator)
113
+ for spec in self.schema.fields:
114
+ tok = self.field_tokenizers.get(spec.name)
115
+ if tok is not None and hasattr(tok, "vocab"):
116
+ tokens.extend(tok.vocab)
117
+ seen = set()
118
+ unique = []
119
+ for t in tokens:
120
+ if t not in seen:
121
+ seen.add(t)
122
+ unique.append(t)
123
+ return unique
124
+
125
+ def build(
126
+ self,
127
+ text_corpus: Optional[Iterator[str]] = None,
128
+ bpe_vocab_size: int = 8000,
129
+ min_frequency: int = 2,
130
+ ) -> PreTrainedTokenizerFast:
131
+ """Build a complete HuggingFace-compatible tokenizer.
132
+
133
+ 1. Collects all domain special tokens from field tokenizers
134
+ 2. Trains BPE on text corpus (if schema has text fields)
135
+ 3. Merges into a single PreTrainedTokenizerFast
136
+
137
+ Args:
138
+ text_corpus: Iterator of text strings for BPE training.
139
+ bpe_vocab_size: Target BPE vocabulary size (including special tokens).
140
+ min_frequency: Minimum frequency for BPE merges.
141
+
142
+ Returns:
143
+ A PreTrainedTokenizerFast ready for use with HF Trainer.
144
+ """
145
+ for name in self.schema.fittable_field_names:
146
+ tok = self.field_tokenizers[name]
147
+ if isinstance(tok, MagnitudeBucketTokenizer) and not tok._is_fitted:
148
+ raise RuntimeError(
149
+ f"Field '{name}' requires fitting. Call builder.fit(events) first."
150
+ )
151
+ all_special_tokens = self._collect_special_tokens()
152
+ if self.schema.has_text_fields and text_corpus is not None:
153
+ base_tokenizer = Tokenizer(models.BPE(unk_token="[UNK]"))
154
+ base_tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
155
+ base_tokenizer.decoder = decoders.ByteLevel()
156
+ trainer_obj = trainers.BpeTrainer(
157
+ vocab_size=bpe_vocab_size,
158
+ special_tokens=all_special_tokens,
159
+ min_frequency=min_frequency,
160
+ show_progress=True,
161
+ )
162
+ if isinstance(text_corpus, (list, tuple)):
163
+ base_tokenizer.train_from_iterator(iter(text_corpus), trainer=trainer_obj)
164
+ else:
165
+ base_tokenizer.train_from_iterator(text_corpus, trainer=trainer_obj)
166
+ else:
167
+ vocab = {token: i for i, token in enumerate(all_special_tokens)}
168
+ merges = []
169
+ base_tokenizer = Tokenizer(models.BPE(vocab=vocab, merges=merges, unk_token="[UNK]"))
170
+ base_tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()
171
+ base_tokenizer.decoder = decoders.BPEDecoder()
172
+ hf_tokenizer = PreTrainedTokenizerFast(
173
+ tokenizer_object=base_tokenizer,
174
+ bos_token="[BOS]",
175
+ eos_token="[EOS]",
176
+ pad_token="[PAD]",
177
+ unk_token="[UNK]",
178
+ mask_token="[MASK]",
179
+ cls_token="[CLS]",
180
+ sep_token="[SEP]",
181
+ )
182
+ return hf_tokenizer
183
+
184
+ def tokenize_event(self, event: Union[Dict[str, Any], Any]) -> List[str]:
185
+ """Convert a single domain event into a list of token strings."""
186
+ tokens = []
187
+ for spec in self.schema.fields:
188
+ if isinstance(event, dict):
189
+ value = event.get(spec.name)
190
+ else:
191
+ value = getattr(event, spec.name, None)
192
+ if spec.field_type == FieldType.TEXT:
193
+ if value is not None:
194
+ tokens.append(str(value))
195
+ continue
196
+ tok = self.field_tokenizers.get(spec.name)
197
+ if tok is None:
198
+ continue
199
+ if value is None:
200
+ tokens.append("[UNK]")
201
+ continue
202
+ result = tok(value)
203
+ if isinstance(result, list):
204
+ tokens.extend(result)
205
+ else:
206
+ tokens.append(result)
207
+ return tokens
208
+
209
+ def tokenize_sequence(
210
+ self,
211
+ events: Sequence[Union[Dict[str, Any], Any]],
212
+ add_bos: bool = True,
213
+ add_eos: bool = True,
214
+ ) -> List[str]:
215
+ """Tokenize a full sequence of events into token strings."""
216
+ all_tokens = []
217
+ if add_bos:
218
+ all_tokens.append("[BOS]")
219
+ for i, event in enumerate(events):
220
+ if i > 0:
221
+ all_tokens.append(self.schema.event_separator)
222
+ event_tokens = self.tokenize_event(event)
223
+ all_tokens.extend(event_tokens)
224
+ if add_eos:
225
+ all_tokens.append("[EOS]")
226
+ return all_tokens
227
+
228
+ def encode_sequence(
229
+ self,
230
+ events: Sequence[Union[Dict[str, Any], Any]],
231
+ hf_tokenizer: PreTrainedTokenizerFast,
232
+ max_length: int = 2048,
233
+ add_bos: bool = True,
234
+ add_eos: bool = True,
235
+ return_tensors: Optional[str] = None,
236
+ ) -> Dict[str, Any]:
237
+ """Full pipeline: events -> token strings -> token IDs."""
238
+ token_strings = self.tokenize_sequence(events, add_bos=add_bos, add_eos=add_eos)
239
+ token_text = " ".join(token_strings)
240
+ encoding = hf_tokenizer(
241
+ token_text,
242
+ max_length=max_length,
243
+ truncation=True,
244
+ padding="max_length",
245
+ return_tensors=return_tensors,
246
+ )
247
+ return encoding
248
+
249
+ def save(self, directory: str):
250
+ """Save the builder state (fitted bins, schema, etc.) to a directory."""
251
+ os.makedirs(directory, exist_ok=True)
252
+ state = {
253
+ "schema_name": self.schema.name,
254
+ "is_fitted": self._is_fitted,
255
+ "field_tokenizers": {},
256
+ }
257
+ for name, tok in self.field_tokenizers.items():
258
+ if tok is not None:
259
+ state["field_tokenizers"][name] = tok.to_dict()
260
+ with open(os.path.join(directory, "domain_tokenizer_builder.json"), "w") as f:
261
+ json.dump(state, f, indent=2)
262
+
263
+ def get_stats(self) -> Dict[str, Any]:
264
+ """Return statistics about the tokenizer configuration."""
265
+ return {
266
+ "schema_name": self.schema.name,
267
+ "total_fields": len(self.schema.fields),
268
+ "special_token_count": self.schema.special_token_count,
269
+ "fixed_tokens_per_event": self.schema.fixed_tokens_per_event,
270
+ "has_text_fields": self.schema.has_text_fields,
271
+ "is_fitted": self.is_fitted,
272
+ "field_details": {
273
+ spec.name: {
274
+ "type": spec.field_type.value,
275
+ "vocab_tokens": spec.token_count,
276
+ "tokens_per_event": spec.tokens_per_event,
277
+ }
278
+ for spec in self.schema.fields
279
+ },
280
+ }