# ADR-001: Implementation Framework for domainTokenizer > **Status:** Accepted > **Date:** April 29, 2026 > **Decision:** PyTorch + HuggingFace Transformers as primary framework, with JAX/Flax NNX as future scaling path > **Deciders:** domainTokenizer core team --- ## Table of Contents 1. [Context](#1-context) 2. [Goal](#2-goal) 3. [Options Evaluated](#3-options-evaluated) 4. [Decision](#4-decision) 5. [Trade-offs and Justification](#5-trade-offs-and-justification) 6. [Consequences](#6-consequences) 7. [Implementation Roadmap](#7-implementation-roadmap) 8. [Appendix A: Framework Usage Across Reference Papers](#appendix-a-framework-usage-across-reference-papers) 9. [Appendix B: Head-to-Head Comparison Matrix](#appendix-b-head-to-head-comparison-matrix) 10. [Appendix C: Key Code Patterns](#appendix-c-key-code-patterns) --- ## 1. Context ### What We're Building domainTokenizer is a library for building **small Transformer models (24M–330M parameters)** that process **domain-specific tokens** — financial transactions, e-commerce events, healthcare records — instead of natural language text. The architecture follows the validated pattern from Nubank's nuFormer ([arXiv: 2507.23267](https://arxiv.org/abs/2507.23267)): ``` Domain Events → Custom Tokenizer → GPT-style Transformer → Foundation Model → Downstream Tasks ``` ### The Implementation Question A video from Google for Developers presents the **Keras 3 + JAX/Flax NNX integration** as a potential framework, citing: - **Explicit state management** (Flax NNX) — useful for tracking sequential transaction state - **Custom training loops** — Keras structure + JAX/Optax for domain-specific gradient control - **JIT compilation** (`@nnx.jit`) — high-performance processing of millions of transactions - **Paradigm mixing** — Keras layers for standard components + NNX for custom sequential encoders The question: **Is Keras + JAX/Flax NNX the right framework for domainTokenizer, or is there a better choice?** ### Constraints 1. **Custom tokenizer required:** We need a tokenizer that maps structured fields (amounts, dates, categories) to special tokens — not a standard text tokenizer 2. **Small models:** 24M–330M parameters, not 70B+ — framework overhead matters less than developer velocity 3. **Production deployment:** Models must be servable with low latency for real-time applications (fraud detection, recommendations) 4. **GPU hardware:** Development on A100/A10G GPUs, not TPUs (standard cloud environment) 5. **Team context:** ML engineers familiar with Python, PyTorch, and the HuggingFace ecosystem 6. **Iteration speed:** Need to prototype quickly across multiple domains (finance, e-commerce, healthcare) --- ## 2. Goal Choose an implementation framework that: 1. **Minimizes time from research to working prototype** — weeks, not months 2. **Supports custom domain tokenizers** as first-class citizens 3. **Integrates with the HuggingFace Hub** for model sharing, versioning, and community 4. **Enables production deployment** via standard serving infrastructure (ONNX, TGI, vLLM, etc.) 5. **Scales to 330M parameters** on 4–8 GPUs without heroic engineering 6. **Does not preclude future migration** to JAX/TPU if we need to scale beyond 1B parameters --- ## 3. Options Evaluated ### Option A: PyTorch + HuggingFace Transformers The dominant ecosystem for custom NLP/sequential models. Provides `PreTrainedModel`, `PreTrainedTokenizerFast`, `Trainer`, `push_to_hub`, ONNX export, and integration with TRL, PEFT, Accelerate, DeepSpeed. ### Option B: Keras 3 + JAX Backend + Flax NNX Google's multi-backend framework. Keras provides high-level APIs; JAX provides XLA compilation and functional transforms; Flax NNX provides PyTorch-like stateful modules on top of JAX. ### Option C: Pure JAX + Flax NNX + Optax Skip Keras entirely. Use Flax NNX for model definition, Optax for optimization, Orbax for checkpointing, and Grain/tf.data for data loading. Google's MaxText framework follows this pattern. ### Option D: PyTorch + Custom (no HuggingFace) Use PyTorch directly without the HuggingFace abstraction layer. Full control but no ecosystem integration. --- ## 4. Decision ### Primary: PyTorch + HuggingFace Transformers (Option A) ### Future scaling path: JAX/Flax NNX (Option C) — if and when we need TPU training at >1B parameters --- ## 5. Trade-offs and Justification ### 5.1 What the Reference Papers Actually Use We audited the frameworks used by every paper in the domainTokenizer research corpus. The result is overwhelming: | Paper | Framework | Confidence | |-------|-----------|------------| | **nuFormer** (Nubank) | PyTorch + HF Transformers (inferred) | ~90% | | **TIGER** (Google) | JAX + T5X (official); PyTorch (community reimpl) | 100% | | **ActionPiece** (Google DeepMind) | **PyTorch + HF Transformers** (stated verbatim in paper) | 100% | | **RecFormer** (UCSD/Amazon) | **PyTorch + HF Transformers (Longformer)** (stated verbatim) | 100% | | **Banking Transaction Flow** | **PyTorch** (stated verbatim in appendix) | 100% | | **PLR Embeddings** (Yandex) | **PyTorch** + scikit-learn + Optuna | 100% | **5 of 6 papers use PyTorch.** The sole JAX user (TIGER) was a Google-internal project using T5X, and even its most popular community reimplementation (781⭐) is in PyTorch. Even **Google DeepMind's own ActionPiece** — the paper most relevant to our domain tokenization approach — uses PyTorch + HuggingFace. This is the strongest signal possible. ### 5.2 Custom Tokenizer Story This is the **decisive factor**. domainTokenizer's core innovation is the tokenizer itself. The framework must provide first-class support for custom token vocabularies. **PyTorch + HuggingFace:** - Train custom BPE tokenizer via `tokenizers` library (Rust-backed, fast) - Wrap in `PreTrainedTokenizerFast` → full Trainer compatibility - Add domain special tokens via `add_special_tokens()` → auto-resize embeddings - Push tokenizer to Hub: `tokenizer.push_to_hub("org/my-tokenizer")` - Load anywhere: `AutoTokenizer.from_pretrained("org/my-tokenizer")` - **KL3M** ([arXiv: 2503.17247](https://arxiv.org/abs/2503.17247)) — the gold standard for financial domain tokenizers — is built entirely on this stack **Keras + JAX/Flax NNX:** - No equivalent to `PreTrainedTokenizerFast` - No Hub-integrated tokenizer format - Must build custom tokenizer from scratch with no ecosystem support - No standard serialization/deserialization for domain vocabularies **Verdict:** PyTorch/HF has a **complete, production-tested** custom tokenizer pipeline. JAX/Keras has **nothing** — you'd build everything from scratch. ### 5.3 Production Deployment | Path | PyTorch | JAX/Keras | |------|---------|-----------| | ONNX export | `torch.onnx.export()` — one line | Requires TF backend intermediate or experimental `jax.export` | | TensorRT | ONNX → TRT (standard) | Multi-hop, fragile | | TGI (HuggingFace inference) | First-class | Not supported | | vLLM | First-class | Not supported | | Triton Inference Server | Direct ONNX/TorchScript | Via ONNX (workaround) | | BentoML | Supported | Supported | | Model Hub sharing | `push_to_hub()` → `from_pretrained()` | Works but fragmented (`.msgpack` weights, no Trainer compat) | **Verdict:** PyTorch has **direct, tested paths** to every major serving framework. JAX requires **multiple intermediate conversions**, each introducing failure points. ### 5.4 Training Speed At our scale (24M–330M parameters on 4–8 A100s): | Scenario | PyTorch | JAX | |----------|---------|-----| | Steady-state training throughput | **Comparable** (`torch.compile`) | **Comparable** (XLA JIT) | | Variable-length sequences | **Native** — dynamic shapes | **Problematic** — recompiles on new shapes; must pad to buckets | | Multi-GPU (FSDP) | `accelerate` + FSDP2 — mature | `pmap`/`shard_map` — works but harder to configure | | First-run compilation | Instant (eager mode) | 5–20s JIT compilation overhead | | Debugging | Standard Python debugger | `print` debugging; cryptic XLA errors | **Verdict:** At 330M parameters, training speed is a **wash**. JAX's advantages (XLA kernel fusion, TPU native) only matter at 10B+ parameters on 256+ accelerators. At our scale, **developer velocity dominates throughput**. ### 5.5 The JAX Advantage: When It Would Win JAX/Flax NNX would be the right choice **if**: 1. **Training exclusively on Google TPUs** — JAX is the native TPU compiler; PyTorch/XLA is a port with overhead 2. **Models >1B parameters** — XLA's whole-program optimization shines at scale 3. **Fixed-shape workloads** — images, fixed-length token sequences (no variable-length padding issues) 4. **Need functional transforms** — `vmap` (per-sample gradients), `pmap` (data parallelism), `grad` (higher-order derivatives) 5. **Google Cloud infrastructure** — Vertex AI, TPU VMs, GCS integration For domainTokenizer's current scope (24M–330M, GPU, variable-length sequences, fast iteration), **none of these conditions apply**. ### 5.6 The Keras + JAX Mixing Argument The Google for Developers video argues for mixing Keras layers (high-level) with NNX modules (custom, high-performance). In theory, this lets you: - Use Keras for standard Transformer layers - Use NNX for custom sequential transaction encoders - Get JIT compilation on the NNX parts **In practice, this creates problems:** 1. **Two mental models:** Keras (layer-oriented, `fit/compile`) vs. NNX (functional, explicit state) — context switching slows development 2. **Limited interop documentation:** Keras ↔ NNX examples are thin; edge cases are poorly documented 3. **No HF ecosystem integration:** You lose Trainer, push_to_hub, PEFT, TRL, Accelerate — the entire ecosystem Nubank and ActionPiece rely on 4. **Debugging complexity:** Errors in the Keras↔NNX boundary are hard to diagnose **Better approach with PyTorch:** Use `torch.compile()` on performance-critical modules to get JIT compilation benefits without leaving the PyTorch ecosystem. Write custom `nn.Module` subclasses for domain-specific components. This gives you the same "standard parts + custom parts" architecture without framework mixing. --- ## 6. Consequences ### What We Gain 1. **Immediate access to the entire HuggingFace ecosystem:** Trainer, Accelerate, PEFT (LoRA), TRL, Evaluate, push_to_hub, from_pretrained, ONNX export, TGI serving 2. **Copy-paste from reference implementations:** ActionPiece, RecFormer, Banking TF, and PLR embeddings are all PyTorch — we can directly reuse their code 3. **KL3M tokenizer as starting point:** The best financial domain tokenizer already exists in PyTorch/HF format at `alea-institute/kl3m-004-128k-cased` 4. **Standard production deployment:** ONNX → TensorRT → Triton, or direct TGI/vLLM serving 5. **Community and hiring:** PyTorch is the dominant ML framework; finding contributors and documentation is easy 6. **`torch.compile()` for performance:** When we need JIT compilation on hot paths, `torch.compile()` provides 10–30% speedups without leaving the ecosystem ### What We Accept 1. **No native TPU support:** If we later need to train on Google TPUs, we'll need PyTorch/XLA (slower than native JAX) or migrate the model code 2. **No functional transforms:** `vmap` (per-sample gradients) isn't available without `functorch` (experimental). If we need advanced gradient manipulation for meta-learning or Nested Learning (HOPE-style), JAX would be better 3. **Potential future migration cost:** If we scale beyond 1B parameters and move to TPUs, we'll need to rewrite model code in Flax NNX. This is mitigated by keeping model definitions clean and modular ### Migration Strategy (If Needed Later) If domainTokenizer grows to >1B parameters and we need TPU training: 1. **Tokenizer layer stays in Python/HF:** Tokenizer is framework-agnostic — it produces integer sequences regardless of whether the model is PyTorch or JAX 2. **Model architecture translates 1:1:** PyTorch `nn.Module` → Flax NNX `nnx.Module` mapping is straightforward for standard Transformer components 3. **Training loop changes:** PyTorch Trainer → custom Flax NNX training loop with Optax 4. **Reference:** Google's MaxText (`github.com/google/maxtext`) provides production-grade JAX Transformer patterns we can follow **Estimated migration effort:** 2–4 weeks for a clean, well-separated codebase. --- ## 7. Implementation Roadmap ### Phase 2A: Core Tokenizer Library (Weeks 1–3) #### Step 1: Domain Schema Definition Create a declarative schema format that describes the fields in a domain's event data: ```python # src/tokenizers/schema.py from dataclasses import dataclass, field from enum import Enum from typing import List, Optional class FieldType(Enum): NUMERICAL_CONTINUOUS = "numerical_continuous" # prices, amounts → magnitude bins NUMERICAL_DISCRETE = "numerical_discrete" # quantities → small fixed vocab CATEGORICAL_FIXED = "categorical_fixed" # categories, days of week → direct mapping CATEGORICAL_ENTITY = "categorical_entity" # products, merchants → Semantic IDs (RQ-VAE) TEMPORAL = "temporal" # timestamps → calendar decomposition TEXT = "text" # descriptions → BPE subwords SIGN = "sign" # credit/debit → 2 tokens @dataclass class FieldSpec: name: str field_type: FieldType vocab_size: Optional[int] = None # for fixed categorical n_bins: int = 21 # for numerical (Nubank uses 21) calendar_fields: List[str] = field( # for temporal default_factory=lambda: ["month", "dow", "dom", "hour"] ) @dataclass class DomainSchema: name: str # e.g., "ecommerce", "finance" fields: List[FieldSpec] # ordered list of fields per event @property def special_token_count(self) -> int: """Total domain-specific special tokens needed.""" count = 0 for f in self.fields: if f.field_type == FieldType.SIGN: count += 2 elif f.field_type == FieldType.NUMERICAL_CONTINUOUS: count += f.n_bins elif f.field_type == FieldType.CATEGORICAL_FIXED: count += f.vocab_size elif f.field_type == FieldType.TEMPORAL: count += sum({ "month": 12, "dow": 7, "dom": 31, "hour": 24, "quarter": 4, "year": 10 }.get(cf, 0) for cf in f.calendar_fields) return count # Example: Nubank-style financial schema FINANCE_SCHEMA = DomainSchema( name="finance", fields=[ FieldSpec("amount_sign", FieldType.SIGN), FieldSpec("amount", FieldType.NUMERICAL_CONTINUOUS, n_bins=21), FieldSpec("timestamp", FieldType.TEMPORAL, calendar_fields=["month", "dow", "dom", "hour"]), FieldSpec("description", FieldType.TEXT), ] ) # Example: E-commerce schema ECOMMERCE_SCHEMA = DomainSchema( name="ecommerce", fields=[ FieldSpec("event_type", FieldType.CATEGORICAL_FIXED, vocab_size=5), FieldSpec("price", FieldType.NUMERICAL_CONTINUOUS, n_bins=21), FieldSpec("quantity", FieldType.NUMERICAL_DISCRETE, vocab_size=11), FieldSpec("category_l1", FieldType.CATEGORICAL_FIXED, vocab_size=30), FieldSpec("category_l2", FieldType.CATEGORICAL_FIXED, vocab_size=200), FieldSpec("timestamp", FieldType.TEMPORAL, calendar_fields=["month", "dow", "dom", "hour"]), FieldSpec("product_title", FieldType.TEXT), ] ) ``` #### Step 2: Per-Field Tokenizers Implement each field type tokenizer as a standalone module: ```python # src/tokenizers/field_tokenizers.py import numpy as np from typing import List class SignTokenizer: """Tokenizes sign of a numerical value (credit/debit, inflow/outflow).""" def __init__(self, prefix: str = "SIGN"): self.tokens = [f"[{prefix}_POS]", f"[{prefix}_NEG]"] def __call__(self, value: float) -> str: return self.tokens[0] if value >= 0 else self.tokens[1] @property def vocab(self) -> List[str]: return self.tokens class MagnitudeBucketTokenizer: """Quantizes continuous values into bins (Nubank-style). Uses quantile-based binning on the training distribution. Follows the Relative Magnitude Tokenization principle from TP-BERTa. """ def __init__(self, n_bins: int = 21, prefix: str = "AMT"): self.n_bins = n_bins self.prefix = prefix self.bin_edges = None # fitted from data def fit(self, values: np.ndarray): """Compute bin edges from training data using quantiles.""" # Use absolute values for magnitude binning abs_vals = np.abs(values[~np.isnan(values)]) quantiles = np.linspace(0, 100, self.n_bins + 1) self.bin_edges = np.percentile(abs_vals, quantiles) return self def __call__(self, value: float) -> str: if self.bin_edges is None: raise ValueError("Tokenizer not fitted. Call .fit() first.") bin_idx = np.searchsorted(self.bin_edges[1:-1], abs(value)) return f"[{self.prefix}_{bin_idx:02d}]" @property def vocab(self) -> List[str]: return [f"[{self.prefix}_{i:02d}]" for i in range(self.n_bins)] class CalendarTokenizer: """Decomposes timestamps into calendar components (Nubank-style).""" FIELD_VOCABS = { "month": ([f"[MON_{i:02d}]" for i in range(1, 13)], lambda dt: dt.month - 1), "dow": ([f"[DOW_{i}]" for i in range(7)], lambda dt: dt.weekday()), "dom": ([f"[DOM_{i:02d}]" for i in range(1, 32)], lambda dt: dt.day - 1), "hour": ([f"[HOUR_{i:02d}]" for i in range(24)], lambda dt: dt.hour), } def __init__(self, fields: List[str] = None): self.fields = fields or ["month", "dow", "dom", "hour"] def __call__(self, timestamp) -> List[str]: tokens = [] for field_name in self.fields: vocab, extractor = self.FIELD_VOCABS[field_name] idx = extractor(timestamp) tokens.append(vocab[idx]) return tokens @property def vocab(self) -> List[str]: all_tokens = [] for field_name in self.fields: all_tokens.extend(self.FIELD_VOCABS[field_name][0]) return all_tokens class CategoricalTokenizer: """Maps categorical values to fixed vocabulary tokens.""" def __init__(self, categories: List[str], prefix: str = "CAT"): self.prefix = prefix self.token_map = {cat: f"[{prefix}_{i:03d}]" for i, cat in enumerate(categories)} self.unk_token = f"[{prefix}_UNK]" def __call__(self, value: str) -> str: return self.token_map.get(value, self.unk_token) @property def vocab(self) -> List[str]: return list(self.token_map.values()) + [self.unk_token] ``` #### Step 3: Composite Domain Tokenizer Assemble per-field tokenizers into a complete domain tokenizer, wrapped as `PreTrainedTokenizerFast`: ```python # src/tokenizers/domain_tokenizer.py from tokenizers import Tokenizer, models, trainers, pre_tokenizers from transformers import PreTrainedTokenizerFast class DomainTokenizerBuilder: """Builds a HuggingFace-compatible tokenizer from a DomainSchema.""" def __init__(self, schema: DomainSchema): self.schema = schema self.field_tokenizers = {} # name → field tokenizer instance self._build_field_tokenizers() def _build_field_tokenizers(self): for field_spec in self.schema.fields: if field_spec.field_type == FieldType.SIGN: self.field_tokenizers[field_spec.name] = SignTokenizer(field_spec.name.upper()) elif field_spec.field_type == FieldType.NUMERICAL_CONTINUOUS: self.field_tokenizers[field_spec.name] = MagnitudeBucketTokenizer( n_bins=field_spec.n_bins, prefix=field_spec.name.upper() ) elif field_spec.field_type == FieldType.TEMPORAL: self.field_tokenizers[field_spec.name] = CalendarTokenizer(field_spec.calendar_fields) # ... other types def fit(self, data): """Fit data-dependent tokenizers (magnitude bins, etc.).""" for field_spec in self.schema.fields: if field_spec.field_type == FieldType.NUMERICAL_CONTINUOUS: values = [getattr(event, field_spec.name) for event in data] self.field_tokenizers[field_spec.name].fit(np.array(values)) return self def build_hf_tokenizer(self, text_corpus=None, bpe_vocab_size=8000) -> PreTrainedTokenizerFast: """Build a complete HuggingFace tokenizer. 1. Collect all domain special tokens 2. Train BPE on text fields (if any) 3. Merge into a single PreTrainedTokenizerFast """ # Collect all special tokens from field tokenizers all_special_tokens = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "[BOS]", "[EOS]"] for name, tok in self.field_tokenizers.items(): if hasattr(tok, 'vocab'): all_special_tokens.extend(tok.vocab) # Train BPE on text fields bpe_tokenizer = Tokenizer(models.BPE(unk_token="[UNK]")) bpe_tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False) trainer = trainers.BpeTrainer( vocab_size=bpe_vocab_size, special_tokens=all_special_tokens, min_frequency=2, ) if text_corpus: bpe_tokenizer.train_from_iterator(text_corpus, trainer=trainer) # Wrap as HuggingFace tokenizer hf_tokenizer = PreTrainedTokenizerFast( tokenizer_object=bpe_tokenizer, bos_token="[BOS]", eos_token="[EOS]", pad_token="[PAD]", unk_token="[UNK]", mask_token="[MASK]", ) return hf_tokenizer def tokenize_event(self, event) -> List[str]: """Convert a single domain event into a list of token strings.""" tokens = [] for field_spec in self.schema.fields: value = getattr(event, field_spec.name, None) if value is None: tokens.append("[UNK]") continue tok = self.field_tokenizers[field_spec.name] result = tok(value) if isinstance(result, list): tokens.extend(result) else: tokens.append(result) return tokens ``` ### Phase 2B: Model Architecture (Weeks 3–5) #### Step 4: GPT-style Causal Transformer (NoPE) Implement as a HuggingFace-compatible `PreTrainedModel`: ```python # src/models/configuration_domain_transformer.py from transformers import PretrainedConfig class DomainTransformerConfig(PretrainedConfig): model_type = "domain_transformer" def __init__( self, vocab_size: int = 32000, hidden_size: int = 256, # 256 = 24M params, 1024 = 330M (Nubank sizes) num_hidden_layers: int = 24, # Nubank uses 24 for both sizes num_attention_heads: int = 16, # Nubank uses 16 for both sizes intermediate_size: int = None, # defaults to 4 * hidden_size max_position_embeddings: int = 2048, dropout: float = 0.1, use_positional_encoding: bool = False, # NoPE by default! **kwargs ): self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.intermediate_size = intermediate_size or 4 * hidden_size self.max_position_embeddings = max_position_embeddings self.dropout = dropout self.use_positional_encoding = use_positional_encoding super().__init__(**kwargs) ``` ```python # src/models/modeling_domain_transformer.py import torch import torch.nn as nn from transformers import PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast class DomainTransformerBlock(nn.Module): def __init__(self, config): super().__init__() self.ln1 = nn.LayerNorm(config.hidden_size) self.attn = nn.MultiheadAttention( config.hidden_size, config.num_attention_heads, dropout=config.dropout, batch_first=True ) self.ln2 = nn.LayerNorm(config.hidden_size) self.mlp = nn.Sequential( nn.Linear(config.hidden_size, config.intermediate_size), nn.GELU(), nn.Linear(config.intermediate_size, config.hidden_size), nn.Dropout(config.dropout), ) def forward(self, x, attn_mask=None): # Pre-norm architecture h = self.ln1(x) h, _ = self.attn(h, h, h, attn_mask=attn_mask, is_causal=True) x = x + h x = x + self.mlp(self.ln2(x)) return x class DomainTransformerForCausalLM(PreTrainedModel): config_class = DomainTransformerConfig def __init__(self, config): super().__init__(config) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) # NoPE: no positional encoding by default (Kazemnejad et al. 2023) if config.use_positional_encoding: self.embed_positions = nn.Embedding( config.max_position_embeddings, config.hidden_size ) else: self.embed_positions = None self.drop = nn.Dropout(config.dropout) self.blocks = nn.ModuleList([ DomainTransformerBlock(config) for _ in range(config.num_hidden_layers) ]) self.ln_f = nn.LayerNorm(config.hidden_size) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Weight tying (standard for small models) self.lm_head.weight = self.embed_tokens.weight self.post_init() def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): x = self.embed_tokens(input_ids) if self.embed_positions is not None: positions = torch.arange(input_ids.size(1), device=input_ids.device) x = x + self.embed_positions(positions) x = self.drop(x) for block in self.blocks: x = block(x, attn_mask=attention_mask) x = self.ln_f(x) logits = self.lm_head(x) loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = nn.functional.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100, ) return CausalLMOutputWithPast(loss=loss, logits=logits) # Register with AutoClass for Hub compatibility DomainTransformerConfig.register_for_auto_class() DomainTransformerForCausalLM.register_for_auto_class("AutoModelForCausalLM") ``` #### Step 5: PLR Numerical Embeddings (for Joint Fusion) Port from Yandex's implementation: ```python # src/models/plr_embeddings.py import torch import torch.nn as nn import math class PeriodicLinearReLU(nn.Module): """PLR numerical embeddings (Gorishniy et al. 2022). Maps scalar x → [sin(2π·w·x + b), cos(2π·w·x + b)] → Linear → ReLU Frequencies w and phases b are LEARNED parameters. """ def __init__(self, n_features: int, n_frequencies: int = 64, embedding_dim: int = 64): super().__init__() self.n_features = n_features self.n_frequencies = n_frequencies # Learnable frequencies and phases (per feature) self.frequencies = nn.Parameter( torch.randn(n_features, n_frequencies) * 0.01 ) self.phases = nn.Parameter( torch.zeros(n_features, n_frequencies) ) # Linear projection: 2*n_frequencies → embedding_dim self.linear = nn.Linear(2 * n_frequencies, embedding_dim) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: (batch, n_features) — raw scalar feature values Returns: (batch, n_features, embedding_dim) """ # x: (B, F) → (B, F, 1) x = x.unsqueeze(-1) # Periodic encoding: (B, F, n_freq) angles = 2 * math.pi * self.frequencies.unsqueeze(0) * x + self.phases.unsqueeze(0) periodic = torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1) # (B, F, 2*n_freq) # Linear + ReLU: (B, F, embedding_dim) return torch.relu(self.linear(periodic)) ``` ### Phase 2C: Pre-training (Weeks 5–7) #### Step 6: Data Pipeline ```python # src/training/data_pipeline.py from torch.utils.data import Dataset from typing import List class DomainSequenceDataset(Dataset): """Converts user event sequences into token sequences for CLM training.""" def __init__(self, user_sequences, tokenizer_builder, hf_tokenizer, max_length=2048): self.user_sequences = user_sequences self.tokenizer_builder = tokenizer_builder self.hf_tokenizer = hf_tokenizer self.max_length = max_length def __len__(self): return len(self.user_sequences) def __getitem__(self, idx): events = self.user_sequences[idx] # Tokenize each event into token strings token_strings = [] for event in events: event_tokens = self.tokenizer_builder.tokenize_event(event) token_strings.extend(event_tokens) # Convert token strings to IDs via HF tokenizer encoding = self.hf_tokenizer( " ".join(token_strings), max_length=self.max_length, truncation=True, padding="max_length", return_tensors="pt", ) input_ids = encoding["input_ids"].squeeze(0) return { "input_ids": input_ids, "labels": input_ids.clone(), # CLM: labels = input shifted by 1 "attention_mask": encoding["attention_mask"].squeeze(0), } ``` #### Step 7: Pre-training with HuggingFace Trainer ```python # src/training/pretrain.py from transformers import Trainer, TrainingArguments def pretrain_domain_model( model, train_dataset, eval_dataset=None, output_dir="./checkpoints", hub_model_id="org/domain-model-24m", num_epochs=3, batch_size=64, learning_rate=3e-4, context_length=2048, ): training_args = TrainingArguments( output_dir=output_dir, num_train_epochs=num_epochs, per_device_train_batch_size=batch_size, gradient_accumulation_steps=4, learning_rate=learning_rate, lr_scheduler_type="cosine", warmup_ratio=0.05, weight_decay=0.01, logging_strategy="steps", logging_steps=100, logging_first_step=True, disable_tqdm=True, # plain text logging for cloud eval_strategy="steps" if eval_dataset else "no", eval_steps=500, save_strategy="steps", save_steps=1000, save_total_limit=3, push_to_hub=True, hub_model_id=hub_model_id, bf16=True, dataloader_num_workers=4, report_to="trackio", ) trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, ) trainer.train() trainer.push_to_hub() ``` ### Phase 2D: Joint Fusion Fine-tuning (Weeks 7–9) #### Step 8: nuFormer-style Joint Fusion ```python # src/models/joint_fusion.py import torch import torch.nn as nn class DCNv2CrossLayer(nn.Module): """Single cross layer from DCN V2 (Wang et al. 2021).""" def __init__(self, dim): super().__init__() self.weight = nn.Linear(dim, dim, bias=True) def forward(self, x0, x): return x0 * self.weight(x) + x # element-wise product with anchor class JointFusionModel(nn.Module): """nuFormer-style: Transaction Transformer + DCNv2(PLR) → Joint Prediction. Architecture: Transaction Sequence → Pre-trained DomainTransformer → user_embedding Tabular Features → PLR → DCNv2 → tab_embedding Concatenate → MLP Head → prediction """ def __init__(self, transformer_model, n_tabular_features, n_classes=1, plr_frequencies=64, dcn_layers=3, hidden_dim=256): super().__init__() self.transformer = transformer_model # pre-trained, unfrozen for fine-tuning transformer_dim = transformer_model.config.hidden_size # Tabular branch: PLR → DCNv2 self.plr = PeriodicLinearReLU(n_tabular_features, plr_frequencies, hidden_dim) tab_input_dim = n_tabular_features * hidden_dim self.dcn_layers = nn.ModuleList([ DCNv2CrossLayer(tab_input_dim) for _ in range(dcn_layers) ]) self.dcn_deep = nn.Sequential( nn.Linear(tab_input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), ) # Joint head self.head = nn.Sequential( nn.Linear(transformer_dim + hidden_dim, hidden_dim), nn.ReLU(), nn.Dropout(0.1), nn.Linear(hidden_dim, n_classes), ) def forward(self, input_ids, attention_mask, tabular_features, labels=None): # Transaction branch: get last-token embedding transformer_output = self.transformer(input_ids, attention_mask=attention_mask) user_embedding = transformer_output.logits[:, -1, :] # last token representation # Tabular branch: PLR → flatten → DCNv2 tab_embedded = self.plr(tabular_features) # (B, F, D) tab_flat = tab_embedded.view(tab_embedded.size(0), -1) # (B, F*D) x0 = tab_flat x = tab_flat for cross_layer in self.dcn_layers: x = cross_layer(x0, x) tab_output = self.dcn_deep(x) # (B, hidden_dim) # Joint fusion combined = torch.cat([user_embedding, tab_output], dim=-1) logits = self.head(combined) loss = None if labels is not None: loss = nn.functional.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float()) return {"loss": loss, "logits": logits} ``` ### Phase 3: Domain Demos (Weeks 9–12) | Week | Deliverable | Hardware | |------|------------|----------| | 9–10 | **Finance demo:** Transaction tokenizer + 24M model pre-trained on synthetic/public financial data + fraud detection fine-tuning | a10g-large | | 10–11 | **E-commerce demo:** Event tokenizer + 24M model pre-trained on Amazon review sequences + next-purchase prediction | a10g-large | | 11–12 | **Evaluation & benchmarking:** Compare domain tokenizer vs. text serialization vs. LightGBM baselines on each domain | a10g-large | ### Phase 4: Scale & Optimize (Weeks 12+) | Task | Details | |------|---------| | Scale to 330M params | Increase `hidden_size` to 1024, train on a100-large | | `torch.compile()` | Apply to attention and MLP blocks for 10–30% speedup | | ONNX export | `torch.onnx.export()` for production serving | | Context window experiments | Ablate 512/1024/2048/4096 context lengths | | Data source ablation | Test impact of different event types (Nubank found adding low-signal sources hurts) | | ActionPiece vocabulary | Implement BPE-like cross-field merging on top of per-field tokens | --- ## Appendix A: Framework Usage Across Reference Papers | Paper | ArXiv | Framework | Verbatim Evidence | |-------|-------|-----------|-------------------| | nuFormer (Nubank) | 2507.23267 | PyTorch + HF (inferred) | All dependencies are PyTorch-based | | TIGER (Google) | 2305.05065 | JAX + T5X | "We use the open-sourced T5X framework" | | ActionPiece (DeepMind) | 2502.13581 | PyTorch + HF | "HuggingFace Transformers and PyTorch" (Appendix H) | | RecFormer | 2305.13731 | PyTorch + HF Longformer | "Longformer implemented by Huggingface" (§3.1.4) | | Banking TF | 2410.08243 | PyTorch | "Pytorch backend is used" (Appendix B) | | PLR Embeddings (Yandex) | 2203.05556 | PyTorch | Repository: pure PyTorch + scikit-learn | | KL3M Tokenizers | 2503.17247 | HF `tokenizers` + PyTorch | "tokenizers" BPE for HF compatibility | --- ## Appendix B: Head-to-Head Comparison Matrix | Criterion | PyTorch + HF | JAX/Flax NNX | Keras 3 + JAX | |-----------|-------------|-------------|---------------| | Custom domain tokenizer | ✅ `PreTrainedTokenizerFast` | ❌ Build from scratch | ❌ Build from scratch | | HF Trainer integration | ✅ Native | ❌ Not compatible | ❌ Not compatible | | Hub push/pull | ✅ `push_to_hub()` | ⚠️ Works, fragmented | ⚠️ Limited | | PEFT/LoRA | ✅ Drop-in | ❌ Manual | ❌ Manual | | ONNX export | ✅ One-line | ❌ Multi-hop | ⚠️ TF backend required | | TGI/vLLM serving | ✅ First-class | ❌ Not supported | ❌ Not supported | | TPU training | ⚠️ PyTorch/XLA (overhead) | ✅ Native | ✅ Native | | JIT compilation | ✅ `torch.compile()` | ✅ `@nnx.jit` | ✅ XLA via JAX | | Dynamic shapes (NLP) | ✅ Native | ❌ Recompiles | ❌ Recompiles | | Debugging | ✅ Eager mode, std debugger | ⚠️ Challenging | ⚠️ Challenging | | Reference implementations | 5/6 papers | 1/6 papers | 0/6 papers | | Community/hiring pool | 🟢 Large | 🟡 Small | 🟡 Small | --- ## Appendix C: Key Code Patterns ### Adding Domain Special Tokens to an Existing Tokenizer ```python from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("gpt2") # or any base tokenizer # Add all domain special tokens special_tokens = { "additional_special_tokens": [ # Amount tokens (Nubank-style) "[AMT_POS]", "[AMT_NEG]", *[f"[AMT_{i:02d}]" for i in range(21)], # Calendar tokens *[f"[MON_{i:02d}]" for i in range(1, 13)], *[f"[DOW_{i}]" for i in range(7)], *[f"[DOM_{i:02d}]" for i in range(1, 32)], *[f"[HOUR_{i:02d}]" for i in range(24)], ] } num_added = tokenizer.add_special_tokens(special_tokens) print(f"Added {num_added} domain tokens. Vocab size: {len(tokenizer)}") # CRITICAL: resize model embeddings model.resize_token_embeddings(len(tokenizer)) ``` ### Registering a Custom Model for Hub Deployment ```python # In your package's __init__.py or a registration script: from transformers import AutoConfig, AutoModelForCausalLM from .configuration_domain_transformer import DomainTransformerConfig from .modeling_domain_transformer import DomainTransformerForCausalLM # Register so AutoClass can find your model AutoConfig.register("domain_transformer", DomainTransformerConfig) AutoModelForCausalLM.register(DomainTransformerConfig, DomainTransformerForCausalLM) # Enable push_to_hub with custom code DomainTransformerConfig.register_for_auto_class() DomainTransformerForCausalLM.register_for_auto_class("AutoModelForCausalLM") # Push: uploads configuration.py, modeling.py, config.json, model.safetensors model.push_to_hub("org/domain-transformer-24m") # Load anywhere: model = AutoModelForCausalLM.from_pretrained("org/domain-transformer-24m", trust_remote_code=True) ``` --- *This ADR is a living document and will be updated as implementation progresses.*