rtferraz commited on
Commit
25a1093
Β·
verified Β·
1 Parent(s): e30a14d

Add ADR-001: Implementation framework decision with detailed roadmap

Browse files
docs/adr/ADR-001-implementation-framework.md ADDED
@@ -0,0 +1,978 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ADR-001: Implementation Framework for domainTokenizer
2
+
3
+ > **Status:** Accepted
4
+ > **Date:** April 29, 2026
5
+ > **Decision:** PyTorch + HuggingFace Transformers as primary framework, with JAX/Flax NNX as future scaling path
6
+ > **Deciders:** domainTokenizer core team
7
+
8
+ ---
9
+
10
+ ## Table of Contents
11
+
12
+ 1. [Context](#1-context)
13
+ 2. [Goal](#2-goal)
14
+ 3. [Options Evaluated](#3-options-evaluated)
15
+ 4. [Decision](#4-decision)
16
+ 5. [Trade-offs and Justification](#5-trade-offs-and-justification)
17
+ 6. [Consequences](#6-consequences)
18
+ 7. [Implementation Roadmap](#7-implementation-roadmap)
19
+ 8. [Appendix A: Framework Usage Across Reference Papers](#appendix-a-framework-usage-across-reference-papers)
20
+ 9. [Appendix B: Head-to-Head Comparison Matrix](#appendix-b-head-to-head-comparison-matrix)
21
+ 10. [Appendix C: Key Code Patterns](#appendix-c-key-code-patterns)
22
+
23
+ ---
24
+
25
+ ## 1. Context
26
+
27
+ ### What We're Building
28
+
29
+ domainTokenizer is a library for building **small Transformer models (24M–330M parameters)** that process **domain-specific tokens** β€” financial transactions, e-commerce events, healthcare records β€” instead of natural language text. The architecture follows the validated pattern from Nubank's nuFormer ([arXiv: 2507.23267](https://arxiv.org/abs/2507.23267)):
30
+
31
+ ```
32
+ Domain Events β†’ Custom Tokenizer β†’ GPT-style Transformer β†’ Foundation Model β†’ Downstream Tasks
33
+ ```
34
+
35
+ ### The Implementation Question
36
+
37
+ A video from Google for Developers presents the **Keras 3 + JAX/Flax NNX integration** as a potential framework, citing:
38
+ - **Explicit state management** (Flax NNX) β€” useful for tracking sequential transaction state
39
+ - **Custom training loops** β€” Keras structure + JAX/Optax for domain-specific gradient control
40
+ - **JIT compilation** (`@nnx.jit`) β€” high-performance processing of millions of transactions
41
+ - **Paradigm mixing** β€” Keras layers for standard components + NNX for custom sequential encoders
42
+
43
+ The question: **Is Keras + JAX/Flax NNX the right framework for domainTokenizer, or is there a better choice?**
44
+
45
+ ### Constraints
46
+
47
+ 1. **Custom tokenizer required:** We need a tokenizer that maps structured fields (amounts, dates, categories) to special tokens β€” not a standard text tokenizer
48
+ 2. **Small models:** 24M–330M parameters, not 70B+ β€” framework overhead matters less than developer velocity
49
+ 3. **Production deployment:** Models must be servable with low latency for real-time applications (fraud detection, recommendations)
50
+ 4. **GPU hardware:** Development on A100/A10G GPUs, not TPUs (standard cloud environment)
51
+ 5. **Team context:** ML engineers familiar with Python, PyTorch, and the HuggingFace ecosystem
52
+ 6. **Iteration speed:** Need to prototype quickly across multiple domains (finance, e-commerce, healthcare)
53
+
54
+ ---
55
+
56
+ ## 2. Goal
57
+
58
+ Choose an implementation framework that:
59
+
60
+ 1. **Minimizes time from research to working prototype** β€” weeks, not months
61
+ 2. **Supports custom domain tokenizers** as first-class citizens
62
+ 3. **Integrates with the HuggingFace Hub** for model sharing, versioning, and community
63
+ 4. **Enables production deployment** via standard serving infrastructure (ONNX, TGI, vLLM, etc.)
64
+ 5. **Scales to 330M parameters** on 4–8 GPUs without heroic engineering
65
+ 6. **Does not preclude future migration** to JAX/TPU if we need to scale beyond 1B parameters
66
+
67
+ ---
68
+
69
+ ## 3. Options Evaluated
70
+
71
+ ### Option A: PyTorch + HuggingFace Transformers
72
+
73
+ The dominant ecosystem for custom NLP/sequential models. Provides `PreTrainedModel`, `PreTrainedTokenizerFast`, `Trainer`, `push_to_hub`, ONNX export, and integration with TRL, PEFT, Accelerate, DeepSpeed.
74
+
75
+ ### Option B: Keras 3 + JAX Backend + Flax NNX
76
+
77
+ Google's multi-backend framework. Keras provides high-level APIs; JAX provides XLA compilation and functional transforms; Flax NNX provides PyTorch-like stateful modules on top of JAX.
78
+
79
+ ### Option C: Pure JAX + Flax NNX + Optax
80
+
81
+ Skip Keras entirely. Use Flax NNX for model definition, Optax for optimization, Orbax for checkpointing, and Grain/tf.data for data loading. Google's MaxText framework follows this pattern.
82
+
83
+ ### Option D: PyTorch + Custom (no HuggingFace)
84
+
85
+ Use PyTorch directly without the HuggingFace abstraction layer. Full control but no ecosystem integration.
86
+
87
+ ---
88
+
89
+ ## 4. Decision
90
+
91
+ ### Primary: PyTorch + HuggingFace Transformers (Option A)
92
+
93
+ ### Future scaling path: JAX/Flax NNX (Option C) β€” if and when we need TPU training at >1B parameters
94
+
95
+ ---
96
+
97
+ ## 5. Trade-offs and Justification
98
+
99
+ ### 5.1 What the Reference Papers Actually Use
100
+
101
+ We audited the frameworks used by every paper in the domainTokenizer research corpus. The result is overwhelming:
102
+
103
+ | Paper | Framework | Confidence |
104
+ |-------|-----------|------------|
105
+ | **nuFormer** (Nubank) | PyTorch + HF Transformers (inferred) | ~90% |
106
+ | **TIGER** (Google) | JAX + T5X (official); PyTorch (community reimpl) | 100% |
107
+ | **ActionPiece** (Google DeepMind) | **PyTorch + HF Transformers** (stated verbatim in paper) | 100% |
108
+ | **RecFormer** (UCSD/Amazon) | **PyTorch + HF Transformers (Longformer)** (stated verbatim) | 100% |
109
+ | **Banking Transaction Flow** | **PyTorch** (stated verbatim in appendix) | 100% |
110
+ | **PLR Embeddings** (Yandex) | **PyTorch** + scikit-learn + Optuna | 100% |
111
+
112
+ **5 of 6 papers use PyTorch.** The sole JAX user (TIGER) was a Google-internal project using T5X, and even its most popular community reimplementation (781⭐) is in PyTorch.
113
+
114
+ Even **Google DeepMind's own ActionPiece** β€” the paper most relevant to our domain tokenization approach β€” uses PyTorch + HuggingFace. This is the strongest signal possible.
115
+
116
+ ### 5.2 Custom Tokenizer Story
117
+
118
+ This is the **decisive factor**. domainTokenizer's core innovation is the tokenizer itself. The framework must provide first-class support for custom token vocabularies.
119
+
120
+ **PyTorch + HuggingFace:**
121
+ - Train custom BPE tokenizer via `tokenizers` library (Rust-backed, fast)
122
+ - Wrap in `PreTrainedTokenizerFast` β†’ full Trainer compatibility
123
+ - Add domain special tokens via `add_special_tokens()` β†’ auto-resize embeddings
124
+ - Push tokenizer to Hub: `tokenizer.push_to_hub("org/my-tokenizer")`
125
+ - Load anywhere: `AutoTokenizer.from_pretrained("org/my-tokenizer")`
126
+ - **KL3M** ([arXiv: 2503.17247](https://arxiv.org/abs/2503.17247)) β€” the gold standard for financial domain tokenizers β€” is built entirely on this stack
127
+
128
+ **Keras + JAX/Flax NNX:**
129
+ - No equivalent to `PreTrainedTokenizerFast`
130
+ - No Hub-integrated tokenizer format
131
+ - Must build custom tokenizer from scratch with no ecosystem support
132
+ - No standard serialization/deserialization for domain vocabularies
133
+
134
+ **Verdict:** PyTorch/HF has a **complete, production-tested** custom tokenizer pipeline. JAX/Keras has **nothing** β€” you'd build everything from scratch.
135
+
136
+ ### 5.3 Production Deployment
137
+
138
+ | Path | PyTorch | JAX/Keras |
139
+ |------|---------|-----------|
140
+ | ONNX export | `torch.onnx.export()` β€” one line | Requires TF backend intermediate or experimental `jax.export` |
141
+ | TensorRT | ONNX β†’ TRT (standard) | Multi-hop, fragile |
142
+ | TGI (HuggingFace inference) | First-class | Not supported |
143
+ | vLLM | First-class | Not supported |
144
+ | Triton Inference Server | Direct ONNX/TorchScript | Via ONNX (workaround) |
145
+ | BentoML | Supported | Supported |
146
+ | Model Hub sharing | `push_to_hub()` β†’ `from_pretrained()` | Works but fragmented (`.msgpack` weights, no Trainer compat) |
147
+
148
+ **Verdict:** PyTorch has **direct, tested paths** to every major serving framework. JAX requires **multiple intermediate conversions**, each introducing failure points.
149
+
150
+ ### 5.4 Training Speed
151
+
152
+ At our scale (24M–330M parameters on 4–8 A100s):
153
+
154
+ | Scenario | PyTorch | JAX |
155
+ |----------|---------|-----|
156
+ | Steady-state training throughput | **Comparable** (`torch.compile`) | **Comparable** (XLA JIT) |
157
+ | Variable-length sequences | **Native** β€” dynamic shapes | **Problematic** β€” recompiles on new shapes; must pad to buckets |
158
+ | Multi-GPU (FSDP) | `accelerate` + FSDP2 β€” mature | `pmap`/`shard_map` β€” works but harder to configure |
159
+ | First-run compilation | Instant (eager mode) | 5–20s JIT compilation overhead |
160
+ | Debugging | Standard Python debugger | `print` debugging; cryptic XLA errors |
161
+
162
+ **Verdict:** At 330M parameters, training speed is a **wash**. JAX's advantages (XLA kernel fusion, TPU native) only matter at 10B+ parameters on 256+ accelerators. At our scale, **developer velocity dominates throughput**.
163
+
164
+ ### 5.5 The JAX Advantage: When It Would Win
165
+
166
+ JAX/Flax NNX would be the right choice **if**:
167
+
168
+ 1. **Training exclusively on Google TPUs** β€” JAX is the native TPU compiler; PyTorch/XLA is a port with overhead
169
+ 2. **Models >1B parameters** β€” XLA's whole-program optimization shines at scale
170
+ 3. **Fixed-shape workloads** β€” images, fixed-length token sequences (no variable-length padding issues)
171
+ 4. **Need functional transforms** β€” `vmap` (per-sample gradients), `pmap` (data parallelism), `grad` (higher-order derivatives)
172
+ 5. **Google Cloud infrastructure** β€” Vertex AI, TPU VMs, GCS integration
173
+
174
+ For domainTokenizer's current scope (24M–330M, GPU, variable-length sequences, fast iteration), **none of these conditions apply**.
175
+
176
+ ### 5.6 The Keras + JAX Mixing Argument
177
+
178
+ The Google for Developers video argues for mixing Keras layers (high-level) with NNX modules (custom, high-performance). In theory, this lets you:
179
+ - Use Keras for standard Transformer layers
180
+ - Use NNX for custom sequential transaction encoders
181
+ - Get JIT compilation on the NNX parts
182
+
183
+ **In practice, this creates problems:**
184
+ 1. **Two mental models:** Keras (layer-oriented, `fit/compile`) vs. NNX (functional, explicit state) β€” context switching slows development
185
+ 2. **Limited interop documentation:** Keras ↔ NNX examples are thin; edge cases are poorly documented
186
+ 3. **No HF ecosystem integration:** You lose Trainer, push_to_hub, PEFT, TRL, Accelerate β€” the entire ecosystem Nubank and ActionPiece rely on
187
+ 4. **Debugging complexity:** Errors in the Keras↔NNX boundary are hard to diagnose
188
+
189
+ **Better approach with PyTorch:** Use `torch.compile()` on performance-critical modules to get JIT compilation benefits without leaving the PyTorch ecosystem. Write custom `nn.Module` subclasses for domain-specific components. This gives you the same "standard parts + custom parts" architecture without framework mixing.
190
+
191
+ ---
192
+
193
+ ## 6. Consequences
194
+
195
+ ### What We Gain
196
+
197
+ 1. **Immediate access to the entire HuggingFace ecosystem:** Trainer, Accelerate, PEFT (LoRA), TRL, Evaluate, push_to_hub, from_pretrained, ONNX export, TGI serving
198
+ 2. **Copy-paste from reference implementations:** ActionPiece, RecFormer, Banking TF, and PLR embeddings are all PyTorch β€” we can directly reuse their code
199
+ 3. **KL3M tokenizer as starting point:** The best financial domain tokenizer already exists in PyTorch/HF format at `alea-institute/kl3m-004-128k-cased`
200
+ 4. **Standard production deployment:** ONNX β†’ TensorRT β†’ Triton, or direct TGI/vLLM serving
201
+ 5. **Community and hiring:** PyTorch is the dominant ML framework; finding contributors and documentation is easy
202
+ 6. **`torch.compile()` for performance:** When we need JIT compilation on hot paths, `torch.compile()` provides 10–30% speedups without leaving the ecosystem
203
+
204
+ ### What We Accept
205
+
206
+ 1. **No native TPU support:** If we later need to train on Google TPUs, we'll need PyTorch/XLA (slower than native JAX) or migrate the model code
207
+ 2. **No functional transforms:** `vmap` (per-sample gradients) isn't available without `functorch` (experimental). If we need advanced gradient manipulation for meta-learning or Nested Learning (HOPE-style), JAX would be better
208
+ 3. **Potential future migration cost:** If we scale beyond 1B parameters and move to TPUs, we'll need to rewrite model code in Flax NNX. This is mitigated by keeping model definitions clean and modular
209
+
210
+ ### Migration Strategy (If Needed Later)
211
+
212
+ If domainTokenizer grows to >1B parameters and we need TPU training:
213
+
214
+ 1. **Tokenizer layer stays in Python/HF:** Tokenizer is framework-agnostic β€” it produces integer sequences regardless of whether the model is PyTorch or JAX
215
+ 2. **Model architecture translates 1:1:** PyTorch `nn.Module` β†’ Flax NNX `nnx.Module` mapping is straightforward for standard Transformer components
216
+ 3. **Training loop changes:** PyTorch Trainer β†’ custom Flax NNX training loop with Optax
217
+ 4. **Reference:** Google's MaxText (`github.com/google/maxtext`) provides production-grade JAX Transformer patterns we can follow
218
+
219
+ **Estimated migration effort:** 2–4 weeks for a clean, well-separated codebase.
220
+
221
+ ---
222
+
223
+ ## 7. Implementation Roadmap
224
+
225
+ ### Phase 2A: Core Tokenizer Library (Weeks 1–3)
226
+
227
+ #### Step 1: Domain Schema Definition
228
+
229
+ Create a declarative schema format that describes the fields in a domain's event data:
230
+
231
+ ```python
232
+ # src/tokenizers/schema.py
233
+
234
+ from dataclasses import dataclass, field
235
+ from enum import Enum
236
+ from typing import List, Optional
237
+
238
+ class FieldType(Enum):
239
+ NUMERICAL_CONTINUOUS = "numerical_continuous" # prices, amounts β†’ magnitude bins
240
+ NUMERICAL_DISCRETE = "numerical_discrete" # quantities β†’ small fixed vocab
241
+ CATEGORICAL_FIXED = "categorical_fixed" # categories, days of week β†’ direct mapping
242
+ CATEGORICAL_ENTITY = "categorical_entity" # products, merchants β†’ Semantic IDs (RQ-VAE)
243
+ TEMPORAL = "temporal" # timestamps β†’ calendar decomposition
244
+ TEXT = "text" # descriptions β†’ BPE subwords
245
+ SIGN = "sign" # credit/debit β†’ 2 tokens
246
+
247
+ @dataclass
248
+ class FieldSpec:
249
+ name: str
250
+ field_type: FieldType
251
+ vocab_size: Optional[int] = None # for fixed categorical
252
+ n_bins: int = 21 # for numerical (Nubank uses 21)
253
+ calendar_fields: List[str] = field( # for temporal
254
+ default_factory=lambda: ["month", "dow", "dom", "hour"]
255
+ )
256
+
257
+ @dataclass
258
+ class DomainSchema:
259
+ name: str # e.g., "ecommerce", "finance"
260
+ fields: List[FieldSpec] # ordered list of fields per event
261
+
262
+ @property
263
+ def special_token_count(self) -> int:
264
+ """Total domain-specific special tokens needed."""
265
+ count = 0
266
+ for f in self.fields:
267
+ if f.field_type == FieldType.SIGN:
268
+ count += 2
269
+ elif f.field_type == FieldType.NUMERICAL_CONTINUOUS:
270
+ count += f.n_bins
271
+ elif f.field_type == FieldType.CATEGORICAL_FIXED:
272
+ count += f.vocab_size
273
+ elif f.field_type == FieldType.TEMPORAL:
274
+ count += sum({
275
+ "month": 12, "dow": 7, "dom": 31, "hour": 24,
276
+ "quarter": 4, "year": 10
277
+ }.get(cf, 0) for cf in f.calendar_fields)
278
+ return count
279
+
280
+ # Example: Nubank-style financial schema
281
+ FINANCE_SCHEMA = DomainSchema(
282
+ name="finance",
283
+ fields=[
284
+ FieldSpec("amount_sign", FieldType.SIGN),
285
+ FieldSpec("amount", FieldType.NUMERICAL_CONTINUOUS, n_bins=21),
286
+ FieldSpec("timestamp", FieldType.TEMPORAL,
287
+ calendar_fields=["month", "dow", "dom", "hour"]),
288
+ FieldSpec("description", FieldType.TEXT),
289
+ ]
290
+ )
291
+
292
+ # Example: E-commerce schema
293
+ ECOMMERCE_SCHEMA = DomainSchema(
294
+ name="ecommerce",
295
+ fields=[
296
+ FieldSpec("event_type", FieldType.CATEGORICAL_FIXED, vocab_size=5),
297
+ FieldSpec("price", FieldType.NUMERICAL_CONTINUOUS, n_bins=21),
298
+ FieldSpec("quantity", FieldType.NUMERICAL_DISCRETE, vocab_size=11),
299
+ FieldSpec("category_l1", FieldType.CATEGORICAL_FIXED, vocab_size=30),
300
+ FieldSpec("category_l2", FieldType.CATEGORICAL_FIXED, vocab_size=200),
301
+ FieldSpec("timestamp", FieldType.TEMPORAL,
302
+ calendar_fields=["month", "dow", "dom", "hour"]),
303
+ FieldSpec("product_title", FieldType.TEXT),
304
+ ]
305
+ )
306
+ ```
307
+
308
+ #### Step 2: Per-Field Tokenizers
309
+
310
+ Implement each field type tokenizer as a standalone module:
311
+
312
+ ```python
313
+ # src/tokenizers/field_tokenizers.py
314
+
315
+ import numpy as np
316
+ from typing import List
317
+
318
+ class SignTokenizer:
319
+ """Tokenizes sign of a numerical value (credit/debit, inflow/outflow)."""
320
+
321
+ def __init__(self, prefix: str = "SIGN"):
322
+ self.tokens = [f"[{prefix}_POS]", f"[{prefix}_NEG]"]
323
+
324
+ def __call__(self, value: float) -> str:
325
+ return self.tokens[0] if value >= 0 else self.tokens[1]
326
+
327
+ @property
328
+ def vocab(self) -> List[str]:
329
+ return self.tokens
330
+
331
+
332
+ class MagnitudeBucketTokenizer:
333
+ """Quantizes continuous values into bins (Nubank-style).
334
+
335
+ Uses quantile-based binning on the training distribution.
336
+ Follows the Relative Magnitude Tokenization principle from TP-BERTa.
337
+ """
338
+
339
+ def __init__(self, n_bins: int = 21, prefix: str = "AMT"):
340
+ self.n_bins = n_bins
341
+ self.prefix = prefix
342
+ self.bin_edges = None # fitted from data
343
+
344
+ def fit(self, values: np.ndarray):
345
+ """Compute bin edges from training data using quantiles."""
346
+ # Use absolute values for magnitude binning
347
+ abs_vals = np.abs(values[~np.isnan(values)])
348
+ quantiles = np.linspace(0, 100, self.n_bins + 1)
349
+ self.bin_edges = np.percentile(abs_vals, quantiles)
350
+ return self
351
+
352
+ def __call__(self, value: float) -> str:
353
+ if self.bin_edges is None:
354
+ raise ValueError("Tokenizer not fitted. Call .fit() first.")
355
+ bin_idx = np.searchsorted(self.bin_edges[1:-1], abs(value))
356
+ return f"[{self.prefix}_{bin_idx:02d}]"
357
+
358
+ @property
359
+ def vocab(self) -> List[str]:
360
+ return [f"[{self.prefix}_{i:02d}]" for i in range(self.n_bins)]
361
+
362
+
363
+ class CalendarTokenizer:
364
+ """Decomposes timestamps into calendar components (Nubank-style)."""
365
+
366
+ FIELD_VOCABS = {
367
+ "month": ([f"[MON_{i:02d}]" for i in range(1, 13)], lambda dt: dt.month - 1),
368
+ "dow": ([f"[DOW_{i}]" for i in range(7)], lambda dt: dt.weekday()),
369
+ "dom": ([f"[DOM_{i:02d}]" for i in range(1, 32)], lambda dt: dt.day - 1),
370
+ "hour": ([f"[HOUR_{i:02d}]" for i in range(24)], lambda dt: dt.hour),
371
+ }
372
+
373
+ def __init__(self, fields: List[str] = None):
374
+ self.fields = fields or ["month", "dow", "dom", "hour"]
375
+
376
+ def __call__(self, timestamp) -> List[str]:
377
+ tokens = []
378
+ for field_name in self.fields:
379
+ vocab, extractor = self.FIELD_VOCABS[field_name]
380
+ idx = extractor(timestamp)
381
+ tokens.append(vocab[idx])
382
+ return tokens
383
+
384
+ @property
385
+ def vocab(self) -> List[str]:
386
+ all_tokens = []
387
+ for field_name in self.fields:
388
+ all_tokens.extend(self.FIELD_VOCABS[field_name][0])
389
+ return all_tokens
390
+
391
+
392
+ class CategoricalTokenizer:
393
+ """Maps categorical values to fixed vocabulary tokens."""
394
+
395
+ def __init__(self, categories: List[str], prefix: str = "CAT"):
396
+ self.prefix = prefix
397
+ self.token_map = {cat: f"[{prefix}_{i:03d}]" for i, cat in enumerate(categories)}
398
+ self.unk_token = f"[{prefix}_UNK]"
399
+
400
+ def __call__(self, value: str) -> str:
401
+ return self.token_map.get(value, self.unk_token)
402
+
403
+ @property
404
+ def vocab(self) -> List[str]:
405
+ return list(self.token_map.values()) + [self.unk_token]
406
+ ```
407
+
408
+ #### Step 3: Composite Domain Tokenizer
409
+
410
+ Assemble per-field tokenizers into a complete domain tokenizer, wrapped as `PreTrainedTokenizerFast`:
411
+
412
+ ```python
413
+ # src/tokenizers/domain_tokenizer.py
414
+
415
+ from tokenizers import Tokenizer, models, trainers, pre_tokenizers
416
+ from transformers import PreTrainedTokenizerFast
417
+
418
+ class DomainTokenizerBuilder:
419
+ """Builds a HuggingFace-compatible tokenizer from a DomainSchema."""
420
+
421
+ def __init__(self, schema: DomainSchema):
422
+ self.schema = schema
423
+ self.field_tokenizers = {} # name β†’ field tokenizer instance
424
+ self._build_field_tokenizers()
425
+
426
+ def _build_field_tokenizers(self):
427
+ for field_spec in self.schema.fields:
428
+ if field_spec.field_type == FieldType.SIGN:
429
+ self.field_tokenizers[field_spec.name] = SignTokenizer(field_spec.name.upper())
430
+ elif field_spec.field_type == FieldType.NUMERICAL_CONTINUOUS:
431
+ self.field_tokenizers[field_spec.name] = MagnitudeBucketTokenizer(
432
+ n_bins=field_spec.n_bins, prefix=field_spec.name.upper()
433
+ )
434
+ elif field_spec.field_type == FieldType.TEMPORAL:
435
+ self.field_tokenizers[field_spec.name] = CalendarTokenizer(field_spec.calendar_fields)
436
+ # ... other types
437
+
438
+ def fit(self, data):
439
+ """Fit data-dependent tokenizers (magnitude bins, etc.)."""
440
+ for field_spec in self.schema.fields:
441
+ if field_spec.field_type == FieldType.NUMERICAL_CONTINUOUS:
442
+ values = [getattr(event, field_spec.name) for event in data]
443
+ self.field_tokenizers[field_spec.name].fit(np.array(values))
444
+ return self
445
+
446
+ def build_hf_tokenizer(self, text_corpus=None, bpe_vocab_size=8000) -> PreTrainedTokenizerFast:
447
+ """Build a complete HuggingFace tokenizer.
448
+
449
+ 1. Collect all domain special tokens
450
+ 2. Train BPE on text fields (if any)
451
+ 3. Merge into a single PreTrainedTokenizerFast
452
+ """
453
+ # Collect all special tokens from field tokenizers
454
+ all_special_tokens = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "[BOS]", "[EOS]"]
455
+ for name, tok in self.field_tokenizers.items():
456
+ if hasattr(tok, 'vocab'):
457
+ all_special_tokens.extend(tok.vocab)
458
+
459
+ # Train BPE on text fields
460
+ bpe_tokenizer = Tokenizer(models.BPE(unk_token="[UNK]"))
461
+ bpe_tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
462
+ trainer = trainers.BpeTrainer(
463
+ vocab_size=bpe_vocab_size,
464
+ special_tokens=all_special_tokens,
465
+ min_frequency=2,
466
+ )
467
+ if text_corpus:
468
+ bpe_tokenizer.train_from_iterator(text_corpus, trainer=trainer)
469
+
470
+ # Wrap as HuggingFace tokenizer
471
+ hf_tokenizer = PreTrainedTokenizerFast(
472
+ tokenizer_object=bpe_tokenizer,
473
+ bos_token="[BOS]",
474
+ eos_token="[EOS]",
475
+ pad_token="[PAD]",
476
+ unk_token="[UNK]",
477
+ mask_token="[MASK]",
478
+ )
479
+
480
+ return hf_tokenizer
481
+
482
+ def tokenize_event(self, event) -> List[str]:
483
+ """Convert a single domain event into a list of token strings."""
484
+ tokens = []
485
+ for field_spec in self.schema.fields:
486
+ value = getattr(event, field_spec.name, None)
487
+ if value is None:
488
+ tokens.append("[UNK]")
489
+ continue
490
+ tok = self.field_tokenizers[field_spec.name]
491
+ result = tok(value)
492
+ if isinstance(result, list):
493
+ tokens.extend(result)
494
+ else:
495
+ tokens.append(result)
496
+ return tokens
497
+ ```
498
+
499
+ ### Phase 2B: Model Architecture (Weeks 3–5)
500
+
501
+ #### Step 4: GPT-style Causal Transformer (NoPE)
502
+
503
+ Implement as a HuggingFace-compatible `PreTrainedModel`:
504
+
505
+ ```python
506
+ # src/models/configuration_domain_transformer.py
507
+
508
+ from transformers import PretrainedConfig
509
+
510
+ class DomainTransformerConfig(PretrainedConfig):
511
+ model_type = "domain_transformer"
512
+
513
+ def __init__(
514
+ self,
515
+ vocab_size: int = 32000,
516
+ hidden_size: int = 256, # 256 = 24M params, 1024 = 330M (Nubank sizes)
517
+ num_hidden_layers: int = 24, # Nubank uses 24 for both sizes
518
+ num_attention_heads: int = 16, # Nubank uses 16 for both sizes
519
+ intermediate_size: int = None, # defaults to 4 * hidden_size
520
+ max_position_embeddings: int = 2048,
521
+ dropout: float = 0.1,
522
+ use_positional_encoding: bool = False, # NoPE by default!
523
+ **kwargs
524
+ ):
525
+ self.vocab_size = vocab_size
526
+ self.hidden_size = hidden_size
527
+ self.num_hidden_layers = num_hidden_layers
528
+ self.num_attention_heads = num_attention_heads
529
+ self.intermediate_size = intermediate_size or 4 * hidden_size
530
+ self.max_position_embeddings = max_position_embeddings
531
+ self.dropout = dropout
532
+ self.use_positional_encoding = use_positional_encoding
533
+ super().__init__(**kwargs)
534
+ ```
535
+
536
+ ```python
537
+ # src/models/modeling_domain_transformer.py
538
+
539
+ import torch
540
+ import torch.nn as nn
541
+ from transformers import PreTrainedModel
542
+ from transformers.modeling_outputs import CausalLMOutputWithPast
543
+
544
+ class DomainTransformerBlock(nn.Module):
545
+ def __init__(self, config):
546
+ super().__init__()
547
+ self.ln1 = nn.LayerNorm(config.hidden_size)
548
+ self.attn = nn.MultiheadAttention(
549
+ config.hidden_size, config.num_attention_heads,
550
+ dropout=config.dropout, batch_first=True
551
+ )
552
+ self.ln2 = nn.LayerNorm(config.hidden_size)
553
+ self.mlp = nn.Sequential(
554
+ nn.Linear(config.hidden_size, config.intermediate_size),
555
+ nn.GELU(),
556
+ nn.Linear(config.intermediate_size, config.hidden_size),
557
+ nn.Dropout(config.dropout),
558
+ )
559
+
560
+ def forward(self, x, attn_mask=None):
561
+ # Pre-norm architecture
562
+ h = self.ln1(x)
563
+ h, _ = self.attn(h, h, h, attn_mask=attn_mask, is_causal=True)
564
+ x = x + h
565
+ x = x + self.mlp(self.ln2(x))
566
+ return x
567
+
568
+
569
+ class DomainTransformerForCausalLM(PreTrainedModel):
570
+ config_class = DomainTransformerConfig
571
+
572
+ def __init__(self, config):
573
+ super().__init__(config)
574
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
575
+
576
+ # NoPE: no positional encoding by default (Kazemnejad et al. 2023)
577
+ if config.use_positional_encoding:
578
+ self.embed_positions = nn.Embedding(
579
+ config.max_position_embeddings, config.hidden_size
580
+ )
581
+ else:
582
+ self.embed_positions = None
583
+
584
+ self.drop = nn.Dropout(config.dropout)
585
+ self.blocks = nn.ModuleList([
586
+ DomainTransformerBlock(config)
587
+ for _ in range(config.num_hidden_layers)
588
+ ])
589
+ self.ln_f = nn.LayerNorm(config.hidden_size)
590
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
591
+
592
+ # Weight tying (standard for small models)
593
+ self.lm_head.weight = self.embed_tokens.weight
594
+
595
+ self.post_init()
596
+
597
+ def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
598
+ x = self.embed_tokens(input_ids)
599
+
600
+ if self.embed_positions is not None:
601
+ positions = torch.arange(input_ids.size(1), device=input_ids.device)
602
+ x = x + self.embed_positions(positions)
603
+
604
+ x = self.drop(x)
605
+
606
+ for block in self.blocks:
607
+ x = block(x, attn_mask=attention_mask)
608
+
609
+ x = self.ln_f(x)
610
+ logits = self.lm_head(x)
611
+
612
+ loss = None
613
+ if labels is not None:
614
+ shift_logits = logits[..., :-1, :].contiguous()
615
+ shift_labels = labels[..., 1:].contiguous()
616
+ loss = nn.functional.cross_entropy(
617
+ shift_logits.view(-1, shift_logits.size(-1)),
618
+ shift_labels.view(-1),
619
+ ignore_index=-100,
620
+ )
621
+
622
+ return CausalLMOutputWithPast(loss=loss, logits=logits)
623
+
624
+ # Register with AutoClass for Hub compatibility
625
+ DomainTransformerConfig.register_for_auto_class()
626
+ DomainTransformerForCausalLM.register_for_auto_class("AutoModelForCausalLM")
627
+ ```
628
+
629
+ #### Step 5: PLR Numerical Embeddings (for Joint Fusion)
630
+
631
+ Port from Yandex's implementation:
632
+
633
+ ```python
634
+ # src/models/plr_embeddings.py
635
+
636
+ import torch
637
+ import torch.nn as nn
638
+ import math
639
+
640
+ class PeriodicLinearReLU(nn.Module):
641
+ """PLR numerical embeddings (Gorishniy et al. 2022).
642
+
643
+ Maps scalar x β†’ [sin(2π·wΒ·x + b), cos(2π·wΒ·x + b)] β†’ Linear β†’ ReLU
644
+ Frequencies w and phases b are LEARNED parameters.
645
+ """
646
+
647
+ def __init__(self, n_features: int, n_frequencies: int = 64, embedding_dim: int = 64):
648
+ super().__init__()
649
+ self.n_features = n_features
650
+ self.n_frequencies = n_frequencies
651
+
652
+ # Learnable frequencies and phases (per feature)
653
+ self.frequencies = nn.Parameter(
654
+ torch.randn(n_features, n_frequencies) * 0.01
655
+ )
656
+ self.phases = nn.Parameter(
657
+ torch.zeros(n_features, n_frequencies)
658
+ )
659
+
660
+ # Linear projection: 2*n_frequencies β†’ embedding_dim
661
+ self.linear = nn.Linear(2 * n_frequencies, embedding_dim)
662
+
663
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
664
+ """
665
+ Args:
666
+ x: (batch, n_features) β€” raw scalar feature values
667
+ Returns:
668
+ (batch, n_features, embedding_dim)
669
+ """
670
+ # x: (B, F) β†’ (B, F, 1)
671
+ x = x.unsqueeze(-1)
672
+
673
+ # Periodic encoding: (B, F, n_freq)
674
+ angles = 2 * math.pi * self.frequencies.unsqueeze(0) * x + self.phases.unsqueeze(0)
675
+ periodic = torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1) # (B, F, 2*n_freq)
676
+
677
+ # Linear + ReLU: (B, F, embedding_dim)
678
+ return torch.relu(self.linear(periodic))
679
+ ```
680
+
681
+ ### Phase 2C: Pre-training (Weeks 5–7)
682
+
683
+ #### Step 6: Data Pipeline
684
+
685
+ ```python
686
+ # src/training/data_pipeline.py
687
+
688
+ from torch.utils.data import Dataset
689
+ from typing import List
690
+
691
+ class DomainSequenceDataset(Dataset):
692
+ """Converts user event sequences into token sequences for CLM training."""
693
+
694
+ def __init__(self, user_sequences, tokenizer_builder, hf_tokenizer, max_length=2048):
695
+ self.user_sequences = user_sequences
696
+ self.tokenizer_builder = tokenizer_builder
697
+ self.hf_tokenizer = hf_tokenizer
698
+ self.max_length = max_length
699
+
700
+ def __len__(self):
701
+ return len(self.user_sequences)
702
+
703
+ def __getitem__(self, idx):
704
+ events = self.user_sequences[idx]
705
+
706
+ # Tokenize each event into token strings
707
+ token_strings = []
708
+ for event in events:
709
+ event_tokens = self.tokenizer_builder.tokenize_event(event)
710
+ token_strings.extend(event_tokens)
711
+
712
+ # Convert token strings to IDs via HF tokenizer
713
+ encoding = self.hf_tokenizer(
714
+ " ".join(token_strings),
715
+ max_length=self.max_length,
716
+ truncation=True,
717
+ padding="max_length",
718
+ return_tensors="pt",
719
+ )
720
+
721
+ input_ids = encoding["input_ids"].squeeze(0)
722
+
723
+ return {
724
+ "input_ids": input_ids,
725
+ "labels": input_ids.clone(), # CLM: labels = input shifted by 1
726
+ "attention_mask": encoding["attention_mask"].squeeze(0),
727
+ }
728
+ ```
729
+
730
+ #### Step 7: Pre-training with HuggingFace Trainer
731
+
732
+ ```python
733
+ # src/training/pretrain.py
734
+
735
+ from transformers import Trainer, TrainingArguments
736
+
737
+ def pretrain_domain_model(
738
+ model,
739
+ train_dataset,
740
+ eval_dataset=None,
741
+ output_dir="./checkpoints",
742
+ hub_model_id="org/domain-model-24m",
743
+ num_epochs=3,
744
+ batch_size=64,
745
+ learning_rate=3e-4,
746
+ context_length=2048,
747
+ ):
748
+ training_args = TrainingArguments(
749
+ output_dir=output_dir,
750
+ num_train_epochs=num_epochs,
751
+ per_device_train_batch_size=batch_size,
752
+ gradient_accumulation_steps=4,
753
+ learning_rate=learning_rate,
754
+ lr_scheduler_type="cosine",
755
+ warmup_ratio=0.05,
756
+ weight_decay=0.01,
757
+ logging_strategy="steps",
758
+ logging_steps=100,
759
+ logging_first_step=True,
760
+ disable_tqdm=True, # plain text logging for cloud
761
+ eval_strategy="steps" if eval_dataset else "no",
762
+ eval_steps=500,
763
+ save_strategy="steps",
764
+ save_steps=1000,
765
+ save_total_limit=3,
766
+ push_to_hub=True,
767
+ hub_model_id=hub_model_id,
768
+ bf16=True,
769
+ dataloader_num_workers=4,
770
+ report_to="trackio",
771
+ )
772
+
773
+ trainer = Trainer(
774
+ model=model,
775
+ args=training_args,
776
+ train_dataset=train_dataset,
777
+ eval_dataset=eval_dataset,
778
+ )
779
+
780
+ trainer.train()
781
+ trainer.push_to_hub()
782
+ ```
783
+
784
+ ### Phase 2D: Joint Fusion Fine-tuning (Weeks 7–9)
785
+
786
+ #### Step 8: nuFormer-style Joint Fusion
787
+
788
+ ```python
789
+ # src/models/joint_fusion.py
790
+
791
+ import torch
792
+ import torch.nn as nn
793
+
794
+ class DCNv2CrossLayer(nn.Module):
795
+ """Single cross layer from DCN V2 (Wang et al. 2021)."""
796
+
797
+ def __init__(self, dim):
798
+ super().__init__()
799
+ self.weight = nn.Linear(dim, dim, bias=True)
800
+
801
+ def forward(self, x0, x):
802
+ return x0 * self.weight(x) + x # element-wise product with anchor
803
+
804
+
805
+ class JointFusionModel(nn.Module):
806
+ """nuFormer-style: Transaction Transformer + DCNv2(PLR) β†’ Joint Prediction.
807
+
808
+ Architecture:
809
+ Transaction Sequence β†’ Pre-trained DomainTransformer β†’ user_embedding
810
+ Tabular Features β†’ PLR β†’ DCNv2 β†’ tab_embedding
811
+ Concatenate β†’ MLP Head β†’ prediction
812
+ """
813
+
814
+ def __init__(self, transformer_model, n_tabular_features, n_classes=1,
815
+ plr_frequencies=64, dcn_layers=3, hidden_dim=256):
816
+ super().__init__()
817
+
818
+ self.transformer = transformer_model # pre-trained, unfrozen for fine-tuning
819
+ transformer_dim = transformer_model.config.hidden_size
820
+
821
+ # Tabular branch: PLR β†’ DCNv2
822
+ self.plr = PeriodicLinearReLU(n_tabular_features, plr_frequencies, hidden_dim)
823
+ tab_input_dim = n_tabular_features * hidden_dim
824
+
825
+ self.dcn_layers = nn.ModuleList([
826
+ DCNv2CrossLayer(tab_input_dim) for _ in range(dcn_layers)
827
+ ])
828
+ self.dcn_deep = nn.Sequential(
829
+ nn.Linear(tab_input_dim, hidden_dim),
830
+ nn.ReLU(),
831
+ nn.Linear(hidden_dim, hidden_dim),
832
+ nn.ReLU(),
833
+ )
834
+
835
+ # Joint head
836
+ self.head = nn.Sequential(
837
+ nn.Linear(transformer_dim + hidden_dim, hidden_dim),
838
+ nn.ReLU(),
839
+ nn.Dropout(0.1),
840
+ nn.Linear(hidden_dim, n_classes),
841
+ )
842
+
843
+ def forward(self, input_ids, attention_mask, tabular_features, labels=None):
844
+ # Transaction branch: get last-token embedding
845
+ transformer_output = self.transformer(input_ids, attention_mask=attention_mask)
846
+ user_embedding = transformer_output.logits[:, -1, :] # last token representation
847
+
848
+ # Tabular branch: PLR β†’ flatten β†’ DCNv2
849
+ tab_embedded = self.plr(tabular_features) # (B, F, D)
850
+ tab_flat = tab_embedded.view(tab_embedded.size(0), -1) # (B, F*D)
851
+
852
+ x0 = tab_flat
853
+ x = tab_flat
854
+ for cross_layer in self.dcn_layers:
855
+ x = cross_layer(x0, x)
856
+ tab_output = self.dcn_deep(x) # (B, hidden_dim)
857
+
858
+ # Joint fusion
859
+ combined = torch.cat([user_embedding, tab_output], dim=-1)
860
+ logits = self.head(combined)
861
+
862
+ loss = None
863
+ if labels is not None:
864
+ loss = nn.functional.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
865
+
866
+ return {"loss": loss, "logits": logits}
867
+ ```
868
+
869
+ ### Phase 3: Domain Demos (Weeks 9–12)
870
+
871
+ | Week | Deliverable | Hardware |
872
+ |------|------------|----------|
873
+ | 9–10 | **Finance demo:** Transaction tokenizer + 24M model pre-trained on synthetic/public financial data + fraud detection fine-tuning | a10g-large |
874
+ | 10–11 | **E-commerce demo:** Event tokenizer + 24M model pre-trained on Amazon review sequences + next-purchase prediction | a10g-large |
875
+ | 11–12 | **Evaluation & benchmarking:** Compare domain tokenizer vs. text serialization vs. LightGBM baselines on each domain | a10g-large |
876
+
877
+ ### Phase 4: Scale & Optimize (Weeks 12+)
878
+
879
+ | Task | Details |
880
+ |------|---------|
881
+ | Scale to 330M params | Increase `hidden_size` to 1024, train on a100-large |
882
+ | `torch.compile()` | Apply to attention and MLP blocks for 10–30% speedup |
883
+ | ONNX export | `torch.onnx.export()` for production serving |
884
+ | Context window experiments | Ablate 512/1024/2048/4096 context lengths |
885
+ | Data source ablation | Test impact of different event types (Nubank found adding low-signal sources hurts) |
886
+ | ActionPiece vocabulary | Implement BPE-like cross-field merging on top of per-field tokens |
887
+
888
+ ---
889
+
890
+ ## Appendix A: Framework Usage Across Reference Papers
891
+
892
+ | Paper | ArXiv | Framework | Verbatim Evidence |
893
+ |-------|-------|-----------|-------------------|
894
+ | nuFormer (Nubank) | 2507.23267 | PyTorch + HF (inferred) | All dependencies are PyTorch-based |
895
+ | TIGER (Google) | 2305.05065 | JAX + T5X | "We use the open-sourced T5X framework" |
896
+ | ActionPiece (DeepMind) | 2502.13581 | PyTorch + HF | "HuggingFace Transformers and PyTorch" (Appendix H) |
897
+ | RecFormer | 2305.13731 | PyTorch + HF Longformer | "Longformer implemented by Huggingface" (Β§3.1.4) |
898
+ | Banking TF | 2410.08243 | PyTorch | "Pytorch backend is used" (Appendix B) |
899
+ | PLR Embeddings (Yandex) | 2203.05556 | PyTorch | Repository: pure PyTorch + scikit-learn |
900
+ | KL3M Tokenizers | 2503.17247 | HF `tokenizers` + PyTorch | "tokenizers" BPE for HF compatibility |
901
+
902
+ ---
903
+
904
+ ## Appendix B: Head-to-Head Comparison Matrix
905
+
906
+ | Criterion | PyTorch + HF | JAX/Flax NNX | Keras 3 + JAX |
907
+ |-----------|-------------|-------------|---------------|
908
+ | Custom domain tokenizer | βœ… `PreTrainedTokenizerFast` | ❌ Build from scratch | ❌ Build from scratch |
909
+ | HF Trainer integration | βœ… Native | ❌ Not compatible | ❌ Not compatible |
910
+ | Hub push/pull | βœ… `push_to_hub()` | ⚠️ Works, fragmented | ⚠️ Limited |
911
+ | PEFT/LoRA | βœ… Drop-in | ❌ Manual | ❌ Manual |
912
+ | ONNX export | βœ… One-line | ❌ Multi-hop | ⚠️ TF backend required |
913
+ | TGI/vLLM serving | βœ… First-class | ❌ Not supported | ❌ Not supported |
914
+ | TPU training | ⚠️ PyTorch/XLA (overhead) | βœ… Native | βœ… Native |
915
+ | JIT compilation | βœ… `torch.compile()` | βœ… `@nnx.jit` | βœ… XLA via JAX |
916
+ | Dynamic shapes (NLP) | βœ… Native | ❌ Recompiles | ❌ Recompiles |
917
+ | Debugging | βœ… Eager mode, std debugger | ⚠️ Challenging | ⚠️ Challenging |
918
+ | Reference implementations | 5/6 papers | 1/6 papers | 0/6 papers |
919
+ | Community/hiring pool | 🟒 Large | 🟑 Small | 🟑 Small |
920
+
921
+ ---
922
+
923
+ ## Appendix C: Key Code Patterns
924
+
925
+ ### Adding Domain Special Tokens to an Existing Tokenizer
926
+
927
+ ```python
928
+ from transformers import AutoTokenizer
929
+
930
+ tokenizer = AutoTokenizer.from_pretrained("gpt2") # or any base tokenizer
931
+
932
+ # Add all domain special tokens
933
+ special_tokens = {
934
+ "additional_special_tokens": [
935
+ # Amount tokens (Nubank-style)
936
+ "[AMT_POS]", "[AMT_NEG]",
937
+ *[f"[AMT_{i:02d}]" for i in range(21)],
938
+ # Calendar tokens
939
+ *[f"[MON_{i:02d}]" for i in range(1, 13)],
940
+ *[f"[DOW_{i}]" for i in range(7)],
941
+ *[f"[DOM_{i:02d}]" for i in range(1, 32)],
942
+ *[f"[HOUR_{i:02d}]" for i in range(24)],
943
+ ]
944
+ }
945
+ num_added = tokenizer.add_special_tokens(special_tokens)
946
+ print(f"Added {num_added} domain tokens. Vocab size: {len(tokenizer)}")
947
+
948
+ # CRITICAL: resize model embeddings
949
+ model.resize_token_embeddings(len(tokenizer))
950
+ ```
951
+
952
+ ### Registering a Custom Model for Hub Deployment
953
+
954
+ ```python
955
+ # In your package's __init__.py or a registration script:
956
+ from transformers import AutoConfig, AutoModelForCausalLM
957
+
958
+ from .configuration_domain_transformer import DomainTransformerConfig
959
+ from .modeling_domain_transformer import DomainTransformerForCausalLM
960
+
961
+ # Register so AutoClass can find your model
962
+ AutoConfig.register("domain_transformer", DomainTransformerConfig)
963
+ AutoModelForCausalLM.register(DomainTransformerConfig, DomainTransformerForCausalLM)
964
+
965
+ # Enable push_to_hub with custom code
966
+ DomainTransformerConfig.register_for_auto_class()
967
+ DomainTransformerForCausalLM.register_for_auto_class("AutoModelForCausalLM")
968
+
969
+ # Push: uploads configuration.py, modeling.py, config.json, model.safetensors
970
+ model.push_to_hub("org/domain-transformer-24m")
971
+
972
+ # Load anywhere:
973
+ model = AutoModelForCausalLM.from_pretrained("org/domain-transformer-24m", trust_remote_code=True)
974
+ ```
975
+
976
+ ---
977
+
978
+ *This ADR is a living document and will be updated as implementation progresses.*