dflash-mlx-universal / examples /qwen3_4b_demo.py
tritesh's picture
Upload folder using huggingface_hub
0433390 verified
"""
Example: DFlash speculative decoding with Qwen3-4B on MLX.
This demonstrates using a pre-converted DFlash drafter with the Qwen3-4B
model on Apple Silicon.
Prerequisites:
pip install mlx-lm dflash-mlx-universal
# Convert the drafter (one-time)
python -m dflash_mlx.convert \
--model z-lab/Qwen3-4B-DFlash-b16 \
--output ./Qwen3-4B-DFlash-mlx
"""
from mlx_lm import load
from dflash_mlx import DFlashSpeculativeDecoder
from dflash_mlx.convert import load_mlx_dflash
def main():
print("=" * 60)
print("DFlash Speculative Decoding Demo - Qwen3-4B")
print("=" * 60)
# 1. Load target model (MLX-converted)
print("\n[1] Loading Qwen3-4B target model...")
model, tokenizer = load("Qwen/Qwen3-4B-MLX-4bit")
print(" ✓ Target model loaded")
# 2. Load converted DFlash drafter
print("\n[2] Loading DFlash drafter...")
draft_model, draft_config = load_mlx_dflash("./Qwen3-4B-DFlash-mlx")
print(f" ✓ Drafter loaded ({draft_config['num_hidden_layers']} layers)")
# 3. Create decoder
print("\n[3] Creating DFlash speculative decoder...")
decoder = DFlashSpeculativeDecoder(
target_model=model,
draft_model=draft_model,
tokenizer=tokenizer,
block_size=draft_config.get("block_size", 16),
)
# 4. Generate
print("\n[4] Generating with DFlash speculative decoding...")
prompt = "Write a Python function to implement quicksort."
print(f"\nPrompt: {prompt}")
print("-" * 60)
output = decoder.generate(
prompt=prompt,
max_tokens=1024,
temperature=0.0,
)
print(output)
print("-" * 60)
# 5. Compare with baseline
print("\n[5] Running baseline (no speculative decoding)...")
import time
# Baseline
start = time.time()
baseline_output = model.generate(
tokenizer.encode(prompt),
max_tokens=512,
temp=0.0,
)
baseline_time = time.time() - start
# DFlash
start = time.time()
dflash_output = decoder.generate(
prompt=prompt,
max_tokens=512,
temperature=0.0,
)
dflash_time = time.time() - start
speedup = baseline_time / dflash_time
print(f"\nBaseline: {baseline_time:.2f}s")
print(f"DFlash: {dflash_time:.2f}s")
print(f"Speedup: {speedup:.2f}x")
print("\n" + "=" * 60)
print("Demo complete!")
print("=" * 60)
if __name__ == "__main__":
main()