Qwen3-8B All-Sparse Indexer
Experimental research artifact β a trained Dynamic Sparse Attention (DSA) indexer trained at 2K context length. This repository is intended as an exploratory learned sparse-attention index, not a finished production method. The inference code is written in MLX.
A lightweight sparse-attention indexer trained to approximate dense attention behavior in Qwen/Qwen3-8B. Conceptually, this is a DeepSeek-style learned index in the sense that a small auxiliary network predicts which key-value positions are worth keeping for attention. This is an independent research artifact and is not affiliated with DeepSeek. Early results suggest the approach can work in some settings, but more research is needed.
Runs natively on Apple Silicon (M1/M2/M3/M4) via MLX. No CUDA or bitsandbytes required.
Model Description
This repository contains a Dynamic Sparse Attention (DSA) indexer checkpoint and the MLX runtime code needed to patch it into a standard Qwen3-8B model. The indexer is a small auxiliary network (~72 M parameters total, ~2 M per layer) that runs alongside the frozen base model. For every layer and every query, it predicts a top-k set of key positions that may be useful for attention, allowing exploratory sparse prefill and fixed-budget decode experiments.
Key properties:
- Base model:
mlx-community/Qwen3-8B-4bit(4-bit quantized, Apple Silicon) - Framework: MLX β runs on Apple M-series chips
- Indexer coverage: All 36 transformer layers (full-model sparse attention)
- Sparse budget: top_k = 2048 positions per query per layer
- Fixed-size decode cache: KV cache stays at exactly 2048 entries forever, even for multi-thousand-token outputs
- Current evidence: promising but preliminary; some quality and retrieval checks work, others degrade or fail depending on context length and top_k
Benchmarks
Needle-in-a-Haystack (Retrieval)
A short phrase ("The secret code is 7392") was hidden at the midpoint of a context filled with diverse natural-language prose. The model was then asked to retrieve it. These are exploratory single-task measurements, not a comprehensive long-context benchmark.
| Context | top_k | Pass | Retention | Prefill Time | Peak Memory |
|---|---|---|---|---|---|
| 2K | 2048 | β | 100% | 9s | 5.4 GB |
| 4K | 2048 | β | 55% | 22s | 6.1 GB |
| 8K | 2048 | β | 27% | 55s | 7.5 GB |
| 16K | 2048 | β | 14% | 131s | 9.2 GB |
| 16K | 4096 | β | 28% | 180s | 9.2 GB |
| 32K | 2048 | β | 7% | 306s | 13.2 GB |
Interpretation: in these runs, retrieval became much less reliable when
top_kwas too small relative to context length. A largertop_kmay help on retrieval-style tasks, but this should be treated as an experimental observation rather than a settled rule.
Memory (KV Cache)
| Context | Dense KV | Sparse KV (top_k=2048) | Savings |
|---|---|---|---|
| 4K | 604 MB | 302 MB | 50% |
| 8K | 1,208 MB | 302 MB | 75% |
| 16K | 2,416 MB | 302 MB | 87.5% |
| 32K | 4,832 MB | 302 MB | 93.7% |
Measured on Qwen3-8B 4-bit, Apple Silicon, MLX.
Quality
| Benchmark | Dense | Sparse (top_k=2048) |
|---|---|---|
| GSM8K accuracy (4-shot) | 95% | 92% |
| PPL on C4 (seq_len=2048) | 13.526 | 13.533 (+0.058%) |
| PPL on C4 (seq_len=8192) | 15.628 | 15.653 (+0.16%) |
Training Details
| Parameter | Value |
|---|---|
| Base model | Qwen/Qwen3-8B |
| Quantization | 4-bit (MLX) |
| Training dataset | allenai/c4 (English split) |
| Training tokens | 15 000 000 |
| Validation tokens | 1 000 000 |
| Sequence length | 2048 |
| Sparse layers | All 36 (layers 0β35) |
| top_k | 2048 |
| Indexer heads | 6 |
| Projection dim | 69 |
| RoPE dim | 64 |
| Parameters per layer | ~2 003 427 |
| Total indexer parameters | ~72 123 372 |
| Loss aggregation | per-layer |
| Support loss weight | 0.1 |
| LR schedule | warmup-cosine (5% warmup, min LR 1e-6) |
Files
| File | Description |
|---|---|
lightning_indexer_best_assembled.safetensors |
Sparse indexer checkpoint β safetensors format |
run_config.json |
Training and sparse-layer configuration |
eval_sparse_generate.py |
Sparse patching + MLX runtime + GSM8K evaluation |
demo.py |
One-command demo: loads model + indexer, runs sample prompts |
requirements.txt |
Runtime dependencies (mlx, mlx-lm, safetensors, numpy, datasets) |
ppl_results_assembled.json |
Dense vs. sparse perplexity evaluation summary |
Quick Start
# Requires Python 3.9+ and an Apple Silicon Mac (M1/M2/M3/M4)
# Clone or download this repository, then enter the repo root
# If cloning from the Hugging Face Hub, use your actual repo URL.
pip install -r requirements.txt
python demo.py
The demo automatically downloads mlx-community/Qwen3-8B-4bit (~5 GB), loads the sparse indexer, and runs three sample prompts showing dense vs. sparse output side by side.
More options
# Single custom prompt
python demo.py --prompt "What causes the northern lights?"
# Interactive chat REPL
python demo.py --interactive
# GSM8K accuracy eval β dense baseline
python eval_sparse_generate.py --limit 100
# GSM8K accuracy eval β sparse (fixed-2K decode cache)
python eval_sparse_generate.py --limit 100 \
--indexer-path lightning_indexer_best_assembled.safetensors \
--run-config run_config.json
# Override top-k budget
python eval_sparse_generate.py --top-k 1024 \
--indexer-path lightning_indexer_best_assembled.safetensors \
--run-config run_config.json --limit 50
Programmatic usage
import json
import mlx.core as mx
from pathlib import Path
from mlx_lm.utils import load as mlx_load
from eval_sparse_generate import load_indexers, patch_sparse_generate
# Load base model (auto-downloads from HF on first run)
model, tokenizer = mlx_load(
"mlx-community/Qwen3-8B-4bit",
tokenizer_config={"trust_remote_code": True},
)
mx.eval(model.parameters()); mx.synchronize()
# Load indexers
rc = json.loads(Path("run_config.json").read_text())
dim = int(rc.get("hidden_size", rc.get("metadata", {}).get("hidden_size", 4096)))
indexers = load_indexers(
"lightning_indexer_best_assembled.safetensors",
dim=dim,
proj_dim=rc["proj_dim"],
n_heads=rc["indexer_heads"],
rope_dim=rc["rope_dim"],
)
mx.eval([idx.parameters() for idx in indexers.values()])
# Patch β attention is now sparse for ALL steps
clear_fn = patch_sparse_generate(model, indexers, top_k=2048)
# Generate as normal
from mlx_lm.generate import generate_step
prompt = "<|im_start|>user\nHello!<|im_end|>\n<|im_start|>assistant\n"
input_ids = mx.array(tokenizer.encode(prompt))
tokens = []
for tok, _ in generate_step(input_ids, model, max_tokens=128,
sampler=lambda x: mx.argmax(x, axis=-1),
prefill_step_size=int(input_ids.shape[0])):
t = int(tok.item())
if t == tokenizer.eos_token_id: break
tokens.append(t)
print(tokenizer.decode(tokens))
# Between prompts, clear indexer state
clear_fn()
Note: The indexer monkey-patches
block.__call__andmodel.make_cacheat runtime. It does not usefrom_pretrained(). You needeval_sparse_generate.pyalongside the checkpoint.
Decode Cache Design
After prefill, each layer's KV cache is pruned to exactly top_k = 2048 entries. During decode, every new token is scored by the indexer against all top_k+1 candidates (top_k cached + 1 new), and the lowest-scoring entry is evicted. The cache stays at exactly 2048 entries regardless of how many tokens are generated. This means:
- O(top_k) per decode step (not O(seq_len))
- Constant memory during generation β the KV cache never grows
- Designed to keep decode memory bounded even for much longer generations, though broader validation is still needed
Limitations
- Requires Apple Silicon (M1/M2/M3/M4). Runs on Intel Mac CPU but will be slow.
- Trained on English C4 data only. Other languages / strongly out-of-distribution domains not evaluated.
- Fixed top_k = 2048 across all layers. Per-layer adaptive budgets may improve further.
- Tested with
mlx-community/Qwen3-8B-4bit. Other quantization levels not validated. - Training used sequences of length 2048. Very long contexts (> 32K) are extrapolated.
- Quality and retrieval behavior are still preliminary and can vary materially with context length, prompt type, and
top_k. - Some exploratory retrieval tests pass at 2K-8K with
top_k=2048and at 16K withtop_k=4096, but this should not be read as a general guarantee. - The current MLX runtime is a research implementation. Some measured regimes still show worse latency than dense baselines even when KV-cache scaling looks better.
Citation
If you reference this artifact directly, cite the published Hugging Face repository URL for this model card or your associated paper.
This repository is a custom-code MLX runtime artifact rather than a standard from_pretrained() Transformers checkpoint, so cite the specific published repo you upload.
Quantized