| # ADR-001: Implementation Framework for domainTokenizer |
|
|
| > **Status:** Accepted |
| > **Date:** April 29, 2026 |
| > **Decision:** PyTorch + HuggingFace Transformers as primary framework, with JAX/Flax NNX as future scaling path |
| > **Deciders:** domainTokenizer core team |
|
|
| --- |
|
|
| ## Table of Contents |
|
|
| 1. [Context](#1-context) |
| 2. [Goal](#2-goal) |
| 3. [Options Evaluated](#3-options-evaluated) |
| 4. [Decision](#4-decision) |
| 5. [Trade-offs and Justification](#5-trade-offs-and-justification) |
| 6. [Consequences](#6-consequences) |
| 7. [Implementation Roadmap](#7-implementation-roadmap) |
| 8. [Appendix A: Framework Usage Across Reference Papers](#appendix-a-framework-usage-across-reference-papers) |
| 9. [Appendix B: Head-to-Head Comparison Matrix](#appendix-b-head-to-head-comparison-matrix) |
| 10. [Appendix C: Key Code Patterns](#appendix-c-key-code-patterns) |
|
|
| --- |
|
|
| ## 1. Context |
|
|
| ### What We're Building |
|
|
| domainTokenizer is a library for building **small Transformer models (24Mβ330M parameters)** that process **domain-specific tokens** β financial transactions, e-commerce events, healthcare records β instead of natural language text. The architecture follows the validated pattern from Nubank's nuFormer ([arXiv: 2507.23267](https://arxiv.org/abs/2507.23267)): |
|
|
| ``` |
| Domain Events β Custom Tokenizer β GPT-style Transformer β Foundation Model β Downstream Tasks |
| ``` |
|
|
| ### The Implementation Question |
|
|
| A video from Google for Developers presents the **Keras 3 + JAX/Flax NNX integration** as a potential framework, citing: |
| - **Explicit state management** (Flax NNX) β useful for tracking sequential transaction state |
| - **Custom training loops** β Keras structure + JAX/Optax for domain-specific gradient control |
| - **JIT compilation** (`@nnx.jit`) β high-performance processing of millions of transactions |
| - **Paradigm mixing** β Keras layers for standard components + NNX for custom sequential encoders |
|
|
| The question: **Is Keras + JAX/Flax NNX the right framework for domainTokenizer, or is there a better choice?** |
|
|
| ### Constraints |
|
|
| 1. **Custom tokenizer required:** We need a tokenizer that maps structured fields (amounts, dates, categories) to special tokens β not a standard text tokenizer |
| 2. **Small models:** 24Mβ330M parameters, not 70B+ β framework overhead matters less than developer velocity |
| 3. **Production deployment:** Models must be servable with low latency for real-time applications (fraud detection, recommendations) |
| 4. **GPU hardware:** Development on A100/A10G GPUs, not TPUs (standard cloud environment) |
| 5. **Team context:** ML engineers familiar with Python, PyTorch, and the HuggingFace ecosystem |
| 6. **Iteration speed:** Need to prototype quickly across multiple domains (finance, e-commerce, healthcare) |
|
|
| --- |
|
|
| ## 2. Goal |
|
|
| Choose an implementation framework that: |
|
|
| 1. **Minimizes time from research to working prototype** β weeks, not months |
| 2. **Supports custom domain tokenizers** as first-class citizens |
| 3. **Integrates with the HuggingFace Hub** for model sharing, versioning, and community |
| 4. **Enables production deployment** via standard serving infrastructure (ONNX, TGI, vLLM, etc.) |
| 5. **Scales to 330M parameters** on 4β8 GPUs without heroic engineering |
| 6. **Does not preclude future migration** to JAX/TPU if we need to scale beyond 1B parameters |
|
|
| --- |
|
|
| ## 3. Options Evaluated |
|
|
| ### Option A: PyTorch + HuggingFace Transformers |
|
|
| The dominant ecosystem for custom NLP/sequential models. Provides `PreTrainedModel`, `PreTrainedTokenizerFast`, `Trainer`, `push_to_hub`, ONNX export, and integration with TRL, PEFT, Accelerate, DeepSpeed. |
|
|
| ### Option B: Keras 3 + JAX Backend + Flax NNX |
|
|
| Google's multi-backend framework. Keras provides high-level APIs; JAX provides XLA compilation and functional transforms; Flax NNX provides PyTorch-like stateful modules on top of JAX. |
|
|
| ### Option C: Pure JAX + Flax NNX + Optax |
|
|
| Skip Keras entirely. Use Flax NNX for model definition, Optax for optimization, Orbax for checkpointing, and Grain/tf.data for data loading. Google's MaxText framework follows this pattern. |
|
|
| ### Option D: PyTorch + Custom (no HuggingFace) |
|
|
| Use PyTorch directly without the HuggingFace abstraction layer. Full control but no ecosystem integration. |
|
|
| --- |
|
|
| ## 4. Decision |
|
|
| ### Primary: PyTorch + HuggingFace Transformers (Option A) |
|
|
| ### Future scaling path: JAX/Flax NNX (Option C) β if and when we need TPU training at >1B parameters |
|
|
| --- |
|
|
| ## 5. Trade-offs and Justification |
|
|
| ### 5.1 What the Reference Papers Actually Use |
|
|
| We audited the frameworks used by every paper in the domainTokenizer research corpus. The result is overwhelming: |
|
|
| | Paper | Framework | Confidence | |
| |-------|-----------|------------| |
| | **nuFormer** (Nubank) | PyTorch + HF Transformers (inferred) | ~90% | |
| | **TIGER** (Google) | JAX + T5X (official); PyTorch (community reimpl) | 100% | |
| | **ActionPiece** (Google DeepMind) | **PyTorch + HF Transformers** (stated verbatim in paper) | 100% | |
| | **RecFormer** (UCSD/Amazon) | **PyTorch + HF Transformers (Longformer)** (stated verbatim) | 100% | |
| | **Banking Transaction Flow** | **PyTorch** (stated verbatim in appendix) | 100% | |
| | **PLR Embeddings** (Yandex) | **PyTorch** + scikit-learn + Optuna | 100% | |
|
|
| **5 of 6 papers use PyTorch.** The sole JAX user (TIGER) was a Google-internal project using T5X, and even its most popular community reimplementation (781β) is in PyTorch. |
|
|
| Even **Google DeepMind's own ActionPiece** β the paper most relevant to our domain tokenization approach β uses PyTorch + HuggingFace. This is the strongest signal possible. |
|
|
| ### 5.2 Custom Tokenizer Story |
|
|
| This is the **decisive factor**. domainTokenizer's core innovation is the tokenizer itself. The framework must provide first-class support for custom token vocabularies. |
|
|
| **PyTorch + HuggingFace:** |
| - Train custom BPE tokenizer via `tokenizers` library (Rust-backed, fast) |
| - Wrap in `PreTrainedTokenizerFast` β full Trainer compatibility |
| - Add domain special tokens via `add_special_tokens()` β auto-resize embeddings |
| - Push tokenizer to Hub: `tokenizer.push_to_hub("org/my-tokenizer")` |
| - Load anywhere: `AutoTokenizer.from_pretrained("org/my-tokenizer")` |
| - **KL3M** ([arXiv: 2503.17247](https://arxiv.org/abs/2503.17247)) β the gold standard for financial domain tokenizers β is built entirely on this stack |
|
|
| **Keras + JAX/Flax NNX:** |
| - No equivalent to `PreTrainedTokenizerFast` |
| - No Hub-integrated tokenizer format |
| - Must build custom tokenizer from scratch with no ecosystem support |
| - No standard serialization/deserialization for domain vocabularies |
|
|
| **Verdict:** PyTorch/HF has a **complete, production-tested** custom tokenizer pipeline. JAX/Keras has **nothing** β you'd build everything from scratch. |
|
|
| ### 5.3 Production Deployment |
|
|
| | Path | PyTorch | JAX/Keras | |
| |------|---------|-----------| |
| | ONNX export | `torch.onnx.export()` β one line | Requires TF backend intermediate or experimental `jax.export` | |
| | TensorRT | ONNX β TRT (standard) | Multi-hop, fragile | |
| | TGI (HuggingFace inference) | First-class | Not supported | |
| | vLLM | First-class | Not supported | |
| | Triton Inference Server | Direct ONNX/TorchScript | Via ONNX (workaround) | |
| | BentoML | Supported | Supported | |
| | Model Hub sharing | `push_to_hub()` β `from_pretrained()` | Works but fragmented (`.msgpack` weights, no Trainer compat) | |
|
|
| **Verdict:** PyTorch has **direct, tested paths** to every major serving framework. JAX requires **multiple intermediate conversions**, each introducing failure points. |
|
|
| ### 5.4 Training Speed |
|
|
| At our scale (24Mβ330M parameters on 4β8 A100s): |
|
|
| | Scenario | PyTorch | JAX | |
| |----------|---------|-----| |
| | Steady-state training throughput | **Comparable** (`torch.compile`) | **Comparable** (XLA JIT) | |
| | Variable-length sequences | **Native** β dynamic shapes | **Problematic** β recompiles on new shapes; must pad to buckets | |
| | Multi-GPU (FSDP) | `accelerate` + FSDP2 β mature | `pmap`/`shard_map` β works but harder to configure | |
| | First-run compilation | Instant (eager mode) | 5β20s JIT compilation overhead | |
| | Debugging | Standard Python debugger | `print` debugging; cryptic XLA errors | |
|
|
| **Verdict:** At 330M parameters, training speed is a **wash**. JAX's advantages (XLA kernel fusion, TPU native) only matter at 10B+ parameters on 256+ accelerators. At our scale, **developer velocity dominates throughput**. |
|
|
| ### 5.5 The JAX Advantage: When It Would Win |
|
|
| JAX/Flax NNX would be the right choice **if**: |
|
|
| 1. **Training exclusively on Google TPUs** β JAX is the native TPU compiler; PyTorch/XLA is a port with overhead |
| 2. **Models >1B parameters** β XLA's whole-program optimization shines at scale |
| 3. **Fixed-shape workloads** β images, fixed-length token sequences (no variable-length padding issues) |
| 4. **Need functional transforms** β `vmap` (per-sample gradients), `pmap` (data parallelism), `grad` (higher-order derivatives) |
| 5. **Google Cloud infrastructure** β Vertex AI, TPU VMs, GCS integration |
|
|
| For domainTokenizer's current scope (24Mβ330M, GPU, variable-length sequences, fast iteration), **none of these conditions apply**. |
|
|
| ### 5.6 The Keras + JAX Mixing Argument |
|
|
| The Google for Developers video argues for mixing Keras layers (high-level) with NNX modules (custom, high-performance). In theory, this lets you: |
| - Use Keras for standard Transformer layers |
| - Use NNX for custom sequential transaction encoders |
| - Get JIT compilation on the NNX parts |
|
|
| **In practice, this creates problems:** |
| 1. **Two mental models:** Keras (layer-oriented, `fit/compile`) vs. NNX (functional, explicit state) β context switching slows development |
| 2. **Limited interop documentation:** Keras β NNX examples are thin; edge cases are poorly documented |
| 3. **No HF ecosystem integration:** You lose Trainer, push_to_hub, PEFT, TRL, Accelerate β the entire ecosystem Nubank and ActionPiece rely on |
| 4. **Debugging complexity:** Errors in the KerasβNNX boundary are hard to diagnose |
|
|
| **Better approach with PyTorch:** Use `torch.compile()` on performance-critical modules to get JIT compilation benefits without leaving the PyTorch ecosystem. Write custom `nn.Module` subclasses for domain-specific components. This gives you the same "standard parts + custom parts" architecture without framework mixing. |
|
|
| --- |
|
|
| ## 6. Consequences |
|
|
| ### What We Gain |
|
|
| 1. **Immediate access to the entire HuggingFace ecosystem:** Trainer, Accelerate, PEFT (LoRA), TRL, Evaluate, push_to_hub, from_pretrained, ONNX export, TGI serving |
| 2. **Copy-paste from reference implementations:** ActionPiece, RecFormer, Banking TF, and PLR embeddings are all PyTorch β we can directly reuse their code |
| 3. **KL3M tokenizer as starting point:** The best financial domain tokenizer already exists in PyTorch/HF format at `alea-institute/kl3m-004-128k-cased` |
| 4. **Standard production deployment:** ONNX β TensorRT β Triton, or direct TGI/vLLM serving |
| 5. **Community and hiring:** PyTorch is the dominant ML framework; finding contributors and documentation is easy |
| 6. **`torch.compile()` for performance:** When we need JIT compilation on hot paths, `torch.compile()` provides 10β30% speedups without leaving the ecosystem |
| |
| ### What We Accept |
| |
| 1. **No native TPU support:** If we later need to train on Google TPUs, we'll need PyTorch/XLA (slower than native JAX) or migrate the model code |
| 2. **No functional transforms:** `vmap` (per-sample gradients) isn't available without `functorch` (experimental). If we need advanced gradient manipulation for meta-learning or Nested Learning (HOPE-style), JAX would be better |
| 3. **Potential future migration cost:** If we scale beyond 1B parameters and move to TPUs, we'll need to rewrite model code in Flax NNX. This is mitigated by keeping model definitions clean and modular |
| |
| ### Migration Strategy (If Needed Later) |
| |
| If domainTokenizer grows to >1B parameters and we need TPU training: |
| |
| 1. **Tokenizer layer stays in Python/HF:** Tokenizer is framework-agnostic β it produces integer sequences regardless of whether the model is PyTorch or JAX |
| 2. **Model architecture translates 1:1:** PyTorch `nn.Module` β Flax NNX `nnx.Module` mapping is straightforward for standard Transformer components |
| 3. **Training loop changes:** PyTorch Trainer β custom Flax NNX training loop with Optax |
| 4. **Reference:** Google's MaxText (`github.com/google/maxtext`) provides production-grade JAX Transformer patterns we can follow |
| |
| **Estimated migration effort:** 2β4 weeks for a clean, well-separated codebase. |
| |
| --- |
| |
| ## 7. Implementation Roadmap |
| |
| ### Phase 2A: Core Tokenizer Library (Weeks 1β3) |
| |
| #### Step 1: Domain Schema Definition |
| |
| Create a declarative schema format that describes the fields in a domain's event data: |
| |
| ```python |
| # src/tokenizers/schema.py |
| |
| from dataclasses import dataclass, field |
| from enum import Enum |
| from typing import List, Optional |
| |
| class FieldType(Enum): |
| NUMERICAL_CONTINUOUS = "numerical_continuous" # prices, amounts β magnitude bins |
| NUMERICAL_DISCRETE = "numerical_discrete" # quantities β small fixed vocab |
| CATEGORICAL_FIXED = "categorical_fixed" # categories, days of week β direct mapping |
| CATEGORICAL_ENTITY = "categorical_entity" # products, merchants β Semantic IDs (RQ-VAE) |
| TEMPORAL = "temporal" # timestamps β calendar decomposition |
| TEXT = "text" # descriptions β BPE subwords |
| SIGN = "sign" # credit/debit β 2 tokens |
| |
| @dataclass |
| class FieldSpec: |
| name: str |
| field_type: FieldType |
| vocab_size: Optional[int] = None # for fixed categorical |
| n_bins: int = 21 # for numerical (Nubank uses 21) |
| calendar_fields: List[str] = field( # for temporal |
| default_factory=lambda: ["month", "dow", "dom", "hour"] |
| ) |
| |
| @dataclass |
| class DomainSchema: |
| name: str # e.g., "ecommerce", "finance" |
| fields: List[FieldSpec] # ordered list of fields per event |
| |
| @property |
| def special_token_count(self) -> int: |
| """Total domain-specific special tokens needed.""" |
| count = 0 |
| for f in self.fields: |
| if f.field_type == FieldType.SIGN: |
| count += 2 |
| elif f.field_type == FieldType.NUMERICAL_CONTINUOUS: |
| count += f.n_bins |
| elif f.field_type == FieldType.CATEGORICAL_FIXED: |
| count += f.vocab_size |
| elif f.field_type == FieldType.TEMPORAL: |
| count += sum({ |
| "month": 12, "dow": 7, "dom": 31, "hour": 24, |
| "quarter": 4, "year": 10 |
| }.get(cf, 0) for cf in f.calendar_fields) |
| return count |
| |
| # Example: Nubank-style financial schema |
| FINANCE_SCHEMA = DomainSchema( |
| name="finance", |
| fields=[ |
| FieldSpec("amount_sign", FieldType.SIGN), |
| FieldSpec("amount", FieldType.NUMERICAL_CONTINUOUS, n_bins=21), |
| FieldSpec("timestamp", FieldType.TEMPORAL, |
| calendar_fields=["month", "dow", "dom", "hour"]), |
| FieldSpec("description", FieldType.TEXT), |
| ] |
| ) |
| |
| # Example: E-commerce schema |
| ECOMMERCE_SCHEMA = DomainSchema( |
| name="ecommerce", |
| fields=[ |
| FieldSpec("event_type", FieldType.CATEGORICAL_FIXED, vocab_size=5), |
| FieldSpec("price", FieldType.NUMERICAL_CONTINUOUS, n_bins=21), |
| FieldSpec("quantity", FieldType.NUMERICAL_DISCRETE, vocab_size=11), |
| FieldSpec("category_l1", FieldType.CATEGORICAL_FIXED, vocab_size=30), |
| FieldSpec("category_l2", FieldType.CATEGORICAL_FIXED, vocab_size=200), |
| FieldSpec("timestamp", FieldType.TEMPORAL, |
| calendar_fields=["month", "dow", "dom", "hour"]), |
| FieldSpec("product_title", FieldType.TEXT), |
| ] |
| ) |
| ``` |
| |
| #### Step 2: Per-Field Tokenizers |
|
|
| Implement each field type tokenizer as a standalone module: |
|
|
| ```python |
| # src/tokenizers/field_tokenizers.py |
| |
| import numpy as np |
| from typing import List |
| |
| class SignTokenizer: |
| """Tokenizes sign of a numerical value (credit/debit, inflow/outflow).""" |
| |
| def __init__(self, prefix: str = "SIGN"): |
| self.tokens = [f"[{prefix}_POS]", f"[{prefix}_NEG]"] |
| |
| def __call__(self, value: float) -> str: |
| return self.tokens[0] if value >= 0 else self.tokens[1] |
| |
| @property |
| def vocab(self) -> List[str]: |
| return self.tokens |
| |
| |
| class MagnitudeBucketTokenizer: |
| """Quantizes continuous values into bins (Nubank-style). |
| |
| Uses quantile-based binning on the training distribution. |
| Follows the Relative Magnitude Tokenization principle from TP-BERTa. |
| """ |
| |
| def __init__(self, n_bins: int = 21, prefix: str = "AMT"): |
| self.n_bins = n_bins |
| self.prefix = prefix |
| self.bin_edges = None # fitted from data |
| |
| def fit(self, values: np.ndarray): |
| """Compute bin edges from training data using quantiles.""" |
| # Use absolute values for magnitude binning |
| abs_vals = np.abs(values[~np.isnan(values)]) |
| quantiles = np.linspace(0, 100, self.n_bins + 1) |
| self.bin_edges = np.percentile(abs_vals, quantiles) |
| return self |
| |
| def __call__(self, value: float) -> str: |
| if self.bin_edges is None: |
| raise ValueError("Tokenizer not fitted. Call .fit() first.") |
| bin_idx = np.searchsorted(self.bin_edges[1:-1], abs(value)) |
| return f"[{self.prefix}_{bin_idx:02d}]" |
| |
| @property |
| def vocab(self) -> List[str]: |
| return [f"[{self.prefix}_{i:02d}]" for i in range(self.n_bins)] |
| |
| |
| class CalendarTokenizer: |
| """Decomposes timestamps into calendar components (Nubank-style).""" |
| |
| FIELD_VOCABS = { |
| "month": ([f"[MON_{i:02d}]" for i in range(1, 13)], lambda dt: dt.month - 1), |
| "dow": ([f"[DOW_{i}]" for i in range(7)], lambda dt: dt.weekday()), |
| "dom": ([f"[DOM_{i:02d}]" for i in range(1, 32)], lambda dt: dt.day - 1), |
| "hour": ([f"[HOUR_{i:02d}]" for i in range(24)], lambda dt: dt.hour), |
| } |
| |
| def __init__(self, fields: List[str] = None): |
| self.fields = fields or ["month", "dow", "dom", "hour"] |
| |
| def __call__(self, timestamp) -> List[str]: |
| tokens = [] |
| for field_name in self.fields: |
| vocab, extractor = self.FIELD_VOCABS[field_name] |
| idx = extractor(timestamp) |
| tokens.append(vocab[idx]) |
| return tokens |
| |
| @property |
| def vocab(self) -> List[str]: |
| all_tokens = [] |
| for field_name in self.fields: |
| all_tokens.extend(self.FIELD_VOCABS[field_name][0]) |
| return all_tokens |
| |
| |
| class CategoricalTokenizer: |
| """Maps categorical values to fixed vocabulary tokens.""" |
| |
| def __init__(self, categories: List[str], prefix: str = "CAT"): |
| self.prefix = prefix |
| self.token_map = {cat: f"[{prefix}_{i:03d}]" for i, cat in enumerate(categories)} |
| self.unk_token = f"[{prefix}_UNK]" |
| |
| def __call__(self, value: str) -> str: |
| return self.token_map.get(value, self.unk_token) |
| |
| @property |
| def vocab(self) -> List[str]: |
| return list(self.token_map.values()) + [self.unk_token] |
| ``` |
|
|
| #### Step 3: Composite Domain Tokenizer |
|
|
| Assemble per-field tokenizers into a complete domain tokenizer, wrapped as `PreTrainedTokenizerFast`: |
|
|
| ```python |
| # src/tokenizers/domain_tokenizer.py |
| |
| from tokenizers import Tokenizer, models, trainers, pre_tokenizers |
| from transformers import PreTrainedTokenizerFast |
| |
| class DomainTokenizerBuilder: |
| """Builds a HuggingFace-compatible tokenizer from a DomainSchema.""" |
| |
| def __init__(self, schema: DomainSchema): |
| self.schema = schema |
| self.field_tokenizers = {} # name β field tokenizer instance |
| self._build_field_tokenizers() |
| |
| def _build_field_tokenizers(self): |
| for field_spec in self.schema.fields: |
| if field_spec.field_type == FieldType.SIGN: |
| self.field_tokenizers[field_spec.name] = SignTokenizer(field_spec.name.upper()) |
| elif field_spec.field_type == FieldType.NUMERICAL_CONTINUOUS: |
| self.field_tokenizers[field_spec.name] = MagnitudeBucketTokenizer( |
| n_bins=field_spec.n_bins, prefix=field_spec.name.upper() |
| ) |
| elif field_spec.field_type == FieldType.TEMPORAL: |
| self.field_tokenizers[field_spec.name] = CalendarTokenizer(field_spec.calendar_fields) |
| # ... other types |
| |
| def fit(self, data): |
| """Fit data-dependent tokenizers (magnitude bins, etc.).""" |
| for field_spec in self.schema.fields: |
| if field_spec.field_type == FieldType.NUMERICAL_CONTINUOUS: |
| values = [getattr(event, field_spec.name) for event in data] |
| self.field_tokenizers[field_spec.name].fit(np.array(values)) |
| return self |
| |
| def build_hf_tokenizer(self, text_corpus=None, bpe_vocab_size=8000) -> PreTrainedTokenizerFast: |
| """Build a complete HuggingFace tokenizer. |
| |
| 1. Collect all domain special tokens |
| 2. Train BPE on text fields (if any) |
| 3. Merge into a single PreTrainedTokenizerFast |
| """ |
| # Collect all special tokens from field tokenizers |
| all_special_tokens = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "[BOS]", "[EOS]"] |
| for name, tok in self.field_tokenizers.items(): |
| if hasattr(tok, 'vocab'): |
| all_special_tokens.extend(tok.vocab) |
| |
| # Train BPE on text fields |
| bpe_tokenizer = Tokenizer(models.BPE(unk_token="[UNK]")) |
| bpe_tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False) |
| trainer = trainers.BpeTrainer( |
| vocab_size=bpe_vocab_size, |
| special_tokens=all_special_tokens, |
| min_frequency=2, |
| ) |
| if text_corpus: |
| bpe_tokenizer.train_from_iterator(text_corpus, trainer=trainer) |
| |
| # Wrap as HuggingFace tokenizer |
| hf_tokenizer = PreTrainedTokenizerFast( |
| tokenizer_object=bpe_tokenizer, |
| bos_token="[BOS]", |
| eos_token="[EOS]", |
| pad_token="[PAD]", |
| unk_token="[UNK]", |
| mask_token="[MASK]", |
| ) |
| |
| return hf_tokenizer |
| |
| def tokenize_event(self, event) -> List[str]: |
| """Convert a single domain event into a list of token strings.""" |
| tokens = [] |
| for field_spec in self.schema.fields: |
| value = getattr(event, field_spec.name, None) |
| if value is None: |
| tokens.append("[UNK]") |
| continue |
| tok = self.field_tokenizers[field_spec.name] |
| result = tok(value) |
| if isinstance(result, list): |
| tokens.extend(result) |
| else: |
| tokens.append(result) |
| return tokens |
| ``` |
|
|
| ### Phase 2B: Model Architecture (Weeks 3β5) |
|
|
| #### Step 4: GPT-style Causal Transformer (NoPE) |
|
|
| Implement as a HuggingFace-compatible `PreTrainedModel`: |
|
|
| ```python |
| # src/models/configuration_domain_transformer.py |
| |
| from transformers import PretrainedConfig |
| |
| class DomainTransformerConfig(PretrainedConfig): |
| model_type = "domain_transformer" |
| |
| def __init__( |
| self, |
| vocab_size: int = 32000, |
| hidden_size: int = 256, # 256 = 24M params, 1024 = 330M (Nubank sizes) |
| num_hidden_layers: int = 24, # Nubank uses 24 for both sizes |
| num_attention_heads: int = 16, # Nubank uses 16 for both sizes |
| intermediate_size: int = None, # defaults to 4 * hidden_size |
| max_position_embeddings: int = 2048, |
| dropout: float = 0.1, |
| use_positional_encoding: bool = False, # NoPE by default! |
| **kwargs |
| ): |
| self.vocab_size = vocab_size |
| self.hidden_size = hidden_size |
| self.num_hidden_layers = num_hidden_layers |
| self.num_attention_heads = num_attention_heads |
| self.intermediate_size = intermediate_size or 4 * hidden_size |
| self.max_position_embeddings = max_position_embeddings |
| self.dropout = dropout |
| self.use_positional_encoding = use_positional_encoding |
| super().__init__(**kwargs) |
| ``` |
|
|
| ```python |
| # src/models/modeling_domain_transformer.py |
| |
| import torch |
| import torch.nn as nn |
| from transformers import PreTrainedModel |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
| |
| class DomainTransformerBlock(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.ln1 = nn.LayerNorm(config.hidden_size) |
| self.attn = nn.MultiheadAttention( |
| config.hidden_size, config.num_attention_heads, |
| dropout=config.dropout, batch_first=True |
| ) |
| self.ln2 = nn.LayerNorm(config.hidden_size) |
| self.mlp = nn.Sequential( |
| nn.Linear(config.hidden_size, config.intermediate_size), |
| nn.GELU(), |
| nn.Linear(config.intermediate_size, config.hidden_size), |
| nn.Dropout(config.dropout), |
| ) |
| |
| def forward(self, x, attn_mask=None): |
| # Pre-norm architecture |
| h = self.ln1(x) |
| h, _ = self.attn(h, h, h, attn_mask=attn_mask, is_causal=True) |
| x = x + h |
| x = x + self.mlp(self.ln2(x)) |
| return x |
| |
| |
| class DomainTransformerForCausalLM(PreTrainedModel): |
| config_class = DomainTransformerConfig |
| |
| def __init__(self, config): |
| super().__init__(config) |
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) |
| |
| # NoPE: no positional encoding by default (Kazemnejad et al. 2023) |
| if config.use_positional_encoding: |
| self.embed_positions = nn.Embedding( |
| config.max_position_embeddings, config.hidden_size |
| ) |
| else: |
| self.embed_positions = None |
| |
| self.drop = nn.Dropout(config.dropout) |
| self.blocks = nn.ModuleList([ |
| DomainTransformerBlock(config) |
| for _ in range(config.num_hidden_layers) |
| ]) |
| self.ln_f = nn.LayerNorm(config.hidden_size) |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| |
| # Weight tying (standard for small models) |
| self.lm_head.weight = self.embed_tokens.weight |
| |
| self.post_init() |
| |
| def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): |
| x = self.embed_tokens(input_ids) |
| |
| if self.embed_positions is not None: |
| positions = torch.arange(input_ids.size(1), device=input_ids.device) |
| x = x + self.embed_positions(positions) |
| |
| x = self.drop(x) |
| |
| for block in self.blocks: |
| x = block(x, attn_mask=attention_mask) |
| |
| x = self.ln_f(x) |
| logits = self.lm_head(x) |
| |
| loss = None |
| if labels is not None: |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| loss = nn.functional.cross_entropy( |
| shift_logits.view(-1, shift_logits.size(-1)), |
| shift_labels.view(-1), |
| ignore_index=-100, |
| ) |
| |
| return CausalLMOutputWithPast(loss=loss, logits=logits) |
| |
| # Register with AutoClass for Hub compatibility |
| DomainTransformerConfig.register_for_auto_class() |
| DomainTransformerForCausalLM.register_for_auto_class("AutoModelForCausalLM") |
| ``` |
|
|
| #### Step 5: PLR Numerical Embeddings (for Joint Fusion) |
|
|
| Port from Yandex's implementation: |
|
|
| ```python |
| # src/models/plr_embeddings.py |
| |
| import torch |
| import torch.nn as nn |
| import math |
| |
| class PeriodicLinearReLU(nn.Module): |
| """PLR numerical embeddings (Gorishniy et al. 2022). |
| |
| Maps scalar x β [sin(2ΟΒ·wΒ·x + b), cos(2ΟΒ·wΒ·x + b)] β Linear β ReLU |
| Frequencies w and phases b are LEARNED parameters. |
| """ |
| |
| def __init__(self, n_features: int, n_frequencies: int = 64, embedding_dim: int = 64): |
| super().__init__() |
| self.n_features = n_features |
| self.n_frequencies = n_frequencies |
| |
| # Learnable frequencies and phases (per feature) |
| self.frequencies = nn.Parameter( |
| torch.randn(n_features, n_frequencies) * 0.01 |
| ) |
| self.phases = nn.Parameter( |
| torch.zeros(n_features, n_frequencies) |
| ) |
| |
| # Linear projection: 2*n_frequencies β embedding_dim |
| self.linear = nn.Linear(2 * n_frequencies, embedding_dim) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| x: (batch, n_features) β raw scalar feature values |
| Returns: |
| (batch, n_features, embedding_dim) |
| """ |
| # x: (B, F) β (B, F, 1) |
| x = x.unsqueeze(-1) |
| |
| # Periodic encoding: (B, F, n_freq) |
| angles = 2 * math.pi * self.frequencies.unsqueeze(0) * x + self.phases.unsqueeze(0) |
| periodic = torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1) # (B, F, 2*n_freq) |
| |
| # Linear + ReLU: (B, F, embedding_dim) |
| return torch.relu(self.linear(periodic)) |
| ``` |
|
|
| ### Phase 2C: Pre-training (Weeks 5β7) |
|
|
| #### Step 6: Data Pipeline |
|
|
| ```python |
| # src/training/data_pipeline.py |
| |
| from torch.utils.data import Dataset |
| from typing import List |
| |
| class DomainSequenceDataset(Dataset): |
| """Converts user event sequences into token sequences for CLM training.""" |
| |
| def __init__(self, user_sequences, tokenizer_builder, hf_tokenizer, max_length=2048): |
| self.user_sequences = user_sequences |
| self.tokenizer_builder = tokenizer_builder |
| self.hf_tokenizer = hf_tokenizer |
| self.max_length = max_length |
| |
| def __len__(self): |
| return len(self.user_sequences) |
| |
| def __getitem__(self, idx): |
| events = self.user_sequences[idx] |
| |
| # Tokenize each event into token strings |
| token_strings = [] |
| for event in events: |
| event_tokens = self.tokenizer_builder.tokenize_event(event) |
| token_strings.extend(event_tokens) |
| |
| # Convert token strings to IDs via HF tokenizer |
| encoding = self.hf_tokenizer( |
| " ".join(token_strings), |
| max_length=self.max_length, |
| truncation=True, |
| padding="max_length", |
| return_tensors="pt", |
| ) |
| |
| input_ids = encoding["input_ids"].squeeze(0) |
| |
| return { |
| "input_ids": input_ids, |
| "labels": input_ids.clone(), # CLM: labels = input shifted by 1 |
| "attention_mask": encoding["attention_mask"].squeeze(0), |
| } |
| ``` |
|
|
| #### Step 7: Pre-training with HuggingFace Trainer |
|
|
| ```python |
| # src/training/pretrain.py |
| |
| from transformers import Trainer, TrainingArguments |
| |
| def pretrain_domain_model( |
| model, |
| train_dataset, |
| eval_dataset=None, |
| output_dir="./checkpoints", |
| hub_model_id="org/domain-model-24m", |
| num_epochs=3, |
| batch_size=64, |
| learning_rate=3e-4, |
| context_length=2048, |
| ): |
| training_args = TrainingArguments( |
| output_dir=output_dir, |
| num_train_epochs=num_epochs, |
| per_device_train_batch_size=batch_size, |
| gradient_accumulation_steps=4, |
| learning_rate=learning_rate, |
| lr_scheduler_type="cosine", |
| warmup_ratio=0.05, |
| weight_decay=0.01, |
| logging_strategy="steps", |
| logging_steps=100, |
| logging_first_step=True, |
| disable_tqdm=True, # plain text logging for cloud |
| eval_strategy="steps" if eval_dataset else "no", |
| eval_steps=500, |
| save_strategy="steps", |
| save_steps=1000, |
| save_total_limit=3, |
| push_to_hub=True, |
| hub_model_id=hub_model_id, |
| bf16=True, |
| dataloader_num_workers=4, |
| report_to="trackio", |
| ) |
| |
| trainer = Trainer( |
| model=model, |
| args=training_args, |
| train_dataset=train_dataset, |
| eval_dataset=eval_dataset, |
| ) |
| |
| trainer.train() |
| trainer.push_to_hub() |
| ``` |
|
|
| ### Phase 2D: Joint Fusion Fine-tuning (Weeks 7β9) |
|
|
| #### Step 8: nuFormer-style Joint Fusion |
|
|
| ```python |
| # src/models/joint_fusion.py |
| |
| import torch |
| import torch.nn as nn |
| |
| class DCNv2CrossLayer(nn.Module): |
| """Single cross layer from DCN V2 (Wang et al. 2021).""" |
| |
| def __init__(self, dim): |
| super().__init__() |
| self.weight = nn.Linear(dim, dim, bias=True) |
| |
| def forward(self, x0, x): |
| return x0 * self.weight(x) + x # element-wise product with anchor |
| |
| |
| class JointFusionModel(nn.Module): |
| """nuFormer-style: Transaction Transformer + DCNv2(PLR) β Joint Prediction. |
| |
| Architecture: |
| Transaction Sequence β Pre-trained DomainTransformer β user_embedding |
| Tabular Features β PLR β DCNv2 β tab_embedding |
| Concatenate β MLP Head β prediction |
| """ |
| |
| def __init__(self, transformer_model, n_tabular_features, n_classes=1, |
| plr_frequencies=64, dcn_layers=3, hidden_dim=256): |
| super().__init__() |
| |
| self.transformer = transformer_model # pre-trained, unfrozen for fine-tuning |
| transformer_dim = transformer_model.config.hidden_size |
| |
| # Tabular branch: PLR β DCNv2 |
| self.plr = PeriodicLinearReLU(n_tabular_features, plr_frequencies, hidden_dim) |
| tab_input_dim = n_tabular_features * hidden_dim |
| |
| self.dcn_layers = nn.ModuleList([ |
| DCNv2CrossLayer(tab_input_dim) for _ in range(dcn_layers) |
| ]) |
| self.dcn_deep = nn.Sequential( |
| nn.Linear(tab_input_dim, hidden_dim), |
| nn.ReLU(), |
| nn.Linear(hidden_dim, hidden_dim), |
| nn.ReLU(), |
| ) |
| |
| # Joint head |
| self.head = nn.Sequential( |
| nn.Linear(transformer_dim + hidden_dim, hidden_dim), |
| nn.ReLU(), |
| nn.Dropout(0.1), |
| nn.Linear(hidden_dim, n_classes), |
| ) |
| |
| def forward(self, input_ids, attention_mask, tabular_features, labels=None): |
| # Transaction branch: get last-token embedding |
| transformer_output = self.transformer(input_ids, attention_mask=attention_mask) |
| user_embedding = transformer_output.logits[:, -1, :] # last token representation |
| |
| # Tabular branch: PLR β flatten β DCNv2 |
| tab_embedded = self.plr(tabular_features) # (B, F, D) |
| tab_flat = tab_embedded.view(tab_embedded.size(0), -1) # (B, F*D) |
| |
| x0 = tab_flat |
| x = tab_flat |
| for cross_layer in self.dcn_layers: |
| x = cross_layer(x0, x) |
| tab_output = self.dcn_deep(x) # (B, hidden_dim) |
| |
| # Joint fusion |
| combined = torch.cat([user_embedding, tab_output], dim=-1) |
| logits = self.head(combined) |
| |
| loss = None |
| if labels is not None: |
| loss = nn.functional.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float()) |
| |
| return {"loss": loss, "logits": logits} |
| ``` |
|
|
| ### Phase 3: Domain Demos (Weeks 9β12) |
|
|
| | Week | Deliverable | Hardware | |
| |------|------------|----------| |
| | 9β10 | **Finance demo:** Transaction tokenizer + 24M model pre-trained on synthetic/public financial data + fraud detection fine-tuning | a10g-large | |
| | 10β11 | **E-commerce demo:** Event tokenizer + 24M model pre-trained on Amazon review sequences + next-purchase prediction | a10g-large | |
| | 11β12 | **Evaluation & benchmarking:** Compare domain tokenizer vs. text serialization vs. LightGBM baselines on each domain | a10g-large | |
|
|
| ### Phase 4: Scale & Optimize (Weeks 12+) |
|
|
| | Task | Details | |
| |------|---------| |
| | Scale to 330M params | Increase `hidden_size` to 1024, train on a100-large | |
| | `torch.compile()` | Apply to attention and MLP blocks for 10β30% speedup | |
| | ONNX export | `torch.onnx.export()` for production serving | |
| | Context window experiments | Ablate 512/1024/2048/4096 context lengths | |
| | Data source ablation | Test impact of different event types (Nubank found adding low-signal sources hurts) | |
| | ActionPiece vocabulary | Implement BPE-like cross-field merging on top of per-field tokens | |
|
|
| --- |
|
|
| ## Appendix A: Framework Usage Across Reference Papers |
|
|
| | Paper | ArXiv | Framework | Verbatim Evidence | |
| |-------|-------|-----------|-------------------| |
| | nuFormer (Nubank) | 2507.23267 | PyTorch + HF (inferred) | All dependencies are PyTorch-based | |
| | TIGER (Google) | 2305.05065 | JAX + T5X | "We use the open-sourced T5X framework" | |
| | ActionPiece (DeepMind) | 2502.13581 | PyTorch + HF | "HuggingFace Transformers and PyTorch" (Appendix H) | |
| | RecFormer | 2305.13731 | PyTorch + HF Longformer | "Longformer implemented by Huggingface" (Β§3.1.4) | |
| | Banking TF | 2410.08243 | PyTorch | "Pytorch backend is used" (Appendix B) | |
| | PLR Embeddings (Yandex) | 2203.05556 | PyTorch | Repository: pure PyTorch + scikit-learn | |
| | KL3M Tokenizers | 2503.17247 | HF `tokenizers` + PyTorch | "tokenizers" BPE for HF compatibility | |
|
|
| --- |
|
|
| ## Appendix B: Head-to-Head Comparison Matrix |
|
|
| | Criterion | PyTorch + HF | JAX/Flax NNX | Keras 3 + JAX | |
| |-----------|-------------|-------------|---------------| |
| | Custom domain tokenizer | β
`PreTrainedTokenizerFast` | β Build from scratch | β Build from scratch | |
| | HF Trainer integration | β
Native | β Not compatible | β Not compatible | |
| | Hub push/pull | β
`push_to_hub()` | β οΈ Works, fragmented | β οΈ Limited | |
| | PEFT/LoRA | β
Drop-in | β Manual | β Manual | |
| | ONNX export | β
One-line | β Multi-hop | β οΈ TF backend required | |
| | TGI/vLLM serving | β
First-class | β Not supported | β Not supported | |
| | TPU training | β οΈ PyTorch/XLA (overhead) | β
Native | β
Native | |
| | JIT compilation | β
`torch.compile()` | β
`@nnx.jit` | β
XLA via JAX | |
| | Dynamic shapes (NLP) | β
Native | β Recompiles | β Recompiles | |
| | Debugging | β
Eager mode, std debugger | β οΈ Challenging | β οΈ Challenging | |
| | Reference implementations | 5/6 papers | 1/6 papers | 0/6 papers | |
| | Community/hiring pool | π’ Large | π‘ Small | π‘ Small | |
|
|
| --- |
|
|
| ## Appendix C: Key Code Patterns |
|
|
| ### Adding Domain Special Tokens to an Existing Tokenizer |
|
|
| ```python |
| from transformers import AutoTokenizer |
| |
| tokenizer = AutoTokenizer.from_pretrained("gpt2") # or any base tokenizer |
| |
| # Add all domain special tokens |
| special_tokens = { |
| "additional_special_tokens": [ |
| # Amount tokens (Nubank-style) |
| "[AMT_POS]", "[AMT_NEG]", |
| *[f"[AMT_{i:02d}]" for i in range(21)], |
| # Calendar tokens |
| *[f"[MON_{i:02d}]" for i in range(1, 13)], |
| *[f"[DOW_{i}]" for i in range(7)], |
| *[f"[DOM_{i:02d}]" for i in range(1, 32)], |
| *[f"[HOUR_{i:02d}]" for i in range(24)], |
| ] |
| } |
| num_added = tokenizer.add_special_tokens(special_tokens) |
| print(f"Added {num_added} domain tokens. Vocab size: {len(tokenizer)}") |
| |
| # CRITICAL: resize model embeddings |
| model.resize_token_embeddings(len(tokenizer)) |
| ``` |
|
|
| ### Registering a Custom Model for Hub Deployment |
|
|
| ```python |
| # In your package's __init__.py or a registration script: |
| from transformers import AutoConfig, AutoModelForCausalLM |
| |
| from .configuration_domain_transformer import DomainTransformerConfig |
| from .modeling_domain_transformer import DomainTransformerForCausalLM |
| |
| # Register so AutoClass can find your model |
| AutoConfig.register("domain_transformer", DomainTransformerConfig) |
| AutoModelForCausalLM.register(DomainTransformerConfig, DomainTransformerForCausalLM) |
| |
| # Enable push_to_hub with custom code |
| DomainTransformerConfig.register_for_auto_class() |
| DomainTransformerForCausalLM.register_for_auto_class("AutoModelForCausalLM") |
| |
| # Push: uploads configuration.py, modeling.py, config.json, model.safetensors |
| model.push_to_hub("org/domain-transformer-24m") |
| |
| # Load anywhere: |
| model = AutoModelForCausalLM.from_pretrained("org/domain-transformer-24m", trust_remote_code=True) |
| ``` |
|
|
| --- |
|
|
| *This ADR is a living document and will be updated as implementation progresses.* |