Add ADR-001: Implementation framework decision with detailed roadmap
Browse files
docs/adr/ADR-001-implementation-framework.md
ADDED
|
@@ -0,0 +1,978 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ADR-001: Implementation Framework for domainTokenizer
|
| 2 |
+
|
| 3 |
+
> **Status:** Accepted
|
| 4 |
+
> **Date:** April 29, 2026
|
| 5 |
+
> **Decision:** PyTorch + HuggingFace Transformers as primary framework, with JAX/Flax NNX as future scaling path
|
| 6 |
+
> **Deciders:** domainTokenizer core team
|
| 7 |
+
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
## Table of Contents
|
| 11 |
+
|
| 12 |
+
1. [Context](#1-context)
|
| 13 |
+
2. [Goal](#2-goal)
|
| 14 |
+
3. [Options Evaluated](#3-options-evaluated)
|
| 15 |
+
4. [Decision](#4-decision)
|
| 16 |
+
5. [Trade-offs and Justification](#5-trade-offs-and-justification)
|
| 17 |
+
6. [Consequences](#6-consequences)
|
| 18 |
+
7. [Implementation Roadmap](#7-implementation-roadmap)
|
| 19 |
+
8. [Appendix A: Framework Usage Across Reference Papers](#appendix-a-framework-usage-across-reference-papers)
|
| 20 |
+
9. [Appendix B: Head-to-Head Comparison Matrix](#appendix-b-head-to-head-comparison-matrix)
|
| 21 |
+
10. [Appendix C: Key Code Patterns](#appendix-c-key-code-patterns)
|
| 22 |
+
|
| 23 |
+
---
|
| 24 |
+
|
| 25 |
+
## 1. Context
|
| 26 |
+
|
| 27 |
+
### What We're Building
|
| 28 |
+
|
| 29 |
+
domainTokenizer is a library for building **small Transformer models (24Mβ330M parameters)** that process **domain-specific tokens** β financial transactions, e-commerce events, healthcare records β instead of natural language text. The architecture follows the validated pattern from Nubank's nuFormer ([arXiv: 2507.23267](https://arxiv.org/abs/2507.23267)):
|
| 30 |
+
|
| 31 |
+
```
|
| 32 |
+
Domain Events β Custom Tokenizer β GPT-style Transformer β Foundation Model β Downstream Tasks
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
### The Implementation Question
|
| 36 |
+
|
| 37 |
+
A video from Google for Developers presents the **Keras 3 + JAX/Flax NNX integration** as a potential framework, citing:
|
| 38 |
+
- **Explicit state management** (Flax NNX) β useful for tracking sequential transaction state
|
| 39 |
+
- **Custom training loops** β Keras structure + JAX/Optax for domain-specific gradient control
|
| 40 |
+
- **JIT compilation** (`@nnx.jit`) β high-performance processing of millions of transactions
|
| 41 |
+
- **Paradigm mixing** β Keras layers for standard components + NNX for custom sequential encoders
|
| 42 |
+
|
| 43 |
+
The question: **Is Keras + JAX/Flax NNX the right framework for domainTokenizer, or is there a better choice?**
|
| 44 |
+
|
| 45 |
+
### Constraints
|
| 46 |
+
|
| 47 |
+
1. **Custom tokenizer required:** We need a tokenizer that maps structured fields (amounts, dates, categories) to special tokens β not a standard text tokenizer
|
| 48 |
+
2. **Small models:** 24Mβ330M parameters, not 70B+ β framework overhead matters less than developer velocity
|
| 49 |
+
3. **Production deployment:** Models must be servable with low latency for real-time applications (fraud detection, recommendations)
|
| 50 |
+
4. **GPU hardware:** Development on A100/A10G GPUs, not TPUs (standard cloud environment)
|
| 51 |
+
5. **Team context:** ML engineers familiar with Python, PyTorch, and the HuggingFace ecosystem
|
| 52 |
+
6. **Iteration speed:** Need to prototype quickly across multiple domains (finance, e-commerce, healthcare)
|
| 53 |
+
|
| 54 |
+
---
|
| 55 |
+
|
| 56 |
+
## 2. Goal
|
| 57 |
+
|
| 58 |
+
Choose an implementation framework that:
|
| 59 |
+
|
| 60 |
+
1. **Minimizes time from research to working prototype** β weeks, not months
|
| 61 |
+
2. **Supports custom domain tokenizers** as first-class citizens
|
| 62 |
+
3. **Integrates with the HuggingFace Hub** for model sharing, versioning, and community
|
| 63 |
+
4. **Enables production deployment** via standard serving infrastructure (ONNX, TGI, vLLM, etc.)
|
| 64 |
+
5. **Scales to 330M parameters** on 4β8 GPUs without heroic engineering
|
| 65 |
+
6. **Does not preclude future migration** to JAX/TPU if we need to scale beyond 1B parameters
|
| 66 |
+
|
| 67 |
+
---
|
| 68 |
+
|
| 69 |
+
## 3. Options Evaluated
|
| 70 |
+
|
| 71 |
+
### Option A: PyTorch + HuggingFace Transformers
|
| 72 |
+
|
| 73 |
+
The dominant ecosystem for custom NLP/sequential models. Provides `PreTrainedModel`, `PreTrainedTokenizerFast`, `Trainer`, `push_to_hub`, ONNX export, and integration with TRL, PEFT, Accelerate, DeepSpeed.
|
| 74 |
+
|
| 75 |
+
### Option B: Keras 3 + JAX Backend + Flax NNX
|
| 76 |
+
|
| 77 |
+
Google's multi-backend framework. Keras provides high-level APIs; JAX provides XLA compilation and functional transforms; Flax NNX provides PyTorch-like stateful modules on top of JAX.
|
| 78 |
+
|
| 79 |
+
### Option C: Pure JAX + Flax NNX + Optax
|
| 80 |
+
|
| 81 |
+
Skip Keras entirely. Use Flax NNX for model definition, Optax for optimization, Orbax for checkpointing, and Grain/tf.data for data loading. Google's MaxText framework follows this pattern.
|
| 82 |
+
|
| 83 |
+
### Option D: PyTorch + Custom (no HuggingFace)
|
| 84 |
+
|
| 85 |
+
Use PyTorch directly without the HuggingFace abstraction layer. Full control but no ecosystem integration.
|
| 86 |
+
|
| 87 |
+
---
|
| 88 |
+
|
| 89 |
+
## 4. Decision
|
| 90 |
+
|
| 91 |
+
### Primary: PyTorch + HuggingFace Transformers (Option A)
|
| 92 |
+
|
| 93 |
+
### Future scaling path: JAX/Flax NNX (Option C) β if and when we need TPU training at >1B parameters
|
| 94 |
+
|
| 95 |
+
---
|
| 96 |
+
|
| 97 |
+
## 5. Trade-offs and Justification
|
| 98 |
+
|
| 99 |
+
### 5.1 What the Reference Papers Actually Use
|
| 100 |
+
|
| 101 |
+
We audited the frameworks used by every paper in the domainTokenizer research corpus. The result is overwhelming:
|
| 102 |
+
|
| 103 |
+
| Paper | Framework | Confidence |
|
| 104 |
+
|-------|-----------|------------|
|
| 105 |
+
| **nuFormer** (Nubank) | PyTorch + HF Transformers (inferred) | ~90% |
|
| 106 |
+
| **TIGER** (Google) | JAX + T5X (official); PyTorch (community reimpl) | 100% |
|
| 107 |
+
| **ActionPiece** (Google DeepMind) | **PyTorch + HF Transformers** (stated verbatim in paper) | 100% |
|
| 108 |
+
| **RecFormer** (UCSD/Amazon) | **PyTorch + HF Transformers (Longformer)** (stated verbatim) | 100% |
|
| 109 |
+
| **Banking Transaction Flow** | **PyTorch** (stated verbatim in appendix) | 100% |
|
| 110 |
+
| **PLR Embeddings** (Yandex) | **PyTorch** + scikit-learn + Optuna | 100% |
|
| 111 |
+
|
| 112 |
+
**5 of 6 papers use PyTorch.** The sole JAX user (TIGER) was a Google-internal project using T5X, and even its most popular community reimplementation (781β) is in PyTorch.
|
| 113 |
+
|
| 114 |
+
Even **Google DeepMind's own ActionPiece** β the paper most relevant to our domain tokenization approach β uses PyTorch + HuggingFace. This is the strongest signal possible.
|
| 115 |
+
|
| 116 |
+
### 5.2 Custom Tokenizer Story
|
| 117 |
+
|
| 118 |
+
This is the **decisive factor**. domainTokenizer's core innovation is the tokenizer itself. The framework must provide first-class support for custom token vocabularies.
|
| 119 |
+
|
| 120 |
+
**PyTorch + HuggingFace:**
|
| 121 |
+
- Train custom BPE tokenizer via `tokenizers` library (Rust-backed, fast)
|
| 122 |
+
- Wrap in `PreTrainedTokenizerFast` β full Trainer compatibility
|
| 123 |
+
- Add domain special tokens via `add_special_tokens()` β auto-resize embeddings
|
| 124 |
+
- Push tokenizer to Hub: `tokenizer.push_to_hub("org/my-tokenizer")`
|
| 125 |
+
- Load anywhere: `AutoTokenizer.from_pretrained("org/my-tokenizer")`
|
| 126 |
+
- **KL3M** ([arXiv: 2503.17247](https://arxiv.org/abs/2503.17247)) β the gold standard for financial domain tokenizers β is built entirely on this stack
|
| 127 |
+
|
| 128 |
+
**Keras + JAX/Flax NNX:**
|
| 129 |
+
- No equivalent to `PreTrainedTokenizerFast`
|
| 130 |
+
- No Hub-integrated tokenizer format
|
| 131 |
+
- Must build custom tokenizer from scratch with no ecosystem support
|
| 132 |
+
- No standard serialization/deserialization for domain vocabularies
|
| 133 |
+
|
| 134 |
+
**Verdict:** PyTorch/HF has a **complete, production-tested** custom tokenizer pipeline. JAX/Keras has **nothing** β you'd build everything from scratch.
|
| 135 |
+
|
| 136 |
+
### 5.3 Production Deployment
|
| 137 |
+
|
| 138 |
+
| Path | PyTorch | JAX/Keras |
|
| 139 |
+
|------|---------|-----------|
|
| 140 |
+
| ONNX export | `torch.onnx.export()` β one line | Requires TF backend intermediate or experimental `jax.export` |
|
| 141 |
+
| TensorRT | ONNX β TRT (standard) | Multi-hop, fragile |
|
| 142 |
+
| TGI (HuggingFace inference) | First-class | Not supported |
|
| 143 |
+
| vLLM | First-class | Not supported |
|
| 144 |
+
| Triton Inference Server | Direct ONNX/TorchScript | Via ONNX (workaround) |
|
| 145 |
+
| BentoML | Supported | Supported |
|
| 146 |
+
| Model Hub sharing | `push_to_hub()` β `from_pretrained()` | Works but fragmented (`.msgpack` weights, no Trainer compat) |
|
| 147 |
+
|
| 148 |
+
**Verdict:** PyTorch has **direct, tested paths** to every major serving framework. JAX requires **multiple intermediate conversions**, each introducing failure points.
|
| 149 |
+
|
| 150 |
+
### 5.4 Training Speed
|
| 151 |
+
|
| 152 |
+
At our scale (24Mβ330M parameters on 4β8 A100s):
|
| 153 |
+
|
| 154 |
+
| Scenario | PyTorch | JAX |
|
| 155 |
+
|----------|---------|-----|
|
| 156 |
+
| Steady-state training throughput | **Comparable** (`torch.compile`) | **Comparable** (XLA JIT) |
|
| 157 |
+
| Variable-length sequences | **Native** β dynamic shapes | **Problematic** β recompiles on new shapes; must pad to buckets |
|
| 158 |
+
| Multi-GPU (FSDP) | `accelerate` + FSDP2 β mature | `pmap`/`shard_map` β works but harder to configure |
|
| 159 |
+
| First-run compilation | Instant (eager mode) | 5β20s JIT compilation overhead |
|
| 160 |
+
| Debugging | Standard Python debugger | `print` debugging; cryptic XLA errors |
|
| 161 |
+
|
| 162 |
+
**Verdict:** At 330M parameters, training speed is a **wash**. JAX's advantages (XLA kernel fusion, TPU native) only matter at 10B+ parameters on 256+ accelerators. At our scale, **developer velocity dominates throughput**.
|
| 163 |
+
|
| 164 |
+
### 5.5 The JAX Advantage: When It Would Win
|
| 165 |
+
|
| 166 |
+
JAX/Flax NNX would be the right choice **if**:
|
| 167 |
+
|
| 168 |
+
1. **Training exclusively on Google TPUs** β JAX is the native TPU compiler; PyTorch/XLA is a port with overhead
|
| 169 |
+
2. **Models >1B parameters** β XLA's whole-program optimization shines at scale
|
| 170 |
+
3. **Fixed-shape workloads** β images, fixed-length token sequences (no variable-length padding issues)
|
| 171 |
+
4. **Need functional transforms** β `vmap` (per-sample gradients), `pmap` (data parallelism), `grad` (higher-order derivatives)
|
| 172 |
+
5. **Google Cloud infrastructure** β Vertex AI, TPU VMs, GCS integration
|
| 173 |
+
|
| 174 |
+
For domainTokenizer's current scope (24Mβ330M, GPU, variable-length sequences, fast iteration), **none of these conditions apply**.
|
| 175 |
+
|
| 176 |
+
### 5.6 The Keras + JAX Mixing Argument
|
| 177 |
+
|
| 178 |
+
The Google for Developers video argues for mixing Keras layers (high-level) with NNX modules (custom, high-performance). In theory, this lets you:
|
| 179 |
+
- Use Keras for standard Transformer layers
|
| 180 |
+
- Use NNX for custom sequential transaction encoders
|
| 181 |
+
- Get JIT compilation on the NNX parts
|
| 182 |
+
|
| 183 |
+
**In practice, this creates problems:**
|
| 184 |
+
1. **Two mental models:** Keras (layer-oriented, `fit/compile`) vs. NNX (functional, explicit state) β context switching slows development
|
| 185 |
+
2. **Limited interop documentation:** Keras β NNX examples are thin; edge cases are poorly documented
|
| 186 |
+
3. **No HF ecosystem integration:** You lose Trainer, push_to_hub, PEFT, TRL, Accelerate β the entire ecosystem Nubank and ActionPiece rely on
|
| 187 |
+
4. **Debugging complexity:** Errors in the KerasβNNX boundary are hard to diagnose
|
| 188 |
+
|
| 189 |
+
**Better approach with PyTorch:** Use `torch.compile()` on performance-critical modules to get JIT compilation benefits without leaving the PyTorch ecosystem. Write custom `nn.Module` subclasses for domain-specific components. This gives you the same "standard parts + custom parts" architecture without framework mixing.
|
| 190 |
+
|
| 191 |
+
---
|
| 192 |
+
|
| 193 |
+
## 6. Consequences
|
| 194 |
+
|
| 195 |
+
### What We Gain
|
| 196 |
+
|
| 197 |
+
1. **Immediate access to the entire HuggingFace ecosystem:** Trainer, Accelerate, PEFT (LoRA), TRL, Evaluate, push_to_hub, from_pretrained, ONNX export, TGI serving
|
| 198 |
+
2. **Copy-paste from reference implementations:** ActionPiece, RecFormer, Banking TF, and PLR embeddings are all PyTorch β we can directly reuse their code
|
| 199 |
+
3. **KL3M tokenizer as starting point:** The best financial domain tokenizer already exists in PyTorch/HF format at `alea-institute/kl3m-004-128k-cased`
|
| 200 |
+
4. **Standard production deployment:** ONNX β TensorRT β Triton, or direct TGI/vLLM serving
|
| 201 |
+
5. **Community and hiring:** PyTorch is the dominant ML framework; finding contributors and documentation is easy
|
| 202 |
+
6. **`torch.compile()` for performance:** When we need JIT compilation on hot paths, `torch.compile()` provides 10β30% speedups without leaving the ecosystem
|
| 203 |
+
|
| 204 |
+
### What We Accept
|
| 205 |
+
|
| 206 |
+
1. **No native TPU support:** If we later need to train on Google TPUs, we'll need PyTorch/XLA (slower than native JAX) or migrate the model code
|
| 207 |
+
2. **No functional transforms:** `vmap` (per-sample gradients) isn't available without `functorch` (experimental). If we need advanced gradient manipulation for meta-learning or Nested Learning (HOPE-style), JAX would be better
|
| 208 |
+
3. **Potential future migration cost:** If we scale beyond 1B parameters and move to TPUs, we'll need to rewrite model code in Flax NNX. This is mitigated by keeping model definitions clean and modular
|
| 209 |
+
|
| 210 |
+
### Migration Strategy (If Needed Later)
|
| 211 |
+
|
| 212 |
+
If domainTokenizer grows to >1B parameters and we need TPU training:
|
| 213 |
+
|
| 214 |
+
1. **Tokenizer layer stays in Python/HF:** Tokenizer is framework-agnostic β it produces integer sequences regardless of whether the model is PyTorch or JAX
|
| 215 |
+
2. **Model architecture translates 1:1:** PyTorch `nn.Module` β Flax NNX `nnx.Module` mapping is straightforward for standard Transformer components
|
| 216 |
+
3. **Training loop changes:** PyTorch Trainer β custom Flax NNX training loop with Optax
|
| 217 |
+
4. **Reference:** Google's MaxText (`github.com/google/maxtext`) provides production-grade JAX Transformer patterns we can follow
|
| 218 |
+
|
| 219 |
+
**Estimated migration effort:** 2β4 weeks for a clean, well-separated codebase.
|
| 220 |
+
|
| 221 |
+
---
|
| 222 |
+
|
| 223 |
+
## 7. Implementation Roadmap
|
| 224 |
+
|
| 225 |
+
### Phase 2A: Core Tokenizer Library (Weeks 1β3)
|
| 226 |
+
|
| 227 |
+
#### Step 1: Domain Schema Definition
|
| 228 |
+
|
| 229 |
+
Create a declarative schema format that describes the fields in a domain's event data:
|
| 230 |
+
|
| 231 |
+
```python
|
| 232 |
+
# src/tokenizers/schema.py
|
| 233 |
+
|
| 234 |
+
from dataclasses import dataclass, field
|
| 235 |
+
from enum import Enum
|
| 236 |
+
from typing import List, Optional
|
| 237 |
+
|
| 238 |
+
class FieldType(Enum):
|
| 239 |
+
NUMERICAL_CONTINUOUS = "numerical_continuous" # prices, amounts β magnitude bins
|
| 240 |
+
NUMERICAL_DISCRETE = "numerical_discrete" # quantities β small fixed vocab
|
| 241 |
+
CATEGORICAL_FIXED = "categorical_fixed" # categories, days of week β direct mapping
|
| 242 |
+
CATEGORICAL_ENTITY = "categorical_entity" # products, merchants β Semantic IDs (RQ-VAE)
|
| 243 |
+
TEMPORAL = "temporal" # timestamps β calendar decomposition
|
| 244 |
+
TEXT = "text" # descriptions β BPE subwords
|
| 245 |
+
SIGN = "sign" # credit/debit β 2 tokens
|
| 246 |
+
|
| 247 |
+
@dataclass
|
| 248 |
+
class FieldSpec:
|
| 249 |
+
name: str
|
| 250 |
+
field_type: FieldType
|
| 251 |
+
vocab_size: Optional[int] = None # for fixed categorical
|
| 252 |
+
n_bins: int = 21 # for numerical (Nubank uses 21)
|
| 253 |
+
calendar_fields: List[str] = field( # for temporal
|
| 254 |
+
default_factory=lambda: ["month", "dow", "dom", "hour"]
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
@dataclass
|
| 258 |
+
class DomainSchema:
|
| 259 |
+
name: str # e.g., "ecommerce", "finance"
|
| 260 |
+
fields: List[FieldSpec] # ordered list of fields per event
|
| 261 |
+
|
| 262 |
+
@property
|
| 263 |
+
def special_token_count(self) -> int:
|
| 264 |
+
"""Total domain-specific special tokens needed."""
|
| 265 |
+
count = 0
|
| 266 |
+
for f in self.fields:
|
| 267 |
+
if f.field_type == FieldType.SIGN:
|
| 268 |
+
count += 2
|
| 269 |
+
elif f.field_type == FieldType.NUMERICAL_CONTINUOUS:
|
| 270 |
+
count += f.n_bins
|
| 271 |
+
elif f.field_type == FieldType.CATEGORICAL_FIXED:
|
| 272 |
+
count += f.vocab_size
|
| 273 |
+
elif f.field_type == FieldType.TEMPORAL:
|
| 274 |
+
count += sum({
|
| 275 |
+
"month": 12, "dow": 7, "dom": 31, "hour": 24,
|
| 276 |
+
"quarter": 4, "year": 10
|
| 277 |
+
}.get(cf, 0) for cf in f.calendar_fields)
|
| 278 |
+
return count
|
| 279 |
+
|
| 280 |
+
# Example: Nubank-style financial schema
|
| 281 |
+
FINANCE_SCHEMA = DomainSchema(
|
| 282 |
+
name="finance",
|
| 283 |
+
fields=[
|
| 284 |
+
FieldSpec("amount_sign", FieldType.SIGN),
|
| 285 |
+
FieldSpec("amount", FieldType.NUMERICAL_CONTINUOUS, n_bins=21),
|
| 286 |
+
FieldSpec("timestamp", FieldType.TEMPORAL,
|
| 287 |
+
calendar_fields=["month", "dow", "dom", "hour"]),
|
| 288 |
+
FieldSpec("description", FieldType.TEXT),
|
| 289 |
+
]
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
# Example: E-commerce schema
|
| 293 |
+
ECOMMERCE_SCHEMA = DomainSchema(
|
| 294 |
+
name="ecommerce",
|
| 295 |
+
fields=[
|
| 296 |
+
FieldSpec("event_type", FieldType.CATEGORICAL_FIXED, vocab_size=5),
|
| 297 |
+
FieldSpec("price", FieldType.NUMERICAL_CONTINUOUS, n_bins=21),
|
| 298 |
+
FieldSpec("quantity", FieldType.NUMERICAL_DISCRETE, vocab_size=11),
|
| 299 |
+
FieldSpec("category_l1", FieldType.CATEGORICAL_FIXED, vocab_size=30),
|
| 300 |
+
FieldSpec("category_l2", FieldType.CATEGORICAL_FIXED, vocab_size=200),
|
| 301 |
+
FieldSpec("timestamp", FieldType.TEMPORAL,
|
| 302 |
+
calendar_fields=["month", "dow", "dom", "hour"]),
|
| 303 |
+
FieldSpec("product_title", FieldType.TEXT),
|
| 304 |
+
]
|
| 305 |
+
)
|
| 306 |
+
```
|
| 307 |
+
|
| 308 |
+
#### Step 2: Per-Field Tokenizers
|
| 309 |
+
|
| 310 |
+
Implement each field type tokenizer as a standalone module:
|
| 311 |
+
|
| 312 |
+
```python
|
| 313 |
+
# src/tokenizers/field_tokenizers.py
|
| 314 |
+
|
| 315 |
+
import numpy as np
|
| 316 |
+
from typing import List
|
| 317 |
+
|
| 318 |
+
class SignTokenizer:
|
| 319 |
+
"""Tokenizes sign of a numerical value (credit/debit, inflow/outflow)."""
|
| 320 |
+
|
| 321 |
+
def __init__(self, prefix: str = "SIGN"):
|
| 322 |
+
self.tokens = [f"[{prefix}_POS]", f"[{prefix}_NEG]"]
|
| 323 |
+
|
| 324 |
+
def __call__(self, value: float) -> str:
|
| 325 |
+
return self.tokens[0] if value >= 0 else self.tokens[1]
|
| 326 |
+
|
| 327 |
+
@property
|
| 328 |
+
def vocab(self) -> List[str]:
|
| 329 |
+
return self.tokens
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
class MagnitudeBucketTokenizer:
|
| 333 |
+
"""Quantizes continuous values into bins (Nubank-style).
|
| 334 |
+
|
| 335 |
+
Uses quantile-based binning on the training distribution.
|
| 336 |
+
Follows the Relative Magnitude Tokenization principle from TP-BERTa.
|
| 337 |
+
"""
|
| 338 |
+
|
| 339 |
+
def __init__(self, n_bins: int = 21, prefix: str = "AMT"):
|
| 340 |
+
self.n_bins = n_bins
|
| 341 |
+
self.prefix = prefix
|
| 342 |
+
self.bin_edges = None # fitted from data
|
| 343 |
+
|
| 344 |
+
def fit(self, values: np.ndarray):
|
| 345 |
+
"""Compute bin edges from training data using quantiles."""
|
| 346 |
+
# Use absolute values for magnitude binning
|
| 347 |
+
abs_vals = np.abs(values[~np.isnan(values)])
|
| 348 |
+
quantiles = np.linspace(0, 100, self.n_bins + 1)
|
| 349 |
+
self.bin_edges = np.percentile(abs_vals, quantiles)
|
| 350 |
+
return self
|
| 351 |
+
|
| 352 |
+
def __call__(self, value: float) -> str:
|
| 353 |
+
if self.bin_edges is None:
|
| 354 |
+
raise ValueError("Tokenizer not fitted. Call .fit() first.")
|
| 355 |
+
bin_idx = np.searchsorted(self.bin_edges[1:-1], abs(value))
|
| 356 |
+
return f"[{self.prefix}_{bin_idx:02d}]"
|
| 357 |
+
|
| 358 |
+
@property
|
| 359 |
+
def vocab(self) -> List[str]:
|
| 360 |
+
return [f"[{self.prefix}_{i:02d}]" for i in range(self.n_bins)]
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
class CalendarTokenizer:
|
| 364 |
+
"""Decomposes timestamps into calendar components (Nubank-style)."""
|
| 365 |
+
|
| 366 |
+
FIELD_VOCABS = {
|
| 367 |
+
"month": ([f"[MON_{i:02d}]" for i in range(1, 13)], lambda dt: dt.month - 1),
|
| 368 |
+
"dow": ([f"[DOW_{i}]" for i in range(7)], lambda dt: dt.weekday()),
|
| 369 |
+
"dom": ([f"[DOM_{i:02d}]" for i in range(1, 32)], lambda dt: dt.day - 1),
|
| 370 |
+
"hour": ([f"[HOUR_{i:02d}]" for i in range(24)], lambda dt: dt.hour),
|
| 371 |
+
}
|
| 372 |
+
|
| 373 |
+
def __init__(self, fields: List[str] = None):
|
| 374 |
+
self.fields = fields or ["month", "dow", "dom", "hour"]
|
| 375 |
+
|
| 376 |
+
def __call__(self, timestamp) -> List[str]:
|
| 377 |
+
tokens = []
|
| 378 |
+
for field_name in self.fields:
|
| 379 |
+
vocab, extractor = self.FIELD_VOCABS[field_name]
|
| 380 |
+
idx = extractor(timestamp)
|
| 381 |
+
tokens.append(vocab[idx])
|
| 382 |
+
return tokens
|
| 383 |
+
|
| 384 |
+
@property
|
| 385 |
+
def vocab(self) -> List[str]:
|
| 386 |
+
all_tokens = []
|
| 387 |
+
for field_name in self.fields:
|
| 388 |
+
all_tokens.extend(self.FIELD_VOCABS[field_name][0])
|
| 389 |
+
return all_tokens
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
class CategoricalTokenizer:
|
| 393 |
+
"""Maps categorical values to fixed vocabulary tokens."""
|
| 394 |
+
|
| 395 |
+
def __init__(self, categories: List[str], prefix: str = "CAT"):
|
| 396 |
+
self.prefix = prefix
|
| 397 |
+
self.token_map = {cat: f"[{prefix}_{i:03d}]" for i, cat in enumerate(categories)}
|
| 398 |
+
self.unk_token = f"[{prefix}_UNK]"
|
| 399 |
+
|
| 400 |
+
def __call__(self, value: str) -> str:
|
| 401 |
+
return self.token_map.get(value, self.unk_token)
|
| 402 |
+
|
| 403 |
+
@property
|
| 404 |
+
def vocab(self) -> List[str]:
|
| 405 |
+
return list(self.token_map.values()) + [self.unk_token]
|
| 406 |
+
```
|
| 407 |
+
|
| 408 |
+
#### Step 3: Composite Domain Tokenizer
|
| 409 |
+
|
| 410 |
+
Assemble per-field tokenizers into a complete domain tokenizer, wrapped as `PreTrainedTokenizerFast`:
|
| 411 |
+
|
| 412 |
+
```python
|
| 413 |
+
# src/tokenizers/domain_tokenizer.py
|
| 414 |
+
|
| 415 |
+
from tokenizers import Tokenizer, models, trainers, pre_tokenizers
|
| 416 |
+
from transformers import PreTrainedTokenizerFast
|
| 417 |
+
|
| 418 |
+
class DomainTokenizerBuilder:
|
| 419 |
+
"""Builds a HuggingFace-compatible tokenizer from a DomainSchema."""
|
| 420 |
+
|
| 421 |
+
def __init__(self, schema: DomainSchema):
|
| 422 |
+
self.schema = schema
|
| 423 |
+
self.field_tokenizers = {} # name β field tokenizer instance
|
| 424 |
+
self._build_field_tokenizers()
|
| 425 |
+
|
| 426 |
+
def _build_field_tokenizers(self):
|
| 427 |
+
for field_spec in self.schema.fields:
|
| 428 |
+
if field_spec.field_type == FieldType.SIGN:
|
| 429 |
+
self.field_tokenizers[field_spec.name] = SignTokenizer(field_spec.name.upper())
|
| 430 |
+
elif field_spec.field_type == FieldType.NUMERICAL_CONTINUOUS:
|
| 431 |
+
self.field_tokenizers[field_spec.name] = MagnitudeBucketTokenizer(
|
| 432 |
+
n_bins=field_spec.n_bins, prefix=field_spec.name.upper()
|
| 433 |
+
)
|
| 434 |
+
elif field_spec.field_type == FieldType.TEMPORAL:
|
| 435 |
+
self.field_tokenizers[field_spec.name] = CalendarTokenizer(field_spec.calendar_fields)
|
| 436 |
+
# ... other types
|
| 437 |
+
|
| 438 |
+
def fit(self, data):
|
| 439 |
+
"""Fit data-dependent tokenizers (magnitude bins, etc.)."""
|
| 440 |
+
for field_spec in self.schema.fields:
|
| 441 |
+
if field_spec.field_type == FieldType.NUMERICAL_CONTINUOUS:
|
| 442 |
+
values = [getattr(event, field_spec.name) for event in data]
|
| 443 |
+
self.field_tokenizers[field_spec.name].fit(np.array(values))
|
| 444 |
+
return self
|
| 445 |
+
|
| 446 |
+
def build_hf_tokenizer(self, text_corpus=None, bpe_vocab_size=8000) -> PreTrainedTokenizerFast:
|
| 447 |
+
"""Build a complete HuggingFace tokenizer.
|
| 448 |
+
|
| 449 |
+
1. Collect all domain special tokens
|
| 450 |
+
2. Train BPE on text fields (if any)
|
| 451 |
+
3. Merge into a single PreTrainedTokenizerFast
|
| 452 |
+
"""
|
| 453 |
+
# Collect all special tokens from field tokenizers
|
| 454 |
+
all_special_tokens = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "[BOS]", "[EOS]"]
|
| 455 |
+
for name, tok in self.field_tokenizers.items():
|
| 456 |
+
if hasattr(tok, 'vocab'):
|
| 457 |
+
all_special_tokens.extend(tok.vocab)
|
| 458 |
+
|
| 459 |
+
# Train BPE on text fields
|
| 460 |
+
bpe_tokenizer = Tokenizer(models.BPE(unk_token="[UNK]"))
|
| 461 |
+
bpe_tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
|
| 462 |
+
trainer = trainers.BpeTrainer(
|
| 463 |
+
vocab_size=bpe_vocab_size,
|
| 464 |
+
special_tokens=all_special_tokens,
|
| 465 |
+
min_frequency=2,
|
| 466 |
+
)
|
| 467 |
+
if text_corpus:
|
| 468 |
+
bpe_tokenizer.train_from_iterator(text_corpus, trainer=trainer)
|
| 469 |
+
|
| 470 |
+
# Wrap as HuggingFace tokenizer
|
| 471 |
+
hf_tokenizer = PreTrainedTokenizerFast(
|
| 472 |
+
tokenizer_object=bpe_tokenizer,
|
| 473 |
+
bos_token="[BOS]",
|
| 474 |
+
eos_token="[EOS]",
|
| 475 |
+
pad_token="[PAD]",
|
| 476 |
+
unk_token="[UNK]",
|
| 477 |
+
mask_token="[MASK]",
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
return hf_tokenizer
|
| 481 |
+
|
| 482 |
+
def tokenize_event(self, event) -> List[str]:
|
| 483 |
+
"""Convert a single domain event into a list of token strings."""
|
| 484 |
+
tokens = []
|
| 485 |
+
for field_spec in self.schema.fields:
|
| 486 |
+
value = getattr(event, field_spec.name, None)
|
| 487 |
+
if value is None:
|
| 488 |
+
tokens.append("[UNK]")
|
| 489 |
+
continue
|
| 490 |
+
tok = self.field_tokenizers[field_spec.name]
|
| 491 |
+
result = tok(value)
|
| 492 |
+
if isinstance(result, list):
|
| 493 |
+
tokens.extend(result)
|
| 494 |
+
else:
|
| 495 |
+
tokens.append(result)
|
| 496 |
+
return tokens
|
| 497 |
+
```
|
| 498 |
+
|
| 499 |
+
### Phase 2B: Model Architecture (Weeks 3β5)
|
| 500 |
+
|
| 501 |
+
#### Step 4: GPT-style Causal Transformer (NoPE)
|
| 502 |
+
|
| 503 |
+
Implement as a HuggingFace-compatible `PreTrainedModel`:
|
| 504 |
+
|
| 505 |
+
```python
|
| 506 |
+
# src/models/configuration_domain_transformer.py
|
| 507 |
+
|
| 508 |
+
from transformers import PretrainedConfig
|
| 509 |
+
|
| 510 |
+
class DomainTransformerConfig(PretrainedConfig):
|
| 511 |
+
model_type = "domain_transformer"
|
| 512 |
+
|
| 513 |
+
def __init__(
|
| 514 |
+
self,
|
| 515 |
+
vocab_size: int = 32000,
|
| 516 |
+
hidden_size: int = 256, # 256 = 24M params, 1024 = 330M (Nubank sizes)
|
| 517 |
+
num_hidden_layers: int = 24, # Nubank uses 24 for both sizes
|
| 518 |
+
num_attention_heads: int = 16, # Nubank uses 16 for both sizes
|
| 519 |
+
intermediate_size: int = None, # defaults to 4 * hidden_size
|
| 520 |
+
max_position_embeddings: int = 2048,
|
| 521 |
+
dropout: float = 0.1,
|
| 522 |
+
use_positional_encoding: bool = False, # NoPE by default!
|
| 523 |
+
**kwargs
|
| 524 |
+
):
|
| 525 |
+
self.vocab_size = vocab_size
|
| 526 |
+
self.hidden_size = hidden_size
|
| 527 |
+
self.num_hidden_layers = num_hidden_layers
|
| 528 |
+
self.num_attention_heads = num_attention_heads
|
| 529 |
+
self.intermediate_size = intermediate_size or 4 * hidden_size
|
| 530 |
+
self.max_position_embeddings = max_position_embeddings
|
| 531 |
+
self.dropout = dropout
|
| 532 |
+
self.use_positional_encoding = use_positional_encoding
|
| 533 |
+
super().__init__(**kwargs)
|
| 534 |
+
```
|
| 535 |
+
|
| 536 |
+
```python
|
| 537 |
+
# src/models/modeling_domain_transformer.py
|
| 538 |
+
|
| 539 |
+
import torch
|
| 540 |
+
import torch.nn as nn
|
| 541 |
+
from transformers import PreTrainedModel
|
| 542 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 543 |
+
|
| 544 |
+
class DomainTransformerBlock(nn.Module):
|
| 545 |
+
def __init__(self, config):
|
| 546 |
+
super().__init__()
|
| 547 |
+
self.ln1 = nn.LayerNorm(config.hidden_size)
|
| 548 |
+
self.attn = nn.MultiheadAttention(
|
| 549 |
+
config.hidden_size, config.num_attention_heads,
|
| 550 |
+
dropout=config.dropout, batch_first=True
|
| 551 |
+
)
|
| 552 |
+
self.ln2 = nn.LayerNorm(config.hidden_size)
|
| 553 |
+
self.mlp = nn.Sequential(
|
| 554 |
+
nn.Linear(config.hidden_size, config.intermediate_size),
|
| 555 |
+
nn.GELU(),
|
| 556 |
+
nn.Linear(config.intermediate_size, config.hidden_size),
|
| 557 |
+
nn.Dropout(config.dropout),
|
| 558 |
+
)
|
| 559 |
+
|
| 560 |
+
def forward(self, x, attn_mask=None):
|
| 561 |
+
# Pre-norm architecture
|
| 562 |
+
h = self.ln1(x)
|
| 563 |
+
h, _ = self.attn(h, h, h, attn_mask=attn_mask, is_causal=True)
|
| 564 |
+
x = x + h
|
| 565 |
+
x = x + self.mlp(self.ln2(x))
|
| 566 |
+
return x
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
class DomainTransformerForCausalLM(PreTrainedModel):
|
| 570 |
+
config_class = DomainTransformerConfig
|
| 571 |
+
|
| 572 |
+
def __init__(self, config):
|
| 573 |
+
super().__init__(config)
|
| 574 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
| 575 |
+
|
| 576 |
+
# NoPE: no positional encoding by default (Kazemnejad et al. 2023)
|
| 577 |
+
if config.use_positional_encoding:
|
| 578 |
+
self.embed_positions = nn.Embedding(
|
| 579 |
+
config.max_position_embeddings, config.hidden_size
|
| 580 |
+
)
|
| 581 |
+
else:
|
| 582 |
+
self.embed_positions = None
|
| 583 |
+
|
| 584 |
+
self.drop = nn.Dropout(config.dropout)
|
| 585 |
+
self.blocks = nn.ModuleList([
|
| 586 |
+
DomainTransformerBlock(config)
|
| 587 |
+
for _ in range(config.num_hidden_layers)
|
| 588 |
+
])
|
| 589 |
+
self.ln_f = nn.LayerNorm(config.hidden_size)
|
| 590 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 591 |
+
|
| 592 |
+
# Weight tying (standard for small models)
|
| 593 |
+
self.lm_head.weight = self.embed_tokens.weight
|
| 594 |
+
|
| 595 |
+
self.post_init()
|
| 596 |
+
|
| 597 |
+
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
|
| 598 |
+
x = self.embed_tokens(input_ids)
|
| 599 |
+
|
| 600 |
+
if self.embed_positions is not None:
|
| 601 |
+
positions = torch.arange(input_ids.size(1), device=input_ids.device)
|
| 602 |
+
x = x + self.embed_positions(positions)
|
| 603 |
+
|
| 604 |
+
x = self.drop(x)
|
| 605 |
+
|
| 606 |
+
for block in self.blocks:
|
| 607 |
+
x = block(x, attn_mask=attention_mask)
|
| 608 |
+
|
| 609 |
+
x = self.ln_f(x)
|
| 610 |
+
logits = self.lm_head(x)
|
| 611 |
+
|
| 612 |
+
loss = None
|
| 613 |
+
if labels is not None:
|
| 614 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 615 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 616 |
+
loss = nn.functional.cross_entropy(
|
| 617 |
+
shift_logits.view(-1, shift_logits.size(-1)),
|
| 618 |
+
shift_labels.view(-1),
|
| 619 |
+
ignore_index=-100,
|
| 620 |
+
)
|
| 621 |
+
|
| 622 |
+
return CausalLMOutputWithPast(loss=loss, logits=logits)
|
| 623 |
+
|
| 624 |
+
# Register with AutoClass for Hub compatibility
|
| 625 |
+
DomainTransformerConfig.register_for_auto_class()
|
| 626 |
+
DomainTransformerForCausalLM.register_for_auto_class("AutoModelForCausalLM")
|
| 627 |
+
```
|
| 628 |
+
|
| 629 |
+
#### Step 5: PLR Numerical Embeddings (for Joint Fusion)
|
| 630 |
+
|
| 631 |
+
Port from Yandex's implementation:
|
| 632 |
+
|
| 633 |
+
```python
|
| 634 |
+
# src/models/plr_embeddings.py
|
| 635 |
+
|
| 636 |
+
import torch
|
| 637 |
+
import torch.nn as nn
|
| 638 |
+
import math
|
| 639 |
+
|
| 640 |
+
class PeriodicLinearReLU(nn.Module):
|
| 641 |
+
"""PLR numerical embeddings (Gorishniy et al. 2022).
|
| 642 |
+
|
| 643 |
+
Maps scalar x β [sin(2ΟΒ·wΒ·x + b), cos(2ΟΒ·wΒ·x + b)] β Linear β ReLU
|
| 644 |
+
Frequencies w and phases b are LEARNED parameters.
|
| 645 |
+
"""
|
| 646 |
+
|
| 647 |
+
def __init__(self, n_features: int, n_frequencies: int = 64, embedding_dim: int = 64):
|
| 648 |
+
super().__init__()
|
| 649 |
+
self.n_features = n_features
|
| 650 |
+
self.n_frequencies = n_frequencies
|
| 651 |
+
|
| 652 |
+
# Learnable frequencies and phases (per feature)
|
| 653 |
+
self.frequencies = nn.Parameter(
|
| 654 |
+
torch.randn(n_features, n_frequencies) * 0.01
|
| 655 |
+
)
|
| 656 |
+
self.phases = nn.Parameter(
|
| 657 |
+
torch.zeros(n_features, n_frequencies)
|
| 658 |
+
)
|
| 659 |
+
|
| 660 |
+
# Linear projection: 2*n_frequencies β embedding_dim
|
| 661 |
+
self.linear = nn.Linear(2 * n_frequencies, embedding_dim)
|
| 662 |
+
|
| 663 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 664 |
+
"""
|
| 665 |
+
Args:
|
| 666 |
+
x: (batch, n_features) β raw scalar feature values
|
| 667 |
+
Returns:
|
| 668 |
+
(batch, n_features, embedding_dim)
|
| 669 |
+
"""
|
| 670 |
+
# x: (B, F) β (B, F, 1)
|
| 671 |
+
x = x.unsqueeze(-1)
|
| 672 |
+
|
| 673 |
+
# Periodic encoding: (B, F, n_freq)
|
| 674 |
+
angles = 2 * math.pi * self.frequencies.unsqueeze(0) * x + self.phases.unsqueeze(0)
|
| 675 |
+
periodic = torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1) # (B, F, 2*n_freq)
|
| 676 |
+
|
| 677 |
+
# Linear + ReLU: (B, F, embedding_dim)
|
| 678 |
+
return torch.relu(self.linear(periodic))
|
| 679 |
+
```
|
| 680 |
+
|
| 681 |
+
### Phase 2C: Pre-training (Weeks 5β7)
|
| 682 |
+
|
| 683 |
+
#### Step 6: Data Pipeline
|
| 684 |
+
|
| 685 |
+
```python
|
| 686 |
+
# src/training/data_pipeline.py
|
| 687 |
+
|
| 688 |
+
from torch.utils.data import Dataset
|
| 689 |
+
from typing import List
|
| 690 |
+
|
| 691 |
+
class DomainSequenceDataset(Dataset):
|
| 692 |
+
"""Converts user event sequences into token sequences for CLM training."""
|
| 693 |
+
|
| 694 |
+
def __init__(self, user_sequences, tokenizer_builder, hf_tokenizer, max_length=2048):
|
| 695 |
+
self.user_sequences = user_sequences
|
| 696 |
+
self.tokenizer_builder = tokenizer_builder
|
| 697 |
+
self.hf_tokenizer = hf_tokenizer
|
| 698 |
+
self.max_length = max_length
|
| 699 |
+
|
| 700 |
+
def __len__(self):
|
| 701 |
+
return len(self.user_sequences)
|
| 702 |
+
|
| 703 |
+
def __getitem__(self, idx):
|
| 704 |
+
events = self.user_sequences[idx]
|
| 705 |
+
|
| 706 |
+
# Tokenize each event into token strings
|
| 707 |
+
token_strings = []
|
| 708 |
+
for event in events:
|
| 709 |
+
event_tokens = self.tokenizer_builder.tokenize_event(event)
|
| 710 |
+
token_strings.extend(event_tokens)
|
| 711 |
+
|
| 712 |
+
# Convert token strings to IDs via HF tokenizer
|
| 713 |
+
encoding = self.hf_tokenizer(
|
| 714 |
+
" ".join(token_strings),
|
| 715 |
+
max_length=self.max_length,
|
| 716 |
+
truncation=True,
|
| 717 |
+
padding="max_length",
|
| 718 |
+
return_tensors="pt",
|
| 719 |
+
)
|
| 720 |
+
|
| 721 |
+
input_ids = encoding["input_ids"].squeeze(0)
|
| 722 |
+
|
| 723 |
+
return {
|
| 724 |
+
"input_ids": input_ids,
|
| 725 |
+
"labels": input_ids.clone(), # CLM: labels = input shifted by 1
|
| 726 |
+
"attention_mask": encoding["attention_mask"].squeeze(0),
|
| 727 |
+
}
|
| 728 |
+
```
|
| 729 |
+
|
| 730 |
+
#### Step 7: Pre-training with HuggingFace Trainer
|
| 731 |
+
|
| 732 |
+
```python
|
| 733 |
+
# src/training/pretrain.py
|
| 734 |
+
|
| 735 |
+
from transformers import Trainer, TrainingArguments
|
| 736 |
+
|
| 737 |
+
def pretrain_domain_model(
|
| 738 |
+
model,
|
| 739 |
+
train_dataset,
|
| 740 |
+
eval_dataset=None,
|
| 741 |
+
output_dir="./checkpoints",
|
| 742 |
+
hub_model_id="org/domain-model-24m",
|
| 743 |
+
num_epochs=3,
|
| 744 |
+
batch_size=64,
|
| 745 |
+
learning_rate=3e-4,
|
| 746 |
+
context_length=2048,
|
| 747 |
+
):
|
| 748 |
+
training_args = TrainingArguments(
|
| 749 |
+
output_dir=output_dir,
|
| 750 |
+
num_train_epochs=num_epochs,
|
| 751 |
+
per_device_train_batch_size=batch_size,
|
| 752 |
+
gradient_accumulation_steps=4,
|
| 753 |
+
learning_rate=learning_rate,
|
| 754 |
+
lr_scheduler_type="cosine",
|
| 755 |
+
warmup_ratio=0.05,
|
| 756 |
+
weight_decay=0.01,
|
| 757 |
+
logging_strategy="steps",
|
| 758 |
+
logging_steps=100,
|
| 759 |
+
logging_first_step=True,
|
| 760 |
+
disable_tqdm=True, # plain text logging for cloud
|
| 761 |
+
eval_strategy="steps" if eval_dataset else "no",
|
| 762 |
+
eval_steps=500,
|
| 763 |
+
save_strategy="steps",
|
| 764 |
+
save_steps=1000,
|
| 765 |
+
save_total_limit=3,
|
| 766 |
+
push_to_hub=True,
|
| 767 |
+
hub_model_id=hub_model_id,
|
| 768 |
+
bf16=True,
|
| 769 |
+
dataloader_num_workers=4,
|
| 770 |
+
report_to="trackio",
|
| 771 |
+
)
|
| 772 |
+
|
| 773 |
+
trainer = Trainer(
|
| 774 |
+
model=model,
|
| 775 |
+
args=training_args,
|
| 776 |
+
train_dataset=train_dataset,
|
| 777 |
+
eval_dataset=eval_dataset,
|
| 778 |
+
)
|
| 779 |
+
|
| 780 |
+
trainer.train()
|
| 781 |
+
trainer.push_to_hub()
|
| 782 |
+
```
|
| 783 |
+
|
| 784 |
+
### Phase 2D: Joint Fusion Fine-tuning (Weeks 7β9)
|
| 785 |
+
|
| 786 |
+
#### Step 8: nuFormer-style Joint Fusion
|
| 787 |
+
|
| 788 |
+
```python
|
| 789 |
+
# src/models/joint_fusion.py
|
| 790 |
+
|
| 791 |
+
import torch
|
| 792 |
+
import torch.nn as nn
|
| 793 |
+
|
| 794 |
+
class DCNv2CrossLayer(nn.Module):
|
| 795 |
+
"""Single cross layer from DCN V2 (Wang et al. 2021)."""
|
| 796 |
+
|
| 797 |
+
def __init__(self, dim):
|
| 798 |
+
super().__init__()
|
| 799 |
+
self.weight = nn.Linear(dim, dim, bias=True)
|
| 800 |
+
|
| 801 |
+
def forward(self, x0, x):
|
| 802 |
+
return x0 * self.weight(x) + x # element-wise product with anchor
|
| 803 |
+
|
| 804 |
+
|
| 805 |
+
class JointFusionModel(nn.Module):
|
| 806 |
+
"""nuFormer-style: Transaction Transformer + DCNv2(PLR) β Joint Prediction.
|
| 807 |
+
|
| 808 |
+
Architecture:
|
| 809 |
+
Transaction Sequence β Pre-trained DomainTransformer β user_embedding
|
| 810 |
+
Tabular Features β PLR β DCNv2 β tab_embedding
|
| 811 |
+
Concatenate β MLP Head β prediction
|
| 812 |
+
"""
|
| 813 |
+
|
| 814 |
+
def __init__(self, transformer_model, n_tabular_features, n_classes=1,
|
| 815 |
+
plr_frequencies=64, dcn_layers=3, hidden_dim=256):
|
| 816 |
+
super().__init__()
|
| 817 |
+
|
| 818 |
+
self.transformer = transformer_model # pre-trained, unfrozen for fine-tuning
|
| 819 |
+
transformer_dim = transformer_model.config.hidden_size
|
| 820 |
+
|
| 821 |
+
# Tabular branch: PLR β DCNv2
|
| 822 |
+
self.plr = PeriodicLinearReLU(n_tabular_features, plr_frequencies, hidden_dim)
|
| 823 |
+
tab_input_dim = n_tabular_features * hidden_dim
|
| 824 |
+
|
| 825 |
+
self.dcn_layers = nn.ModuleList([
|
| 826 |
+
DCNv2CrossLayer(tab_input_dim) for _ in range(dcn_layers)
|
| 827 |
+
])
|
| 828 |
+
self.dcn_deep = nn.Sequential(
|
| 829 |
+
nn.Linear(tab_input_dim, hidden_dim),
|
| 830 |
+
nn.ReLU(),
|
| 831 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 832 |
+
nn.ReLU(),
|
| 833 |
+
)
|
| 834 |
+
|
| 835 |
+
# Joint head
|
| 836 |
+
self.head = nn.Sequential(
|
| 837 |
+
nn.Linear(transformer_dim + hidden_dim, hidden_dim),
|
| 838 |
+
nn.ReLU(),
|
| 839 |
+
nn.Dropout(0.1),
|
| 840 |
+
nn.Linear(hidden_dim, n_classes),
|
| 841 |
+
)
|
| 842 |
+
|
| 843 |
+
def forward(self, input_ids, attention_mask, tabular_features, labels=None):
|
| 844 |
+
# Transaction branch: get last-token embedding
|
| 845 |
+
transformer_output = self.transformer(input_ids, attention_mask=attention_mask)
|
| 846 |
+
user_embedding = transformer_output.logits[:, -1, :] # last token representation
|
| 847 |
+
|
| 848 |
+
# Tabular branch: PLR β flatten β DCNv2
|
| 849 |
+
tab_embedded = self.plr(tabular_features) # (B, F, D)
|
| 850 |
+
tab_flat = tab_embedded.view(tab_embedded.size(0), -1) # (B, F*D)
|
| 851 |
+
|
| 852 |
+
x0 = tab_flat
|
| 853 |
+
x = tab_flat
|
| 854 |
+
for cross_layer in self.dcn_layers:
|
| 855 |
+
x = cross_layer(x0, x)
|
| 856 |
+
tab_output = self.dcn_deep(x) # (B, hidden_dim)
|
| 857 |
+
|
| 858 |
+
# Joint fusion
|
| 859 |
+
combined = torch.cat([user_embedding, tab_output], dim=-1)
|
| 860 |
+
logits = self.head(combined)
|
| 861 |
+
|
| 862 |
+
loss = None
|
| 863 |
+
if labels is not None:
|
| 864 |
+
loss = nn.functional.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
|
| 865 |
+
|
| 866 |
+
return {"loss": loss, "logits": logits}
|
| 867 |
+
```
|
| 868 |
+
|
| 869 |
+
### Phase 3: Domain Demos (Weeks 9β12)
|
| 870 |
+
|
| 871 |
+
| Week | Deliverable | Hardware |
|
| 872 |
+
|------|------------|----------|
|
| 873 |
+
| 9β10 | **Finance demo:** Transaction tokenizer + 24M model pre-trained on synthetic/public financial data + fraud detection fine-tuning | a10g-large |
|
| 874 |
+
| 10β11 | **E-commerce demo:** Event tokenizer + 24M model pre-trained on Amazon review sequences + next-purchase prediction | a10g-large |
|
| 875 |
+
| 11β12 | **Evaluation & benchmarking:** Compare domain tokenizer vs. text serialization vs. LightGBM baselines on each domain | a10g-large |
|
| 876 |
+
|
| 877 |
+
### Phase 4: Scale & Optimize (Weeks 12+)
|
| 878 |
+
|
| 879 |
+
| Task | Details |
|
| 880 |
+
|------|---------|
|
| 881 |
+
| Scale to 330M params | Increase `hidden_size` to 1024, train on a100-large |
|
| 882 |
+
| `torch.compile()` | Apply to attention and MLP blocks for 10β30% speedup |
|
| 883 |
+
| ONNX export | `torch.onnx.export()` for production serving |
|
| 884 |
+
| Context window experiments | Ablate 512/1024/2048/4096 context lengths |
|
| 885 |
+
| Data source ablation | Test impact of different event types (Nubank found adding low-signal sources hurts) |
|
| 886 |
+
| ActionPiece vocabulary | Implement BPE-like cross-field merging on top of per-field tokens |
|
| 887 |
+
|
| 888 |
+
---
|
| 889 |
+
|
| 890 |
+
## Appendix A: Framework Usage Across Reference Papers
|
| 891 |
+
|
| 892 |
+
| Paper | ArXiv | Framework | Verbatim Evidence |
|
| 893 |
+
|-------|-------|-----------|-------------------|
|
| 894 |
+
| nuFormer (Nubank) | 2507.23267 | PyTorch + HF (inferred) | All dependencies are PyTorch-based |
|
| 895 |
+
| TIGER (Google) | 2305.05065 | JAX + T5X | "We use the open-sourced T5X framework" |
|
| 896 |
+
| ActionPiece (DeepMind) | 2502.13581 | PyTorch + HF | "HuggingFace Transformers and PyTorch" (Appendix H) |
|
| 897 |
+
| RecFormer | 2305.13731 | PyTorch + HF Longformer | "Longformer implemented by Huggingface" (Β§3.1.4) |
|
| 898 |
+
| Banking TF | 2410.08243 | PyTorch | "Pytorch backend is used" (Appendix B) |
|
| 899 |
+
| PLR Embeddings (Yandex) | 2203.05556 | PyTorch | Repository: pure PyTorch + scikit-learn |
|
| 900 |
+
| KL3M Tokenizers | 2503.17247 | HF `tokenizers` + PyTorch | "tokenizers" BPE for HF compatibility |
|
| 901 |
+
|
| 902 |
+
---
|
| 903 |
+
|
| 904 |
+
## Appendix B: Head-to-Head Comparison Matrix
|
| 905 |
+
|
| 906 |
+
| Criterion | PyTorch + HF | JAX/Flax NNX | Keras 3 + JAX |
|
| 907 |
+
|-----------|-------------|-------------|---------------|
|
| 908 |
+
| Custom domain tokenizer | β
`PreTrainedTokenizerFast` | β Build from scratch | β Build from scratch |
|
| 909 |
+
| HF Trainer integration | β
Native | β Not compatible | β Not compatible |
|
| 910 |
+
| Hub push/pull | β
`push_to_hub()` | β οΈ Works, fragmented | β οΈ Limited |
|
| 911 |
+
| PEFT/LoRA | β
Drop-in | β Manual | β Manual |
|
| 912 |
+
| ONNX export | β
One-line | β Multi-hop | β οΈ TF backend required |
|
| 913 |
+
| TGI/vLLM serving | β
First-class | β Not supported | β Not supported |
|
| 914 |
+
| TPU training | β οΈ PyTorch/XLA (overhead) | β
Native | β
Native |
|
| 915 |
+
| JIT compilation | β
`torch.compile()` | β
`@nnx.jit` | β
XLA via JAX |
|
| 916 |
+
| Dynamic shapes (NLP) | β
Native | β Recompiles | β Recompiles |
|
| 917 |
+
| Debugging | β
Eager mode, std debugger | β οΈ Challenging | β οΈ Challenging |
|
| 918 |
+
| Reference implementations | 5/6 papers | 1/6 papers | 0/6 papers |
|
| 919 |
+
| Community/hiring pool | π’ Large | π‘ Small | π‘ Small |
|
| 920 |
+
|
| 921 |
+
---
|
| 922 |
+
|
| 923 |
+
## Appendix C: Key Code Patterns
|
| 924 |
+
|
| 925 |
+
### Adding Domain Special Tokens to an Existing Tokenizer
|
| 926 |
+
|
| 927 |
+
```python
|
| 928 |
+
from transformers import AutoTokenizer
|
| 929 |
+
|
| 930 |
+
tokenizer = AutoTokenizer.from_pretrained("gpt2") # or any base tokenizer
|
| 931 |
+
|
| 932 |
+
# Add all domain special tokens
|
| 933 |
+
special_tokens = {
|
| 934 |
+
"additional_special_tokens": [
|
| 935 |
+
# Amount tokens (Nubank-style)
|
| 936 |
+
"[AMT_POS]", "[AMT_NEG]",
|
| 937 |
+
*[f"[AMT_{i:02d}]" for i in range(21)],
|
| 938 |
+
# Calendar tokens
|
| 939 |
+
*[f"[MON_{i:02d}]" for i in range(1, 13)],
|
| 940 |
+
*[f"[DOW_{i}]" for i in range(7)],
|
| 941 |
+
*[f"[DOM_{i:02d}]" for i in range(1, 32)],
|
| 942 |
+
*[f"[HOUR_{i:02d}]" for i in range(24)],
|
| 943 |
+
]
|
| 944 |
+
}
|
| 945 |
+
num_added = tokenizer.add_special_tokens(special_tokens)
|
| 946 |
+
print(f"Added {num_added} domain tokens. Vocab size: {len(tokenizer)}")
|
| 947 |
+
|
| 948 |
+
# CRITICAL: resize model embeddings
|
| 949 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 950 |
+
```
|
| 951 |
+
|
| 952 |
+
### Registering a Custom Model for Hub Deployment
|
| 953 |
+
|
| 954 |
+
```python
|
| 955 |
+
# In your package's __init__.py or a registration script:
|
| 956 |
+
from transformers import AutoConfig, AutoModelForCausalLM
|
| 957 |
+
|
| 958 |
+
from .configuration_domain_transformer import DomainTransformerConfig
|
| 959 |
+
from .modeling_domain_transformer import DomainTransformerForCausalLM
|
| 960 |
+
|
| 961 |
+
# Register so AutoClass can find your model
|
| 962 |
+
AutoConfig.register("domain_transformer", DomainTransformerConfig)
|
| 963 |
+
AutoModelForCausalLM.register(DomainTransformerConfig, DomainTransformerForCausalLM)
|
| 964 |
+
|
| 965 |
+
# Enable push_to_hub with custom code
|
| 966 |
+
DomainTransformerConfig.register_for_auto_class()
|
| 967 |
+
DomainTransformerForCausalLM.register_for_auto_class("AutoModelForCausalLM")
|
| 968 |
+
|
| 969 |
+
# Push: uploads configuration.py, modeling.py, config.json, model.safetensors
|
| 970 |
+
model.push_to_hub("org/domain-transformer-24m")
|
| 971 |
+
|
| 972 |
+
# Load anywhere:
|
| 973 |
+
model = AutoModelForCausalLM.from_pretrained("org/domain-transformer-24m", trust_remote_code=True)
|
| 974 |
+
```
|
| 975 |
+
|
| 976 |
+
---
|
| 977 |
+
|
| 978 |
+
*This ADR is a living document and will be updated as implementation progresses.*
|