domainTokenizer / docs /adr /ADR-001-implementation-framework.md
rtferraz's picture
Add ADR-001: Implementation framework decision with detailed roadmap
25a1093 verified
|
raw
history blame
40 kB
# ADR-001: Implementation Framework for domainTokenizer
> **Status:** Accepted
> **Date:** April 29, 2026
> **Decision:** PyTorch + HuggingFace Transformers as primary framework, with JAX/Flax NNX as future scaling path
> **Deciders:** domainTokenizer core team
---
## Table of Contents
1. [Context](#1-context)
2. [Goal](#2-goal)
3. [Options Evaluated](#3-options-evaluated)
4. [Decision](#4-decision)
5. [Trade-offs and Justification](#5-trade-offs-and-justification)
6. [Consequences](#6-consequences)
7. [Implementation Roadmap](#7-implementation-roadmap)
8. [Appendix A: Framework Usage Across Reference Papers](#appendix-a-framework-usage-across-reference-papers)
9. [Appendix B: Head-to-Head Comparison Matrix](#appendix-b-head-to-head-comparison-matrix)
10. [Appendix C: Key Code Patterns](#appendix-c-key-code-patterns)
---
## 1. Context
### What We're Building
domainTokenizer is a library for building **small Transformer models (24M–330M parameters)** that process **domain-specific tokens** β€” financial transactions, e-commerce events, healthcare records β€” instead of natural language text. The architecture follows the validated pattern from Nubank's nuFormer ([arXiv: 2507.23267](https://arxiv.org/abs/2507.23267)):
```
Domain Events β†’ Custom Tokenizer β†’ GPT-style Transformer β†’ Foundation Model β†’ Downstream Tasks
```
### The Implementation Question
A video from Google for Developers presents the **Keras 3 + JAX/Flax NNX integration** as a potential framework, citing:
- **Explicit state management** (Flax NNX) β€” useful for tracking sequential transaction state
- **Custom training loops** β€” Keras structure + JAX/Optax for domain-specific gradient control
- **JIT compilation** (`@nnx.jit`) β€” high-performance processing of millions of transactions
- **Paradigm mixing** β€” Keras layers for standard components + NNX for custom sequential encoders
The question: **Is Keras + JAX/Flax NNX the right framework for domainTokenizer, or is there a better choice?**
### Constraints
1. **Custom tokenizer required:** We need a tokenizer that maps structured fields (amounts, dates, categories) to special tokens β€” not a standard text tokenizer
2. **Small models:** 24M–330M parameters, not 70B+ β€” framework overhead matters less than developer velocity
3. **Production deployment:** Models must be servable with low latency for real-time applications (fraud detection, recommendations)
4. **GPU hardware:** Development on A100/A10G GPUs, not TPUs (standard cloud environment)
5. **Team context:** ML engineers familiar with Python, PyTorch, and the HuggingFace ecosystem
6. **Iteration speed:** Need to prototype quickly across multiple domains (finance, e-commerce, healthcare)
---
## 2. Goal
Choose an implementation framework that:
1. **Minimizes time from research to working prototype** β€” weeks, not months
2. **Supports custom domain tokenizers** as first-class citizens
3. **Integrates with the HuggingFace Hub** for model sharing, versioning, and community
4. **Enables production deployment** via standard serving infrastructure (ONNX, TGI, vLLM, etc.)
5. **Scales to 330M parameters** on 4–8 GPUs without heroic engineering
6. **Does not preclude future migration** to JAX/TPU if we need to scale beyond 1B parameters
---
## 3. Options Evaluated
### Option A: PyTorch + HuggingFace Transformers
The dominant ecosystem for custom NLP/sequential models. Provides `PreTrainedModel`, `PreTrainedTokenizerFast`, `Trainer`, `push_to_hub`, ONNX export, and integration with TRL, PEFT, Accelerate, DeepSpeed.
### Option B: Keras 3 + JAX Backend + Flax NNX
Google's multi-backend framework. Keras provides high-level APIs; JAX provides XLA compilation and functional transforms; Flax NNX provides PyTorch-like stateful modules on top of JAX.
### Option C: Pure JAX + Flax NNX + Optax
Skip Keras entirely. Use Flax NNX for model definition, Optax for optimization, Orbax for checkpointing, and Grain/tf.data for data loading. Google's MaxText framework follows this pattern.
### Option D: PyTorch + Custom (no HuggingFace)
Use PyTorch directly without the HuggingFace abstraction layer. Full control but no ecosystem integration.
---
## 4. Decision
### Primary: PyTorch + HuggingFace Transformers (Option A)
### Future scaling path: JAX/Flax NNX (Option C) β€” if and when we need TPU training at >1B parameters
---
## 5. Trade-offs and Justification
### 5.1 What the Reference Papers Actually Use
We audited the frameworks used by every paper in the domainTokenizer research corpus. The result is overwhelming:
| Paper | Framework | Confidence |
|-------|-----------|------------|
| **nuFormer** (Nubank) | PyTorch + HF Transformers (inferred) | ~90% |
| **TIGER** (Google) | JAX + T5X (official); PyTorch (community reimpl) | 100% |
| **ActionPiece** (Google DeepMind) | **PyTorch + HF Transformers** (stated verbatim in paper) | 100% |
| **RecFormer** (UCSD/Amazon) | **PyTorch + HF Transformers (Longformer)** (stated verbatim) | 100% |
| **Banking Transaction Flow** | **PyTorch** (stated verbatim in appendix) | 100% |
| **PLR Embeddings** (Yandex) | **PyTorch** + scikit-learn + Optuna | 100% |
**5 of 6 papers use PyTorch.** The sole JAX user (TIGER) was a Google-internal project using T5X, and even its most popular community reimplementation (781⭐) is in PyTorch.
Even **Google DeepMind's own ActionPiece** β€” the paper most relevant to our domain tokenization approach β€” uses PyTorch + HuggingFace. This is the strongest signal possible.
### 5.2 Custom Tokenizer Story
This is the **decisive factor**. domainTokenizer's core innovation is the tokenizer itself. The framework must provide first-class support for custom token vocabularies.
**PyTorch + HuggingFace:**
- Train custom BPE tokenizer via `tokenizers` library (Rust-backed, fast)
- Wrap in `PreTrainedTokenizerFast` β†’ full Trainer compatibility
- Add domain special tokens via `add_special_tokens()` β†’ auto-resize embeddings
- Push tokenizer to Hub: `tokenizer.push_to_hub("org/my-tokenizer")`
- Load anywhere: `AutoTokenizer.from_pretrained("org/my-tokenizer")`
- **KL3M** ([arXiv: 2503.17247](https://arxiv.org/abs/2503.17247)) β€” the gold standard for financial domain tokenizers β€” is built entirely on this stack
**Keras + JAX/Flax NNX:**
- No equivalent to `PreTrainedTokenizerFast`
- No Hub-integrated tokenizer format
- Must build custom tokenizer from scratch with no ecosystem support
- No standard serialization/deserialization for domain vocabularies
**Verdict:** PyTorch/HF has a **complete, production-tested** custom tokenizer pipeline. JAX/Keras has **nothing** β€” you'd build everything from scratch.
### 5.3 Production Deployment
| Path | PyTorch | JAX/Keras |
|------|---------|-----------|
| ONNX export | `torch.onnx.export()` β€” one line | Requires TF backend intermediate or experimental `jax.export` |
| TensorRT | ONNX β†’ TRT (standard) | Multi-hop, fragile |
| TGI (HuggingFace inference) | First-class | Not supported |
| vLLM | First-class | Not supported |
| Triton Inference Server | Direct ONNX/TorchScript | Via ONNX (workaround) |
| BentoML | Supported | Supported |
| Model Hub sharing | `push_to_hub()` β†’ `from_pretrained()` | Works but fragmented (`.msgpack` weights, no Trainer compat) |
**Verdict:** PyTorch has **direct, tested paths** to every major serving framework. JAX requires **multiple intermediate conversions**, each introducing failure points.
### 5.4 Training Speed
At our scale (24M–330M parameters on 4–8 A100s):
| Scenario | PyTorch | JAX |
|----------|---------|-----|
| Steady-state training throughput | **Comparable** (`torch.compile`) | **Comparable** (XLA JIT) |
| Variable-length sequences | **Native** β€” dynamic shapes | **Problematic** β€” recompiles on new shapes; must pad to buckets |
| Multi-GPU (FSDP) | `accelerate` + FSDP2 β€” mature | `pmap`/`shard_map` β€” works but harder to configure |
| First-run compilation | Instant (eager mode) | 5–20s JIT compilation overhead |
| Debugging | Standard Python debugger | `print` debugging; cryptic XLA errors |
**Verdict:** At 330M parameters, training speed is a **wash**. JAX's advantages (XLA kernel fusion, TPU native) only matter at 10B+ parameters on 256+ accelerators. At our scale, **developer velocity dominates throughput**.
### 5.5 The JAX Advantage: When It Would Win
JAX/Flax NNX would be the right choice **if**:
1. **Training exclusively on Google TPUs** β€” JAX is the native TPU compiler; PyTorch/XLA is a port with overhead
2. **Models >1B parameters** β€” XLA's whole-program optimization shines at scale
3. **Fixed-shape workloads** β€” images, fixed-length token sequences (no variable-length padding issues)
4. **Need functional transforms** β€” `vmap` (per-sample gradients), `pmap` (data parallelism), `grad` (higher-order derivatives)
5. **Google Cloud infrastructure** β€” Vertex AI, TPU VMs, GCS integration
For domainTokenizer's current scope (24M–330M, GPU, variable-length sequences, fast iteration), **none of these conditions apply**.
### 5.6 The Keras + JAX Mixing Argument
The Google for Developers video argues for mixing Keras layers (high-level) with NNX modules (custom, high-performance). In theory, this lets you:
- Use Keras for standard Transformer layers
- Use NNX for custom sequential transaction encoders
- Get JIT compilation on the NNX parts
**In practice, this creates problems:**
1. **Two mental models:** Keras (layer-oriented, `fit/compile`) vs. NNX (functional, explicit state) β€” context switching slows development
2. **Limited interop documentation:** Keras ↔ NNX examples are thin; edge cases are poorly documented
3. **No HF ecosystem integration:** You lose Trainer, push_to_hub, PEFT, TRL, Accelerate β€” the entire ecosystem Nubank and ActionPiece rely on
4. **Debugging complexity:** Errors in the Keras↔NNX boundary are hard to diagnose
**Better approach with PyTorch:** Use `torch.compile()` on performance-critical modules to get JIT compilation benefits without leaving the PyTorch ecosystem. Write custom `nn.Module` subclasses for domain-specific components. This gives you the same "standard parts + custom parts" architecture without framework mixing.
---
## 6. Consequences
### What We Gain
1. **Immediate access to the entire HuggingFace ecosystem:** Trainer, Accelerate, PEFT (LoRA), TRL, Evaluate, push_to_hub, from_pretrained, ONNX export, TGI serving
2. **Copy-paste from reference implementations:** ActionPiece, RecFormer, Banking TF, and PLR embeddings are all PyTorch β€” we can directly reuse their code
3. **KL3M tokenizer as starting point:** The best financial domain tokenizer already exists in PyTorch/HF format at `alea-institute/kl3m-004-128k-cased`
4. **Standard production deployment:** ONNX β†’ TensorRT β†’ Triton, or direct TGI/vLLM serving
5. **Community and hiring:** PyTorch is the dominant ML framework; finding contributors and documentation is easy
6. **`torch.compile()` for performance:** When we need JIT compilation on hot paths, `torch.compile()` provides 10–30% speedups without leaving the ecosystem
### What We Accept
1. **No native TPU support:** If we later need to train on Google TPUs, we'll need PyTorch/XLA (slower than native JAX) or migrate the model code
2. **No functional transforms:** `vmap` (per-sample gradients) isn't available without `functorch` (experimental). If we need advanced gradient manipulation for meta-learning or Nested Learning (HOPE-style), JAX would be better
3. **Potential future migration cost:** If we scale beyond 1B parameters and move to TPUs, we'll need to rewrite model code in Flax NNX. This is mitigated by keeping model definitions clean and modular
### Migration Strategy (If Needed Later)
If domainTokenizer grows to >1B parameters and we need TPU training:
1. **Tokenizer layer stays in Python/HF:** Tokenizer is framework-agnostic β€” it produces integer sequences regardless of whether the model is PyTorch or JAX
2. **Model architecture translates 1:1:** PyTorch `nn.Module` β†’ Flax NNX `nnx.Module` mapping is straightforward for standard Transformer components
3. **Training loop changes:** PyTorch Trainer β†’ custom Flax NNX training loop with Optax
4. **Reference:** Google's MaxText (`github.com/google/maxtext`) provides production-grade JAX Transformer patterns we can follow
**Estimated migration effort:** 2–4 weeks for a clean, well-separated codebase.
---
## 7. Implementation Roadmap
### Phase 2A: Core Tokenizer Library (Weeks 1–3)
#### Step 1: Domain Schema Definition
Create a declarative schema format that describes the fields in a domain's event data:
```python
# src/tokenizers/schema.py
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Optional
class FieldType(Enum):
NUMERICAL_CONTINUOUS = "numerical_continuous" # prices, amounts β†’ magnitude bins
NUMERICAL_DISCRETE = "numerical_discrete" # quantities β†’ small fixed vocab
CATEGORICAL_FIXED = "categorical_fixed" # categories, days of week β†’ direct mapping
CATEGORICAL_ENTITY = "categorical_entity" # products, merchants β†’ Semantic IDs (RQ-VAE)
TEMPORAL = "temporal" # timestamps β†’ calendar decomposition
TEXT = "text" # descriptions β†’ BPE subwords
SIGN = "sign" # credit/debit β†’ 2 tokens
@dataclass
class FieldSpec:
name: str
field_type: FieldType
vocab_size: Optional[int] = None # for fixed categorical
n_bins: int = 21 # for numerical (Nubank uses 21)
calendar_fields: List[str] = field( # for temporal
default_factory=lambda: ["month", "dow", "dom", "hour"]
)
@dataclass
class DomainSchema:
name: str # e.g., "ecommerce", "finance"
fields: List[FieldSpec] # ordered list of fields per event
@property
def special_token_count(self) -> int:
"""Total domain-specific special tokens needed."""
count = 0
for f in self.fields:
if f.field_type == FieldType.SIGN:
count += 2
elif f.field_type == FieldType.NUMERICAL_CONTINUOUS:
count += f.n_bins
elif f.field_type == FieldType.CATEGORICAL_FIXED:
count += f.vocab_size
elif f.field_type == FieldType.TEMPORAL:
count += sum({
"month": 12, "dow": 7, "dom": 31, "hour": 24,
"quarter": 4, "year": 10
}.get(cf, 0) for cf in f.calendar_fields)
return count
# Example: Nubank-style financial schema
FINANCE_SCHEMA = DomainSchema(
name="finance",
fields=[
FieldSpec("amount_sign", FieldType.SIGN),
FieldSpec("amount", FieldType.NUMERICAL_CONTINUOUS, n_bins=21),
FieldSpec("timestamp", FieldType.TEMPORAL,
calendar_fields=["month", "dow", "dom", "hour"]),
FieldSpec("description", FieldType.TEXT),
]
)
# Example: E-commerce schema
ECOMMERCE_SCHEMA = DomainSchema(
name="ecommerce",
fields=[
FieldSpec("event_type", FieldType.CATEGORICAL_FIXED, vocab_size=5),
FieldSpec("price", FieldType.NUMERICAL_CONTINUOUS, n_bins=21),
FieldSpec("quantity", FieldType.NUMERICAL_DISCRETE, vocab_size=11),
FieldSpec("category_l1", FieldType.CATEGORICAL_FIXED, vocab_size=30),
FieldSpec("category_l2", FieldType.CATEGORICAL_FIXED, vocab_size=200),
FieldSpec("timestamp", FieldType.TEMPORAL,
calendar_fields=["month", "dow", "dom", "hour"]),
FieldSpec("product_title", FieldType.TEXT),
]
)
```
#### Step 2: Per-Field Tokenizers
Implement each field type tokenizer as a standalone module:
```python
# src/tokenizers/field_tokenizers.py
import numpy as np
from typing import List
class SignTokenizer:
"""Tokenizes sign of a numerical value (credit/debit, inflow/outflow)."""
def __init__(self, prefix: str = "SIGN"):
self.tokens = [f"[{prefix}_POS]", f"[{prefix}_NEG]"]
def __call__(self, value: float) -> str:
return self.tokens[0] if value >= 0 else self.tokens[1]
@property
def vocab(self) -> List[str]:
return self.tokens
class MagnitudeBucketTokenizer:
"""Quantizes continuous values into bins (Nubank-style).
Uses quantile-based binning on the training distribution.
Follows the Relative Magnitude Tokenization principle from TP-BERTa.
"""
def __init__(self, n_bins: int = 21, prefix: str = "AMT"):
self.n_bins = n_bins
self.prefix = prefix
self.bin_edges = None # fitted from data
def fit(self, values: np.ndarray):
"""Compute bin edges from training data using quantiles."""
# Use absolute values for magnitude binning
abs_vals = np.abs(values[~np.isnan(values)])
quantiles = np.linspace(0, 100, self.n_bins + 1)
self.bin_edges = np.percentile(abs_vals, quantiles)
return self
def __call__(self, value: float) -> str:
if self.bin_edges is None:
raise ValueError("Tokenizer not fitted. Call .fit() first.")
bin_idx = np.searchsorted(self.bin_edges[1:-1], abs(value))
return f"[{self.prefix}_{bin_idx:02d}]"
@property
def vocab(self) -> List[str]:
return [f"[{self.prefix}_{i:02d}]" for i in range(self.n_bins)]
class CalendarTokenizer:
"""Decomposes timestamps into calendar components (Nubank-style)."""
FIELD_VOCABS = {
"month": ([f"[MON_{i:02d}]" for i in range(1, 13)], lambda dt: dt.month - 1),
"dow": ([f"[DOW_{i}]" for i in range(7)], lambda dt: dt.weekday()),
"dom": ([f"[DOM_{i:02d}]" for i in range(1, 32)], lambda dt: dt.day - 1),
"hour": ([f"[HOUR_{i:02d}]" for i in range(24)], lambda dt: dt.hour),
}
def __init__(self, fields: List[str] = None):
self.fields = fields or ["month", "dow", "dom", "hour"]
def __call__(self, timestamp) -> List[str]:
tokens = []
for field_name in self.fields:
vocab, extractor = self.FIELD_VOCABS[field_name]
idx = extractor(timestamp)
tokens.append(vocab[idx])
return tokens
@property
def vocab(self) -> List[str]:
all_tokens = []
for field_name in self.fields:
all_tokens.extend(self.FIELD_VOCABS[field_name][0])
return all_tokens
class CategoricalTokenizer:
"""Maps categorical values to fixed vocabulary tokens."""
def __init__(self, categories: List[str], prefix: str = "CAT"):
self.prefix = prefix
self.token_map = {cat: f"[{prefix}_{i:03d}]" for i, cat in enumerate(categories)}
self.unk_token = f"[{prefix}_UNK]"
def __call__(self, value: str) -> str:
return self.token_map.get(value, self.unk_token)
@property
def vocab(self) -> List[str]:
return list(self.token_map.values()) + [self.unk_token]
```
#### Step 3: Composite Domain Tokenizer
Assemble per-field tokenizers into a complete domain tokenizer, wrapped as `PreTrainedTokenizerFast`:
```python
# src/tokenizers/domain_tokenizer.py
from tokenizers import Tokenizer, models, trainers, pre_tokenizers
from transformers import PreTrainedTokenizerFast
class DomainTokenizerBuilder:
"""Builds a HuggingFace-compatible tokenizer from a DomainSchema."""
def __init__(self, schema: DomainSchema):
self.schema = schema
self.field_tokenizers = {} # name β†’ field tokenizer instance
self._build_field_tokenizers()
def _build_field_tokenizers(self):
for field_spec in self.schema.fields:
if field_spec.field_type == FieldType.SIGN:
self.field_tokenizers[field_spec.name] = SignTokenizer(field_spec.name.upper())
elif field_spec.field_type == FieldType.NUMERICAL_CONTINUOUS:
self.field_tokenizers[field_spec.name] = MagnitudeBucketTokenizer(
n_bins=field_spec.n_bins, prefix=field_spec.name.upper()
)
elif field_spec.field_type == FieldType.TEMPORAL:
self.field_tokenizers[field_spec.name] = CalendarTokenizer(field_spec.calendar_fields)
# ... other types
def fit(self, data):
"""Fit data-dependent tokenizers (magnitude bins, etc.)."""
for field_spec in self.schema.fields:
if field_spec.field_type == FieldType.NUMERICAL_CONTINUOUS:
values = [getattr(event, field_spec.name) for event in data]
self.field_tokenizers[field_spec.name].fit(np.array(values))
return self
def build_hf_tokenizer(self, text_corpus=None, bpe_vocab_size=8000) -> PreTrainedTokenizerFast:
"""Build a complete HuggingFace tokenizer.
1. Collect all domain special tokens
2. Train BPE on text fields (if any)
3. Merge into a single PreTrainedTokenizerFast
"""
# Collect all special tokens from field tokenizers
all_special_tokens = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "[BOS]", "[EOS]"]
for name, tok in self.field_tokenizers.items():
if hasattr(tok, 'vocab'):
all_special_tokens.extend(tok.vocab)
# Train BPE on text fields
bpe_tokenizer = Tokenizer(models.BPE(unk_token="[UNK]"))
bpe_tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
trainer = trainers.BpeTrainer(
vocab_size=bpe_vocab_size,
special_tokens=all_special_tokens,
min_frequency=2,
)
if text_corpus:
bpe_tokenizer.train_from_iterator(text_corpus, trainer=trainer)
# Wrap as HuggingFace tokenizer
hf_tokenizer = PreTrainedTokenizerFast(
tokenizer_object=bpe_tokenizer,
bos_token="[BOS]",
eos_token="[EOS]",
pad_token="[PAD]",
unk_token="[UNK]",
mask_token="[MASK]",
)
return hf_tokenizer
def tokenize_event(self, event) -> List[str]:
"""Convert a single domain event into a list of token strings."""
tokens = []
for field_spec in self.schema.fields:
value = getattr(event, field_spec.name, None)
if value is None:
tokens.append("[UNK]")
continue
tok = self.field_tokenizers[field_spec.name]
result = tok(value)
if isinstance(result, list):
tokens.extend(result)
else:
tokens.append(result)
return tokens
```
### Phase 2B: Model Architecture (Weeks 3–5)
#### Step 4: GPT-style Causal Transformer (NoPE)
Implement as a HuggingFace-compatible `PreTrainedModel`:
```python
# src/models/configuration_domain_transformer.py
from transformers import PretrainedConfig
class DomainTransformerConfig(PretrainedConfig):
model_type = "domain_transformer"
def __init__(
self,
vocab_size: int = 32000,
hidden_size: int = 256, # 256 = 24M params, 1024 = 330M (Nubank sizes)
num_hidden_layers: int = 24, # Nubank uses 24 for both sizes
num_attention_heads: int = 16, # Nubank uses 16 for both sizes
intermediate_size: int = None, # defaults to 4 * hidden_size
max_position_embeddings: int = 2048,
dropout: float = 0.1,
use_positional_encoding: bool = False, # NoPE by default!
**kwargs
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size or 4 * hidden_size
self.max_position_embeddings = max_position_embeddings
self.dropout = dropout
self.use_positional_encoding = use_positional_encoding
super().__init__(**kwargs)
```
```python
# src/models/modeling_domain_transformer.py
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
class DomainTransformerBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.ln1 = nn.LayerNorm(config.hidden_size)
self.attn = nn.MultiheadAttention(
config.hidden_size, config.num_attention_heads,
dropout=config.dropout, batch_first=True
)
self.ln2 = nn.LayerNorm(config.hidden_size)
self.mlp = nn.Sequential(
nn.Linear(config.hidden_size, config.intermediate_size),
nn.GELU(),
nn.Linear(config.intermediate_size, config.hidden_size),
nn.Dropout(config.dropout),
)
def forward(self, x, attn_mask=None):
# Pre-norm architecture
h = self.ln1(x)
h, _ = self.attn(h, h, h, attn_mask=attn_mask, is_causal=True)
x = x + h
x = x + self.mlp(self.ln2(x))
return x
class DomainTransformerForCausalLM(PreTrainedModel):
config_class = DomainTransformerConfig
def __init__(self, config):
super().__init__(config)
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
# NoPE: no positional encoding by default (Kazemnejad et al. 2023)
if config.use_positional_encoding:
self.embed_positions = nn.Embedding(
config.max_position_embeddings, config.hidden_size
)
else:
self.embed_positions = None
self.drop = nn.Dropout(config.dropout)
self.blocks = nn.ModuleList([
DomainTransformerBlock(config)
for _ in range(config.num_hidden_layers)
])
self.ln_f = nn.LayerNorm(config.hidden_size)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Weight tying (standard for small models)
self.lm_head.weight = self.embed_tokens.weight
self.post_init()
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
x = self.embed_tokens(input_ids)
if self.embed_positions is not None:
positions = torch.arange(input_ids.size(1), device=input_ids.device)
x = x + self.embed_positions(positions)
x = self.drop(x)
for block in self.blocks:
x = block(x, attn_mask=attention_mask)
x = self.ln_f(x)
logits = self.lm_head(x)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = nn.functional.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100,
)
return CausalLMOutputWithPast(loss=loss, logits=logits)
# Register with AutoClass for Hub compatibility
DomainTransformerConfig.register_for_auto_class()
DomainTransformerForCausalLM.register_for_auto_class("AutoModelForCausalLM")
```
#### Step 5: PLR Numerical Embeddings (for Joint Fusion)
Port from Yandex's implementation:
```python
# src/models/plr_embeddings.py
import torch
import torch.nn as nn
import math
class PeriodicLinearReLU(nn.Module):
"""PLR numerical embeddings (Gorishniy et al. 2022).
Maps scalar x β†’ [sin(2π·wΒ·x + b), cos(2π·wΒ·x + b)] β†’ Linear β†’ ReLU
Frequencies w and phases b are LEARNED parameters.
"""
def __init__(self, n_features: int, n_frequencies: int = 64, embedding_dim: int = 64):
super().__init__()
self.n_features = n_features
self.n_frequencies = n_frequencies
# Learnable frequencies and phases (per feature)
self.frequencies = nn.Parameter(
torch.randn(n_features, n_frequencies) * 0.01
)
self.phases = nn.Parameter(
torch.zeros(n_features, n_frequencies)
)
# Linear projection: 2*n_frequencies β†’ embedding_dim
self.linear = nn.Linear(2 * n_frequencies, embedding_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (batch, n_features) β€” raw scalar feature values
Returns:
(batch, n_features, embedding_dim)
"""
# x: (B, F) β†’ (B, F, 1)
x = x.unsqueeze(-1)
# Periodic encoding: (B, F, n_freq)
angles = 2 * math.pi * self.frequencies.unsqueeze(0) * x + self.phases.unsqueeze(0)
periodic = torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1) # (B, F, 2*n_freq)
# Linear + ReLU: (B, F, embedding_dim)
return torch.relu(self.linear(periodic))
```
### Phase 2C: Pre-training (Weeks 5–7)
#### Step 6: Data Pipeline
```python
# src/training/data_pipeline.py
from torch.utils.data import Dataset
from typing import List
class DomainSequenceDataset(Dataset):
"""Converts user event sequences into token sequences for CLM training."""
def __init__(self, user_sequences, tokenizer_builder, hf_tokenizer, max_length=2048):
self.user_sequences = user_sequences
self.tokenizer_builder = tokenizer_builder
self.hf_tokenizer = hf_tokenizer
self.max_length = max_length
def __len__(self):
return len(self.user_sequences)
def __getitem__(self, idx):
events = self.user_sequences[idx]
# Tokenize each event into token strings
token_strings = []
for event in events:
event_tokens = self.tokenizer_builder.tokenize_event(event)
token_strings.extend(event_tokens)
# Convert token strings to IDs via HF tokenizer
encoding = self.hf_tokenizer(
" ".join(token_strings),
max_length=self.max_length,
truncation=True,
padding="max_length",
return_tensors="pt",
)
input_ids = encoding["input_ids"].squeeze(0)
return {
"input_ids": input_ids,
"labels": input_ids.clone(), # CLM: labels = input shifted by 1
"attention_mask": encoding["attention_mask"].squeeze(0),
}
```
#### Step 7: Pre-training with HuggingFace Trainer
```python
# src/training/pretrain.py
from transformers import Trainer, TrainingArguments
def pretrain_domain_model(
model,
train_dataset,
eval_dataset=None,
output_dir="./checkpoints",
hub_model_id="org/domain-model-24m",
num_epochs=3,
batch_size=64,
learning_rate=3e-4,
context_length=2048,
):
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=num_epochs,
per_device_train_batch_size=batch_size,
gradient_accumulation_steps=4,
learning_rate=learning_rate,
lr_scheduler_type="cosine",
warmup_ratio=0.05,
weight_decay=0.01,
logging_strategy="steps",
logging_steps=100,
logging_first_step=True,
disable_tqdm=True, # plain text logging for cloud
eval_strategy="steps" if eval_dataset else "no",
eval_steps=500,
save_strategy="steps",
save_steps=1000,
save_total_limit=3,
push_to_hub=True,
hub_model_id=hub_model_id,
bf16=True,
dataloader_num_workers=4,
report_to="trackio",
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()
trainer.push_to_hub()
```
### Phase 2D: Joint Fusion Fine-tuning (Weeks 7–9)
#### Step 8: nuFormer-style Joint Fusion
```python
# src/models/joint_fusion.py
import torch
import torch.nn as nn
class DCNv2CrossLayer(nn.Module):
"""Single cross layer from DCN V2 (Wang et al. 2021)."""
def __init__(self, dim):
super().__init__()
self.weight = nn.Linear(dim, dim, bias=True)
def forward(self, x0, x):
return x0 * self.weight(x) + x # element-wise product with anchor
class JointFusionModel(nn.Module):
"""nuFormer-style: Transaction Transformer + DCNv2(PLR) β†’ Joint Prediction.
Architecture:
Transaction Sequence β†’ Pre-trained DomainTransformer β†’ user_embedding
Tabular Features β†’ PLR β†’ DCNv2 β†’ tab_embedding
Concatenate β†’ MLP Head β†’ prediction
"""
def __init__(self, transformer_model, n_tabular_features, n_classes=1,
plr_frequencies=64, dcn_layers=3, hidden_dim=256):
super().__init__()
self.transformer = transformer_model # pre-trained, unfrozen for fine-tuning
transformer_dim = transformer_model.config.hidden_size
# Tabular branch: PLR β†’ DCNv2
self.plr = PeriodicLinearReLU(n_tabular_features, plr_frequencies, hidden_dim)
tab_input_dim = n_tabular_features * hidden_dim
self.dcn_layers = nn.ModuleList([
DCNv2CrossLayer(tab_input_dim) for _ in range(dcn_layers)
])
self.dcn_deep = nn.Sequential(
nn.Linear(tab_input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
)
# Joint head
self.head = nn.Sequential(
nn.Linear(transformer_dim + hidden_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim, n_classes),
)
def forward(self, input_ids, attention_mask, tabular_features, labels=None):
# Transaction branch: get last-token embedding
transformer_output = self.transformer(input_ids, attention_mask=attention_mask)
user_embedding = transformer_output.logits[:, -1, :] # last token representation
# Tabular branch: PLR β†’ flatten β†’ DCNv2
tab_embedded = self.plr(tabular_features) # (B, F, D)
tab_flat = tab_embedded.view(tab_embedded.size(0), -1) # (B, F*D)
x0 = tab_flat
x = tab_flat
for cross_layer in self.dcn_layers:
x = cross_layer(x0, x)
tab_output = self.dcn_deep(x) # (B, hidden_dim)
# Joint fusion
combined = torch.cat([user_embedding, tab_output], dim=-1)
logits = self.head(combined)
loss = None
if labels is not None:
loss = nn.functional.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
return {"loss": loss, "logits": logits}
```
### Phase 3: Domain Demos (Weeks 9–12)
| Week | Deliverable | Hardware |
|------|------------|----------|
| 9–10 | **Finance demo:** Transaction tokenizer + 24M model pre-trained on synthetic/public financial data + fraud detection fine-tuning | a10g-large |
| 10–11 | **E-commerce demo:** Event tokenizer + 24M model pre-trained on Amazon review sequences + next-purchase prediction | a10g-large |
| 11–12 | **Evaluation & benchmarking:** Compare domain tokenizer vs. text serialization vs. LightGBM baselines on each domain | a10g-large |
### Phase 4: Scale & Optimize (Weeks 12+)
| Task | Details |
|------|---------|
| Scale to 330M params | Increase `hidden_size` to 1024, train on a100-large |
| `torch.compile()` | Apply to attention and MLP blocks for 10–30% speedup |
| ONNX export | `torch.onnx.export()` for production serving |
| Context window experiments | Ablate 512/1024/2048/4096 context lengths |
| Data source ablation | Test impact of different event types (Nubank found adding low-signal sources hurts) |
| ActionPiece vocabulary | Implement BPE-like cross-field merging on top of per-field tokens |
---
## Appendix A: Framework Usage Across Reference Papers
| Paper | ArXiv | Framework | Verbatim Evidence |
|-------|-------|-----------|-------------------|
| nuFormer (Nubank) | 2507.23267 | PyTorch + HF (inferred) | All dependencies are PyTorch-based |
| TIGER (Google) | 2305.05065 | JAX + T5X | "We use the open-sourced T5X framework" |
| ActionPiece (DeepMind) | 2502.13581 | PyTorch + HF | "HuggingFace Transformers and PyTorch" (Appendix H) |
| RecFormer | 2305.13731 | PyTorch + HF Longformer | "Longformer implemented by Huggingface" (Β§3.1.4) |
| Banking TF | 2410.08243 | PyTorch | "Pytorch backend is used" (Appendix B) |
| PLR Embeddings (Yandex) | 2203.05556 | PyTorch | Repository: pure PyTorch + scikit-learn |
| KL3M Tokenizers | 2503.17247 | HF `tokenizers` + PyTorch | "tokenizers" BPE for HF compatibility |
---
## Appendix B: Head-to-Head Comparison Matrix
| Criterion | PyTorch + HF | JAX/Flax NNX | Keras 3 + JAX |
|-----------|-------------|-------------|---------------|
| Custom domain tokenizer | βœ… `PreTrainedTokenizerFast` | ❌ Build from scratch | ❌ Build from scratch |
| HF Trainer integration | βœ… Native | ❌ Not compatible | ❌ Not compatible |
| Hub push/pull | βœ… `push_to_hub()` | ⚠️ Works, fragmented | ⚠️ Limited |
| PEFT/LoRA | βœ… Drop-in | ❌ Manual | ❌ Manual |
| ONNX export | βœ… One-line | ❌ Multi-hop | ⚠️ TF backend required |
| TGI/vLLM serving | βœ… First-class | ❌ Not supported | ❌ Not supported |
| TPU training | ⚠️ PyTorch/XLA (overhead) | βœ… Native | βœ… Native |
| JIT compilation | βœ… `torch.compile()` | βœ… `@nnx.jit` | βœ… XLA via JAX |
| Dynamic shapes (NLP) | βœ… Native | ❌ Recompiles | ❌ Recompiles |
| Debugging | βœ… Eager mode, std debugger | ⚠️ Challenging | ⚠️ Challenging |
| Reference implementations | 5/6 papers | 1/6 papers | 0/6 papers |
| Community/hiring pool | 🟒 Large | 🟑 Small | 🟑 Small |
---
## Appendix C: Key Code Patterns
### Adding Domain Special Tokens to an Existing Tokenizer
```python
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2") # or any base tokenizer
# Add all domain special tokens
special_tokens = {
"additional_special_tokens": [
# Amount tokens (Nubank-style)
"[AMT_POS]", "[AMT_NEG]",
*[f"[AMT_{i:02d}]" for i in range(21)],
# Calendar tokens
*[f"[MON_{i:02d}]" for i in range(1, 13)],
*[f"[DOW_{i}]" for i in range(7)],
*[f"[DOM_{i:02d}]" for i in range(1, 32)],
*[f"[HOUR_{i:02d}]" for i in range(24)],
]
}
num_added = tokenizer.add_special_tokens(special_tokens)
print(f"Added {num_added} domain tokens. Vocab size: {len(tokenizer)}")
# CRITICAL: resize model embeddings
model.resize_token_embeddings(len(tokenizer))
```
### Registering a Custom Model for Hub Deployment
```python
# In your package's __init__.py or a registration script:
from transformers import AutoConfig, AutoModelForCausalLM
from .configuration_domain_transformer import DomainTransformerConfig
from .modeling_domain_transformer import DomainTransformerForCausalLM
# Register so AutoClass can find your model
AutoConfig.register("domain_transformer", DomainTransformerConfig)
AutoModelForCausalLM.register(DomainTransformerConfig, DomainTransformerForCausalLM)
# Enable push_to_hub with custom code
DomainTransformerConfig.register_for_auto_class()
DomainTransformerForCausalLM.register_for_auto_class("AutoModelForCausalLM")
# Push: uploads configuration.py, modeling.py, config.json, model.safetensors
model.push_to_hub("org/domain-transformer-24m")
# Load anywhere:
model = AutoModelForCausalLM.from_pretrained("org/domain-transformer-24m", trust_remote_code=True)
```
---
*This ADR is a living document and will be updated as implementation progresses.*