Gemma 4 E4B MTP Drafter β€” Extracted from LiteRT

First public extraction of Google Gemma 4's Multi-Token Prediction (MTP) drafter weights from LiteRT format into standard PyTorch safetensors.

Google trained all Gemma 4 variants with MTP heads for built-in speculative decoding, then deliberately stripped them from the public HuggingFace release, retaining them exclusively in LiteRT on-device format. This repository provides the extracted drafter as a standard PyTorch module with a working speculative decoding implementation.

Key Results

Metric Value
Step-0 top-1 acceptance (greedy) 35%
Step-0 top-5 overlap with base model ~80%
Architecture fully reconstructed 42/42 tensors, all shapes verified
Outputs correct Yes β€” coherent, accurate text on all test prompts
INT4 nibble order low_first confirmed
Attention scaling None (QK-norm with scale=1.0, confirmed correct)

The 35% top-1 acceptance from dequantized INT4/INT8 mobile weights is the expected ceiling β€” quantization noise is irreversible. The 80% top-5 overlap makes these weights immediately useful for:

  • Tree-based speculative decoding (top-k draft candidates verified in parallel)
  • vLLM/SGLang native spec decode (zero-overhead verification eliminates the Python loop bottleneck)
  • Architecture reference for training EAGLE3 draft heads from full BF16 activations

Architecture

The MTP drafter is a lightweight 4-layer transformer (78M parameters) designed as a one-step recurrent cell:

Input: concat(token_embedding, projected_activation) [5120]
  β†’ Linear pre-projection [5120 β†’ 256]
  β†’ 4 transformer blocks (256 hidden, 2048 intermediate, GeGLU MLP)
      Blocks 0-2: sliding attention (window=512), head_dim=256, 4Q/2KV heads
      Block 3:    full attention,                  head_dim=512, 4Q/2KV heads
  β†’ RMSNorm β†’ parallel heads:
      LM head:     [256 β†’ 262144] with tanh soft-cap at 30.0
      Post-proj:   [256 β†’ 2560]  (recurrent: fed back as next step's activation)

Critical architectural details:

  • Q-only attention β€” the drafter owns no K/V projections. It attends over the base model's shared KV banks from layers 22 (sliding) and 23 (full attention).
  • No 1/sqrt(d) attention scaling β€” Gemma 4 uses QK-norm with attention_scale=1.0. Adding traditional scaling degrades accuracy (top-5 drops from 80% to 20%). The fixed-scale query norms (local=0.9916, global=1.0228) serve as the scaling mechanism.
  • Heterogeneous head dimensions β€” blocks 0-2 use head_dim=256, block 3 uses head_dim=512. This is often misidentified as "8 heads" for block 3; it is actually 4 heads with 512-dim, as confirmed by the reshape ops in the LiteRT graph.
  • BF16 precision required β€” Gemma 4's unscaled attention is precision-sensitive. FP16's narrower dynamic range is reported to cause output degeneration after ~50 tokens due to overflow in the unscaled dot products. Always serve in BF16 with F32 intermediate accumulation.

Files

File Description
mtp_drafter.safetensors Extracted weights (297 MB, fp32, 42 tensors, 78M params)
mtp_drafter_config.json Architecture configuration
gemma4_litert_mtp/ Python package: PyTorch module, HF patches, speculative generation

Quick Start

import torch
from transformers import AutoModelForConditionalGeneration, AutoTokenizer
from safetensors.torch import load_file
from gemma4_litert_mtp import Gemma4LiteRtMtpDrafter, LiteRtMtpDrafterConfig
from gemma4_litert_mtp.speculative_generate import Gemma4MtpSpeculativeGenerator

# Load base model (BF16 required β€” FP16 will degenerate)
model = AutoModelForConditionalGeneration.from_pretrained(
    "google/gemma-4-E4B-it", torch_dtype=torch.bfloat16, device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-4-E4B-it")

# Load extracted MTP drafter
config = LiteRtMtpDrafterConfig.from_json_file("mtp_drafter_config.json")
drafter = Gemma4LiteRtMtpDrafter(config)
drafter.load_state_dict(load_file("mtp_drafter.safetensors"), strict=True)
drafter = drafter.to(dtype=torch.bfloat16, device=model.device).eval()

# Generate with speculative decoding
generator = Gemma4MtpSpeculativeGenerator(model, drafter, tokenizer)
result = generator.generate("What is the capital of France?")
print(result.text)  # "The capital of France is **Paris**."
print(f"Acceptance: {result.acceptance_rate:.0%}")

Important: Gemma 4 E4B is instruction-tuned. Always use the chat template (the generator handles this automatically). Raw prompts without the template produce degenerate repetitive output.

Benchmark Results (DGX Spark GB10, BF16)

Acceptance Rate by Draft Steps (Greedy)

Draft Steps Step-0 Acceptance Overall Acceptance Notes
1 35% 35% Best standalone config
2 35% 22% Steps 1+ degrade (stale KV)
3 35% 14% Further degradation

Why Standalone Speedup Is Limited

In a Python speculative decoding loop, each cycle requires β‰₯2 base model forward passes (verify + state update). At 35% acceptance, expected tokens per cycle = 1.35, yielding ~0.7x β€” slower than autoregressive. This is not a flaw in the extraction; it's an inherent limitation of Python-level speculative decoding loops.

Production serving engines (vLLM, SGLang) eliminate this overhead via fused kernels that score draft tokens within the base model's forward pass. The 35% acceptance rate would deliver meaningful speedup there. The 80% top-5 overlap is particularly valuable for tree-based speculative methods that verify multiple candidates per position.

Extraction Process

  1. Downloaded gemma-4-E4B-it.litertlm (3.5 GB) from litert-community/gemma-4-E4B-it-litert-lm
  2. Scanned the FlatBuffer for TFLite models β€” found 10 shards, MTP drafter is section 9 (43 MB)
  3. Parsed 191 tensors from the TFLite FlatBuffer using the tflite Python package
  4. Mapped 42 weight tensors using a pre-built mapping derived from the LiteRT graph JSON via Google Model Explorer
  5. Dequantized INT8 (per-channel) and INT4 (per-channel, packed nibbles, low_first) to float32
  6. Bug found and fixed: F32 tensors (RMSNorm weights) had spurious quantization metadata in the FlatBuffer β€” naive dequantization zeroed them out. Fix: skip quantization path for dtype=FLOAT32 regardless of metadata.

Architecture reverse-engineered from three sources: the .tflite binary (weights/shapes), graph JSON (operator topology), and the LiteRT-LM C++ runner (external interface and loop semantics).

Reproduction

# Extract MTP shard from .litertlm
python extract_mtp_from_litertlm.py --input gemma-4-E4B-it.litertlm --output-dir ./extracted

# Convert to safetensors using pre-built mapping
python convert_with_prebuilt_mapping.py --tflite ./extracted/section9_mtp_drafter.tflite --output-dir ./drafter

# Validate
python validate_drafter.py ./drafter

Known Limitations

  1. Quantization ceiling. The LiteRT drafter was INT4/INT8 quantized for mobile deployment. Dequantization cannot recover the original BF16 training precision. Step-0 acceptance is 35% (top-1) vs the theoretical ~80%+ achievable with full-precision weights.

  2. E4B only. Google has not released LiteRT models for Gemma 4 26B MoE or 31B dense. The extraction pipeline generalizes with config changes if they do.

  3. Standalone Python loop is slower than autoregressive. The per-cycle overhead of 2+ forward passes makes Python-level speculative decoding net-negative at 35% acceptance. Integration into vLLM/SGLang (which amortize verification into the base forward pass) is needed for real speedup.

  4. Multi-step degradation. Steps 1+ use the drafter's own projected activations with stale KV caches, degrading rapidly. The drafter is most effective as a 1-step predictor.

Roadmap

  • v2: EAGLE3 draft head β€” Train a BF16 draft head directly from E4B activations on DGX Spark, bypassing the quantization noise ceiling. Architecture mapping from this extraction serves as the blueprint.
  • vLLM integration β€” PR to add Gemma 4 MTP as a speculative decoding draft model type.
  • 26B/31B MTP β€” All four Gemma 4 variants were trained with MTP, but the 26B and 31B heads are currently inaccessible β€” no LiteRT release exists for them (LiteRT targets on-device deployment), and the HuggingFace weights have them stripped. Until Google releases them in some format, training an EAGLE3 draft head from full-precision activations is the realistic path to MTP-style speculative decoding on the larger models.

Cross-Validation

mirifiuto independently extracted the same drafter and confirmed a working forward pass. Their architecture findings converge with ours, with one correction: layer 3 has 4 heads at head_dim=512 (not 8 heads at head_dim=256 as initially reported).

Acknowledgments

  • shadowlilac for initiating the community extraction effort
  • mirifiuto for independent cross-validation
  • Architecture reverse-engineering assisted by ChatGPT Pro 5.4; extraction pipeline, validation, and benchmarking by Claude Opus 4.6

License

Extracted weights are derived from Google's Gemma 4 E4B model β€” usage subject to the Gemma Terms of Use. Extraction code and PyTorch module are Apache 2.0.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for SeatownSin/gemma-4-E4B-mtp-drafter

Finetuned
(94)
this model