Gemma 4 E4B MTP Drafter β Extracted from LiteRT
First public extraction of Google Gemma 4's Multi-Token Prediction (MTP) drafter weights from LiteRT format into standard PyTorch safetensors.
Google trained all Gemma 4 variants with MTP heads for built-in speculative decoding, then deliberately stripped them from the public HuggingFace release, retaining them exclusively in LiteRT on-device format. This repository provides the extracted drafter as a standard PyTorch module with a working speculative decoding implementation.
Key Results
| Metric | Value |
|---|---|
| Step-0 top-1 acceptance (greedy) | 35% |
| Step-0 top-5 overlap with base model | ~80% |
| Architecture fully reconstructed | 42/42 tensors, all shapes verified |
| Outputs correct | Yes β coherent, accurate text on all test prompts |
| INT4 nibble order | low_first confirmed |
| Attention scaling | None (QK-norm with scale=1.0, confirmed correct) |
The 35% top-1 acceptance from dequantized INT4/INT8 mobile weights is the expected ceiling β quantization noise is irreversible. The 80% top-5 overlap makes these weights immediately useful for:
- Tree-based speculative decoding (top-k draft candidates verified in parallel)
- vLLM/SGLang native spec decode (zero-overhead verification eliminates the Python loop bottleneck)
- Architecture reference for training EAGLE3 draft heads from full BF16 activations
Architecture
The MTP drafter is a lightweight 4-layer transformer (78M parameters) designed as a one-step recurrent cell:
Input: concat(token_embedding, projected_activation) [5120]
β Linear pre-projection [5120 β 256]
β 4 transformer blocks (256 hidden, 2048 intermediate, GeGLU MLP)
Blocks 0-2: sliding attention (window=512), head_dim=256, 4Q/2KV heads
Block 3: full attention, head_dim=512, 4Q/2KV heads
β RMSNorm β parallel heads:
LM head: [256 β 262144] with tanh soft-cap at 30.0
Post-proj: [256 β 2560] (recurrent: fed back as next step's activation)
Critical architectural details:
- Q-only attention β the drafter owns no K/V projections. It attends over the base model's shared KV banks from layers 22 (sliding) and 23 (full attention).
- No
1/sqrt(d)attention scaling β Gemma 4 uses QK-norm withattention_scale=1.0. Adding traditional scaling degrades accuracy (top-5 drops from 80% to 20%). The fixed-scale query norms (local=0.9916, global=1.0228) serve as the scaling mechanism. - Heterogeneous head dimensions β blocks 0-2 use head_dim=256, block 3 uses head_dim=512. This is often misidentified as "8 heads" for block 3; it is actually 4 heads with 512-dim, as confirmed by the reshape ops in the LiteRT graph.
- BF16 precision required β Gemma 4's unscaled attention is precision-sensitive. FP16's narrower dynamic range is reported to cause output degeneration after ~50 tokens due to overflow in the unscaled dot products. Always serve in BF16 with F32 intermediate accumulation.
Files
| File | Description |
|---|---|
mtp_drafter.safetensors |
Extracted weights (297 MB, fp32, 42 tensors, 78M params) |
mtp_drafter_config.json |
Architecture configuration |
gemma4_litert_mtp/ |
Python package: PyTorch module, HF patches, speculative generation |
Quick Start
import torch
from transformers import AutoModelForConditionalGeneration, AutoTokenizer
from safetensors.torch import load_file
from gemma4_litert_mtp import Gemma4LiteRtMtpDrafter, LiteRtMtpDrafterConfig
from gemma4_litert_mtp.speculative_generate import Gemma4MtpSpeculativeGenerator
# Load base model (BF16 required β FP16 will degenerate)
model = AutoModelForConditionalGeneration.from_pretrained(
"google/gemma-4-E4B-it", torch_dtype=torch.bfloat16, device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-4-E4B-it")
# Load extracted MTP drafter
config = LiteRtMtpDrafterConfig.from_json_file("mtp_drafter_config.json")
drafter = Gemma4LiteRtMtpDrafter(config)
drafter.load_state_dict(load_file("mtp_drafter.safetensors"), strict=True)
drafter = drafter.to(dtype=torch.bfloat16, device=model.device).eval()
# Generate with speculative decoding
generator = Gemma4MtpSpeculativeGenerator(model, drafter, tokenizer)
result = generator.generate("What is the capital of France?")
print(result.text) # "The capital of France is **Paris**."
print(f"Acceptance: {result.acceptance_rate:.0%}")
Important: Gemma 4 E4B is instruction-tuned. Always use the chat template (the generator handles this automatically). Raw prompts without the template produce degenerate repetitive output.
Benchmark Results (DGX Spark GB10, BF16)
Acceptance Rate by Draft Steps (Greedy)
| Draft Steps | Step-0 Acceptance | Overall Acceptance | Notes |
|---|---|---|---|
| 1 | 35% | 35% | Best standalone config |
| 2 | 35% | 22% | Steps 1+ degrade (stale KV) |
| 3 | 35% | 14% | Further degradation |
Why Standalone Speedup Is Limited
In a Python speculative decoding loop, each cycle requires β₯2 base model forward passes (verify + state update). At 35% acceptance, expected tokens per cycle = 1.35, yielding ~0.7x β slower than autoregressive. This is not a flaw in the extraction; it's an inherent limitation of Python-level speculative decoding loops.
Production serving engines (vLLM, SGLang) eliminate this overhead via fused kernels that score draft tokens within the base model's forward pass. The 35% acceptance rate would deliver meaningful speedup there. The 80% top-5 overlap is particularly valuable for tree-based speculative methods that verify multiple candidates per position.
Extraction Process
- Downloaded
gemma-4-E4B-it.litertlm(3.5 GB) fromlitert-community/gemma-4-E4B-it-litert-lm - Scanned the FlatBuffer for TFLite models β found 10 shards, MTP drafter is section 9 (43 MB)
- Parsed 191 tensors from the TFLite FlatBuffer using the
tflitePython package - Mapped 42 weight tensors using a pre-built mapping derived from the LiteRT graph JSON via Google Model Explorer
- Dequantized INT8 (per-channel) and INT4 (per-channel, packed nibbles,
low_first) to float32 - Bug found and fixed: F32 tensors (RMSNorm weights) had spurious quantization metadata in the FlatBuffer β naive dequantization zeroed them out. Fix: skip quantization path for dtype=FLOAT32 regardless of metadata.
Architecture reverse-engineered from three sources: the .tflite binary (weights/shapes), graph JSON (operator topology), and the LiteRT-LM C++ runner (external interface and loop semantics).
Reproduction
# Extract MTP shard from .litertlm
python extract_mtp_from_litertlm.py --input gemma-4-E4B-it.litertlm --output-dir ./extracted
# Convert to safetensors using pre-built mapping
python convert_with_prebuilt_mapping.py --tflite ./extracted/section9_mtp_drafter.tflite --output-dir ./drafter
# Validate
python validate_drafter.py ./drafter
Known Limitations
Quantization ceiling. The LiteRT drafter was INT4/INT8 quantized for mobile deployment. Dequantization cannot recover the original BF16 training precision. Step-0 acceptance is 35% (top-1) vs the theoretical ~80%+ achievable with full-precision weights.
E4B only. Google has not released LiteRT models for Gemma 4 26B MoE or 31B dense. The extraction pipeline generalizes with config changes if they do.
Standalone Python loop is slower than autoregressive. The per-cycle overhead of 2+ forward passes makes Python-level speculative decoding net-negative at 35% acceptance. Integration into vLLM/SGLang (which amortize verification into the base forward pass) is needed for real speedup.
Multi-step degradation. Steps 1+ use the drafter's own projected activations with stale KV caches, degrading rapidly. The drafter is most effective as a 1-step predictor.
Roadmap
- v2: EAGLE3 draft head β Train a BF16 draft head directly from E4B activations on DGX Spark, bypassing the quantization noise ceiling. Architecture mapping from this extraction serves as the blueprint.
- vLLM integration β PR to add Gemma 4 MTP as a speculative decoding draft model type.
- 26B/31B MTP β All four Gemma 4 variants were trained with MTP, but the 26B and 31B heads are currently inaccessible β no LiteRT release exists for them (LiteRT targets on-device deployment), and the HuggingFace weights have them stripped. Until Google releases them in some format, training an EAGLE3 draft head from full-precision activations is the realistic path to MTP-style speculative decoding on the larger models.
Cross-Validation
mirifiuto independently extracted the same drafter and confirmed a working forward pass. Their architecture findings converge with ours, with one correction: layer 3 has 4 heads at head_dim=512 (not 8 heads at head_dim=256 as initially reported).
Acknowledgments
- shadowlilac for initiating the community extraction effort
- mirifiuto for independent cross-validation
- Architecture reverse-engineering assisted by ChatGPT Pro 5.4; extraction pipeline, validation, and benchmarking by Claude Opus 4.6
License
Extracted weights are derived from Google's Gemma 4 E4B model β usage subject to the Gemma Terms of Use. Extraction code and PyTorch module are Apache 2.0.
Model tree for SeatownSin/gemma-4-E4B-mtp-drafter
Base model
google/gemma-4-E4B-it