File size: 3,972 Bytes
0433390
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
#!/bin/bash
# Setup script for DFlash on M2 Pro Max (96GB)
# Run: chmod +x setup_m2.sh && ./setup_m2.sh

set -e

echo "=========================================="
echo " DFlash MLX Setup for M2 Pro Max (96GB)"
echo "=========================================="

# Check architecture
echo ""
echo "[1/6] Checking system..."
ARCH=$(uname -m)
if [ "$ARCH" != "arm64" ]; then
    echo "Warning: Not running on Apple Silicon (arm64). MLX may not work optimally."
fi

echo "  Architecture: $ARCH"
echo "  Python: $(python3 --version)"

# Create virtual environment
echo ""
echo "[2/6] Creating virtual environment..."
python3 -m venv .venv-dflash
echo "  Created .venv-dflash/"

# Activate
echo ""
echo "[3/6] Installing dependencies..."
source .venv-dflash/bin/activate

pip install --upgrade pip
pip install mlx-lm
pip install dflash-mlx-universal

echo "  ✓ MLX-LM installed"
echo "  ✓ DFlash-MLX-Universal installed"

# Create models directory
echo ""
echo "[4/6] Setting up model directories..."
mkdir -p ~/models/dflash
mkdir -p ~/models/target

echo "  Created:"
echo "    ~/models/dflash/  (for converted DFlash drafters)"
echo "    ~/models/target/  (for target models)"

# Download and convert a drafter
echo ""
echo "[5/6] Downloading and converting DFlash drafter..."
echo "  This will download ~1GB and take 2-5 minutes."
echo ""

MODEL_CHOICE="${1:-qwen3-4b}"

case $MODEL_CHOICE in
    qwen3-4b|4b|default)
        DRAFTER_ID="z-lab/Qwen3-4B-DFlash-b16"
        TARGET_ID="Qwen/Qwen3-4B-MLX-4bit"
        OUTPUT="~/models/dflash/Qwen3-4B-DFlash-mlx"
        ;;
    qwen3-8b|8b)
        DRAFTER_ID="z-lab/Qwen3-8B-DFlash-b16"
        TARGET_ID="Qwen/Qwen3-8B-MLX-4bit"
        OUTPUT="~/models/dflash/Qwen3-8B-DFlash-mlx"
        ;;
    *)
        echo "Unknown model choice: $MODEL_CHOICE"
        echo "Use: qwen3-4b (default) or qwen3-8b"
        exit 1
        ;;
esac

echo "  Drafter: $DRAFTER_ID"
echo "  Target:  $TARGET_ID"
echo "  Output:  $OUTPUT"
echo ""

python3 -m dflash_mlx.convert \
    --model "$DRAFTER_ID" \
    --output "$OUTPUT"

echo "  ✓ DFlash drafter converted to MLX format"

# Quick test
echo ""
echo "[6/6] Running quick test..."
cat > /tmp/dflash_test.py << 'EOF'
import sys
sys.path.insert(0, '.')
from mlx_lm import load
from dflash_mlx import DFlashSpeculativeDecoder
from dflash_mlx.convert import load_mlx_dflash

print("Loading models...")
model, tokenizer = load("TARGET_ID")
draft, _ = load_mlx_dflash("OUTPUT")

decoder = DFlashSpeculativeDecoder(
    target_model=model,
    draft_model=draft,
    tokenizer=tokenizer,
    block_size=16,
)

print("\nGenerating test output...")
output = decoder.generate(
    prompt="What is 2 + 2? Answer in one word.",
    max_tokens=10,
    temperature=0.0,
)
print(f"Output: {output}")
print("\n✓ DFlash is working correctly!")
EOF

sed -i '' "s|TARGET_ID|$TARGET_ID|g" /tmp/dflash_test.py
sed -i '' "s|OUTPUT|$OUTPUT|g" /tmp/dflash_test.py

python3 /tmp/dflash_test.py

# Summary
echo ""
echo "=========================================="
echo " Setup Complete!"
echo "=========================================="
echo ""
echo "To use DFlash in your projects:"
echo ""
echo "  source .venv-dflash/bin/activate"
echo ""
echo "  python3 -c \""
echo "  from mlx_lm import load"
echo "  from dflash_mlx import DFlashSpeculativeDecoder"
echo "  from dflash_mlx.convert import load_mlx_dflash"
echo ""
echo "  model, tokenizer = load('$TARGET_ID')"
echo "  draft, _ = load_mlx_dflash('$OUTPUT')"
echo ""
echo "  decoder = DFlashSpeculativeDecoder("
echo "      target_model=model,"
echo "      draft_model=draft,"
echo "      tokenizer=tokenizer,"
echo "      block_size=16,"
echo "  )"
echo ""
echo "  output = decoder.generate('Your prompt here')"
echo "  print(output)"
echo "  \""
echo ""
echo "To benchmark:"
echo "  python3 benchmark_m2.py --target $TARGET_ID --draft $OUTPUT"
echo ""
echo "For more info, see M2_PRO_MAX_GUIDE.md"
echo "=========================================="