| #!/bin/bash |
| |
| |
|
|
| set -e |
|
|
| echo "==========================================" |
| echo " DFlash MLX Setup for M2 Pro Max (96GB)" |
| echo "==========================================" |
|
|
| |
| echo "" |
| echo "[1/6] Checking system..." |
| ARCH=$(uname -m) |
| if [ "$ARCH" != "arm64" ]; then |
| echo "Warning: Not running on Apple Silicon (arm64). MLX may not work optimally." |
| fi |
|
|
| echo " Architecture: $ARCH" |
| echo " Python: $(python3 --version)" |
|
|
| |
| echo "" |
| echo "[2/6] Creating virtual environment..." |
| python3 -m venv .venv-dflash |
| echo " Created .venv-dflash/" |
|
|
| |
| echo "" |
| echo "[3/6] Installing dependencies..." |
| source .venv-dflash/bin/activate |
|
|
| pip install --upgrade pip |
| pip install mlx-lm |
| pip install dflash-mlx-universal |
|
|
| echo " ✓ MLX-LM installed" |
| echo " ✓ DFlash-MLX-Universal installed" |
|
|
| |
| echo "" |
| echo "[4/6] Setting up model directories..." |
| mkdir -p ~/models/dflash |
| mkdir -p ~/models/target |
|
|
| echo " Created:" |
| echo " ~/models/dflash/ (for converted DFlash drafters)" |
| echo " ~/models/target/ (for target models)" |
|
|
| |
| echo "" |
| echo "[5/6] Downloading and converting DFlash drafter..." |
| echo " This will download ~1GB and take 2-5 minutes." |
| echo "" |
|
|
| MODEL_CHOICE="${1:-qwen3-4b}" |
|
|
| case $MODEL_CHOICE in |
| qwen3-4b|4b|default) |
| DRAFTER_ID="z-lab/Qwen3-4B-DFlash-b16" |
| TARGET_ID="Qwen/Qwen3-4B-MLX-4bit" |
| OUTPUT="~/models/dflash/Qwen3-4B-DFlash-mlx" |
| ;; |
| qwen3-8b|8b) |
| DRAFTER_ID="z-lab/Qwen3-8B-DFlash-b16" |
| TARGET_ID="Qwen/Qwen3-8B-MLX-4bit" |
| OUTPUT="~/models/dflash/Qwen3-8B-DFlash-mlx" |
| ;; |
| *) |
| echo "Unknown model choice: $MODEL_CHOICE" |
| echo "Use: qwen3-4b (default) or qwen3-8b" |
| exit 1 |
| ;; |
| esac |
|
|
| echo " Drafter: $DRAFTER_ID" |
| echo " Target: $TARGET_ID" |
| echo " Output: $OUTPUT" |
| echo "" |
|
|
| python3 -m dflash_mlx.convert \ |
| --model "$DRAFTER_ID" \ |
| --output "$OUTPUT" |
|
|
| echo " ✓ DFlash drafter converted to MLX format" |
|
|
| |
| echo "" |
| echo "[6/6] Running quick test..." |
| cat > /tmp/dflash_test.py << 'EOF' |
| import sys |
| sys.path.insert(0, '.') |
| from mlx_lm import load |
| from dflash_mlx import DFlashSpeculativeDecoder |
| from dflash_mlx.convert import load_mlx_dflash |
|
|
| print("Loading models...") |
| model, tokenizer = load("TARGET_ID") |
| draft, _ = load_mlx_dflash("OUTPUT") |
|
|
| decoder = DFlashSpeculativeDecoder( |
| target_model=model, |
| draft_model=draft, |
| tokenizer=tokenizer, |
| block_size=16, |
| ) |
|
|
| print("\nGenerating test output...") |
| output = decoder.generate( |
| prompt="What is 2 + 2? Answer in one word.", |
| max_tokens=10, |
| temperature=0.0, |
| ) |
| print(f"Output: {output}") |
| print("\n✓ DFlash is working correctly!") |
| EOF |
|
|
| sed -i '' "s|TARGET_ID|$TARGET_ID|g" /tmp/dflash_test.py |
| sed -i '' "s|OUTPUT|$OUTPUT|g" /tmp/dflash_test.py |
|
|
| python3 /tmp/dflash_test.py |
|
|
| |
| echo "" |
| echo "==========================================" |
| echo " Setup Complete!" |
| echo "==========================================" |
| echo "" |
| echo "To use DFlash in your projects:" |
| echo "" |
| echo " source .venv-dflash/bin/activate" |
| echo "" |
| echo " python3 -c \"" |
| echo " from mlx_lm import load" |
| echo " from dflash_mlx import DFlashSpeculativeDecoder" |
| echo " from dflash_mlx.convert import load_mlx_dflash" |
| echo "" |
| echo " model, tokenizer = load('$TARGET_ID')" |
| echo " draft, _ = load_mlx_dflash('$OUTPUT')" |
| echo "" |
| echo " decoder = DFlashSpeculativeDecoder(" |
| echo " target_model=model," |
| echo " draft_model=draft," |
| echo " tokenizer=tokenizer," |
| echo " block_size=16," |
| echo " )" |
| echo "" |
| echo " output = decoder.generate('Your prompt here')" |
| echo " print(output)" |
| echo " \"" |
| echo "" |
| echo "To benchmark:" |
| echo " python3 benchmark_m2.py --target $TARGET_ID --draft $OUTPUT" |
| echo "" |
| echo "For more info, see M2_PRO_MAX_GUIDE.md" |
| echo "==========================================" |
|
|