File size: 2,495 Bytes
0433390
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""
Example: DFlash speculative decoding with Qwen3-4B on MLX.

This demonstrates using a pre-converted DFlash drafter with the Qwen3-4B
model on Apple Silicon.

Prerequisites:
    pip install mlx-lm dflash-mlx-universal
    
    # Convert the drafter (one-time)
    python -m dflash_mlx.convert \
        --model z-lab/Qwen3-4B-DFlash-b16 \
        --output ./Qwen3-4B-DFlash-mlx
"""

from mlx_lm import load
from dflash_mlx import DFlashSpeculativeDecoder
from dflash_mlx.convert import load_mlx_dflash


def main():
    print("=" * 60)
    print("DFlash Speculative Decoding Demo - Qwen3-4B")
    print("=" * 60)

    # 1. Load target model (MLX-converted)
    print("\n[1] Loading Qwen3-4B target model...")
    model, tokenizer = load("Qwen/Qwen3-4B-MLX-4bit")
    print("    ✓ Target model loaded")

    # 2. Load converted DFlash drafter
    print("\n[2] Loading DFlash drafter...")
    draft_model, draft_config = load_mlx_dflash("./Qwen3-4B-DFlash-mlx")
    print(f"    ✓ Drafter loaded ({draft_config['num_hidden_layers']} layers)")

    # 3. Create decoder
    print("\n[3] Creating DFlash speculative decoder...")
    decoder = DFlashSpeculativeDecoder(
        target_model=model,
        draft_model=draft_model,
        tokenizer=tokenizer,
        block_size=draft_config.get("block_size", 16),
    )

    # 4. Generate
    print("\n[4] Generating with DFlash speculative decoding...")
    prompt = "Write a Python function to implement quicksort."
    
    print(f"\nPrompt: {prompt}")
    print("-" * 60)
    
    output = decoder.generate(
        prompt=prompt,
        max_tokens=1024,
        temperature=0.0,
    )
    
    print(output)
    print("-" * 60)

    # 5. Compare with baseline
    print("\n[5] Running baseline (no speculative decoding)...")
    
    import time
    
    # Baseline
    start = time.time()
    baseline_output = model.generate(
        tokenizer.encode(prompt),
        max_tokens=512,
        temp=0.0,
    )
    baseline_time = time.time() - start
    
    # DFlash
    start = time.time()
    dflash_output = decoder.generate(
        prompt=prompt,
        max_tokens=512,
        temperature=0.0,
    )
    dflash_time = time.time() - start
    
    speedup = baseline_time / dflash_time
    print(f"\nBaseline: {baseline_time:.2f}s")
    print(f"DFlash:   {dflash_time:.2f}s")
    print(f"Speedup:  {speedup:.2f}x")

    print("\n" + "=" * 60)
    print("Demo complete!")
    print("=" * 60)


if __name__ == "__main__":
    main()