Full MTP drafter extraction with working weights and speculative decoding implementation

#3
by SeatownSin - opened

We've completed a full extraction of the Gemma 4 E4B MTP drafter weights from the LiteRT binary and published them as standard PyTorch safetensors:

SeatownSin/gemma-4-E4B-mtp-drafter

What's included

  • mtp_drafter.safetensors β€” all 42 weight tensors dequantized from INT4/INT8, 78M parameters
  • Clean PyTorch nn.Module reconstruction of the drafter architecture
  • Working speculative decoding loop with the base E4B model
  • Full extraction pipeline scripts so anyone can reproduce from the .litertlm file
  • Detailed architecture documentation and benchmark results

Key findings that may help others working on this

Architecture: The drafter is a 4-layer transformer (256 hidden dim) with Q-only attention over the base model's shared KV banks from layers 22 and 23. It has no K/V projections of its own, it borrows them from the base model's KV-sharing mechanism.

Heterogeneous head dimensions: Blocks 0-2 use head_dim=256, but block 3 uses head_dim=512. This can be misread as "8 heads" if you assume uniform head_dim, it's actually 4 heads with 512-dim, confirmed by the reshape ops in the LiteRT graph.

Attention scaling: Gemma 4 uses QK-norm with attention_scale=1.0. There is no 1/sqrt(d) scaling anywhere in the drafter. We tested adding it explicitly and it made accuracy significantly worse (top-5 overlap dropped from 80% to 20%). The fixed-scale query norms (0.9916 local, 1.0228 global) serve as the scaling mechanism.

BF16 is mandatory: The unscaled dot products in the attention are precision-sensitive. FP16 is reported to cause output degeneration after ~50 tokens.

Dequantization bug to watch for: The RMSNorm weight tensors in the TFLite FlatBuffer are stored as F32 but have spurious quantization metadata attached. Naive dequantization pipelines will multiply them by near-zero scale factors, producing all-zero norms that kill the signal. The fix is to skip the quantization path entirely for tensors with dtype=FLOAT32.

Benchmark results (DGX Spark GB10, BF16)

Metric Value
Step-0 top-1 acceptance 35% (greedy, 8 prompts)
Step-0 top-5 overlap ~80%
INT4 nibble order low_first confirmed

The 35% top-1 acceptance is apparently the ceiling for dequantized INT4/INT8 mobile weights, the quantization noise is irreversible. The 80% top-5 overlap is where the real value is: tree-based speculative decoding methods or vLLM/SGLang integration (which amortize verification overhead at the kernel level) can exploit this effectively.

Acknowledgments

This builds directly on the effort you initiated, and @mirifiuto 's independent cross-validation. The extraction pipeline, on-device validation, and benchmarking were done with assistance from Claude Opus 4.6 on a DGX Spark. Finally, Gemini 3.1 Pro served as an evaluator model to sanity-check the architectural quirks, validate the BF16 requirements, and refine the documentation.

Hope this helps move things forward, and who knows... Maybe we'll get unquantized weights someday.

I just stumbled across this GitHub project that has been working with the E2B/E4B models and successfully enabled NPU support (Mediatek & Qualcomm). It might be worthwhile to contact them and see if they can assist or collaborate.
β†’https://github.com/finnff/LLMChatApp

Sign up or log in to comment