Qwen3-8B All-Sparse Indexer

Experimental research artifact β€” a trained Dynamic Sparse Attention (DSA) indexer trained at 2K context length. This repository is intended as an exploratory learned sparse-attention index, not a finished production method. The inference code is written in MLX.

A lightweight sparse-attention indexer trained to approximate dense attention behavior in Qwen/Qwen3-8B. Conceptually, this is a DeepSeek-style learned index in the sense that a small auxiliary network predicts which key-value positions are worth keeping for attention. This is an independent research artifact and is not affiliated with DeepSeek. Early results suggest the approach can work in some settings, but more research is needed.

Runs natively on Apple Silicon (M1/M2/M3/M4) via MLX. No CUDA or bitsandbytes required.

Model Description

This repository contains a Dynamic Sparse Attention (DSA) indexer checkpoint and the MLX runtime code needed to patch it into a standard Qwen3-8B model. The indexer is a small auxiliary network (~72 M parameters total, ~2 M per layer) that runs alongside the frozen base model. For every layer and every query, it predicts a top-k set of key positions that may be useful for attention, allowing exploratory sparse prefill and fixed-budget decode experiments.

Key properties:

  • Base model: mlx-community/Qwen3-8B-4bit (4-bit quantized, Apple Silicon)
  • Framework: MLX β€” runs on Apple M-series chips
  • Indexer coverage: All 36 transformer layers (full-model sparse attention)
  • Sparse budget: top_k = 2048 positions per query per layer
  • Fixed-size decode cache: KV cache stays at exactly 2048 entries forever, even for multi-thousand-token outputs
  • Current evidence: promising but preliminary; some quality and retrieval checks work, others degrade or fail depending on context length and top_k

Benchmarks

Needle-in-a-Haystack (Retrieval)

A short phrase ("The secret code is 7392") was hidden at the midpoint of a context filled with diverse natural-language prose. The model was then asked to retrieve it. These are exploratory single-task measurements, not a comprehensive long-context benchmark.

Context top_k Pass Retention Prefill Time Peak Memory
2K 2048 βœ… 100% 9s 5.4 GB
4K 2048 βœ… 55% 22s 6.1 GB
8K 2048 βœ… 27% 55s 7.5 GB
16K 2048 ❌ 14% 131s 9.2 GB
16K 4096 βœ… 28% 180s 9.2 GB
32K 2048 ❌ 7% 306s 13.2 GB

Interpretation: in these runs, retrieval became much less reliable when top_k was too small relative to context length. A larger top_k may help on retrieval-style tasks, but this should be treated as an experimental observation rather than a settled rule.

Memory (KV Cache)

Context Dense KV Sparse KV (top_k=2048) Savings
4K 604 MB 302 MB 50%
8K 1,208 MB 302 MB 75%
16K 2,416 MB 302 MB 87.5%
32K 4,832 MB 302 MB 93.7%

Measured on Qwen3-8B 4-bit, Apple Silicon, MLX.

Quality

Benchmark Dense Sparse (top_k=2048)
GSM8K accuracy (4-shot) 95% 92%
PPL on C4 (seq_len=2048) 13.526 13.533 (+0.058%)
PPL on C4 (seq_len=8192) 15.628 15.653 (+0.16%)

Training Details

Parameter Value
Base model Qwen/Qwen3-8B
Quantization 4-bit (MLX)
Training dataset allenai/c4 (English split)
Training tokens 15 000 000
Validation tokens 1 000 000
Sequence length 2048
Sparse layers All 36 (layers 0–35)
top_k 2048
Indexer heads 6
Projection dim 69
RoPE dim 64
Parameters per layer ~2 003 427
Total indexer parameters ~72 123 372
Loss aggregation per-layer
Support loss weight 0.1
LR schedule warmup-cosine (5% warmup, min LR 1e-6)

Files

File Description
lightning_indexer_best_assembled.safetensors Sparse indexer checkpoint β€” safetensors format
run_config.json Training and sparse-layer configuration
eval_sparse_generate.py Sparse patching + MLX runtime + GSM8K evaluation
demo.py One-command demo: loads model + indexer, runs sample prompts
requirements.txt Runtime dependencies (mlx, mlx-lm, safetensors, numpy, datasets)
ppl_results_assembled.json Dense vs. sparse perplexity evaluation summary

Quick Start

# Requires Python 3.9+ and an Apple Silicon Mac (M1/M2/M3/M4)
# Clone or download this repository, then enter the repo root
# If cloning from the Hugging Face Hub, use your actual repo URL.
pip install -r requirements.txt
python demo.py

The demo automatically downloads mlx-community/Qwen3-8B-4bit (~5 GB), loads the sparse indexer, and runs three sample prompts showing dense vs. sparse output side by side.

More options

# Single custom prompt
python demo.py --prompt "What causes the northern lights?"

# Interactive chat REPL
python demo.py --interactive

# GSM8K accuracy eval β€” dense baseline
python eval_sparse_generate.py --limit 100

# GSM8K accuracy eval β€” sparse (fixed-2K decode cache)
python eval_sparse_generate.py --limit 100 \
    --indexer-path lightning_indexer_best_assembled.safetensors \
    --run-config run_config.json

# Override top-k budget
python eval_sparse_generate.py --top-k 1024 \
    --indexer-path lightning_indexer_best_assembled.safetensors \
    --run-config run_config.json --limit 50

Programmatic usage

import json
import mlx.core as mx
from pathlib import Path
from mlx_lm.utils import load as mlx_load
from eval_sparse_generate import load_indexers, patch_sparse_generate

# Load base model (auto-downloads from HF on first run)
model, tokenizer = mlx_load(
    "mlx-community/Qwen3-8B-4bit",
    tokenizer_config={"trust_remote_code": True},
)
mx.eval(model.parameters()); mx.synchronize()

# Load indexers
rc = json.loads(Path("run_config.json").read_text())
dim = int(rc.get("hidden_size", rc.get("metadata", {}).get("hidden_size", 4096)))
indexers = load_indexers(
    "lightning_indexer_best_assembled.safetensors",
    dim=dim,
    proj_dim=rc["proj_dim"],
    n_heads=rc["indexer_heads"],
    rope_dim=rc["rope_dim"],
)
mx.eval([idx.parameters() for idx in indexers.values()])

# Patch β€” attention is now sparse for ALL steps
clear_fn = patch_sparse_generate(model, indexers, top_k=2048)

# Generate as normal
from mlx_lm.generate import generate_step
prompt = "<|im_start|>user\nHello!<|im_end|>\n<|im_start|>assistant\n"
input_ids = mx.array(tokenizer.encode(prompt))
tokens = []
for tok, _ in generate_step(input_ids, model, max_tokens=128,
                             sampler=lambda x: mx.argmax(x, axis=-1),
                             prefill_step_size=int(input_ids.shape[0])):
    t = int(tok.item())
    if t == tokenizer.eos_token_id: break
    tokens.append(t)
print(tokenizer.decode(tokens))

# Between prompts, clear indexer state
clear_fn()

Note: The indexer monkey-patches block.__call__ and model.make_cache at runtime. It does not use from_pretrained(). You need eval_sparse_generate.py alongside the checkpoint.

Decode Cache Design

After prefill, each layer's KV cache is pruned to exactly top_k = 2048 entries. During decode, every new token is scored by the indexer against all top_k+1 candidates (top_k cached + 1 new), and the lowest-scoring entry is evicted. The cache stays at exactly 2048 entries regardless of how many tokens are generated. This means:

  • O(top_k) per decode step (not O(seq_len))
  • Constant memory during generation β€” the KV cache never grows
  • Designed to keep decode memory bounded even for much longer generations, though broader validation is still needed

Limitations

  • Requires Apple Silicon (M1/M2/M3/M4). Runs on Intel Mac CPU but will be slow.
  • Trained on English C4 data only. Other languages / strongly out-of-distribution domains not evaluated.
  • Fixed top_k = 2048 across all layers. Per-layer adaptive budgets may improve further.
  • Tested with mlx-community/Qwen3-8B-4bit. Other quantization levels not validated.
  • Training used sequences of length 2048. Very long contexts (> 32K) are extrapolated.
  • Quality and retrieval behavior are still preliminary and can vary materially with context length, prompt type, and top_k.
  • Some exploratory retrieval tests pass at 2K-8K with top_k=2048 and at 16K with top_k=4096, but this should not be read as a general guarantee.
  • The current MLX runtime is a research implementation. Some measured regimes still show worse latency than dense baselines even when KV-cache scaling looks better.

Citation

If you reference this artifact directly, cite the published Hugging Face repository URL for this model card or your associated paper.

This repository is a custom-code MLX runtime artifact rather than a standard from_pretrained() Transformers checkpoint, so cite the specific published repo you upload.

Downloads last month

-

Downloads are not tracked for this model. How to track
MLX
Hardware compatibility
Log In to add your hardware

Quantized

Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for rp440/Qwen3-8b-DSA-index

Finetuned
Qwen/Qwen3-8B
Finetuned
(1458)
this model

Dataset used to train rp440/Qwen3-8b-DSA-index