#!/bin/bash # Setup script for DFlash on M2 Pro Max (96GB) # Run: chmod +x setup_m2.sh && ./setup_m2.sh set -e echo "==========================================" echo " DFlash MLX Setup for M2 Pro Max (96GB)" echo "==========================================" # Check architecture echo "" echo "[1/6] Checking system..." ARCH=$(uname -m) if [ "$ARCH" != "arm64" ]; then echo "Warning: Not running on Apple Silicon (arm64). MLX may not work optimally." fi echo " Architecture: $ARCH" echo " Python: $(python3 --version)" # Create virtual environment echo "" echo "[2/6] Creating virtual environment..." python3 -m venv .venv-dflash echo " Created .venv-dflash/" # Activate echo "" echo "[3/6] Installing dependencies..." source .venv-dflash/bin/activate pip install --upgrade pip pip install mlx-lm pip install dflash-mlx-universal echo " āœ“ MLX-LM installed" echo " āœ“ DFlash-MLX-Universal installed" # Create models directory echo "" echo "[4/6] Setting up model directories..." mkdir -p ~/models/dflash mkdir -p ~/models/target echo " Created:" echo " ~/models/dflash/ (for converted DFlash drafters)" echo " ~/models/target/ (for target models)" # Download and convert a drafter echo "" echo "[5/6] Downloading and converting DFlash drafter..." echo " This will download ~1GB and take 2-5 minutes." echo "" MODEL_CHOICE="${1:-qwen3-4b}" case $MODEL_CHOICE in qwen3-4b|4b|default) DRAFTER_ID="z-lab/Qwen3-4B-DFlash-b16" TARGET_ID="Qwen/Qwen3-4B-MLX-4bit" OUTPUT="~/models/dflash/Qwen3-4B-DFlash-mlx" ;; qwen3-8b|8b) DRAFTER_ID="z-lab/Qwen3-8B-DFlash-b16" TARGET_ID="Qwen/Qwen3-8B-MLX-4bit" OUTPUT="~/models/dflash/Qwen3-8B-DFlash-mlx" ;; *) echo "Unknown model choice: $MODEL_CHOICE" echo "Use: qwen3-4b (default) or qwen3-8b" exit 1 ;; esac echo " Drafter: $DRAFTER_ID" echo " Target: $TARGET_ID" echo " Output: $OUTPUT" echo "" python3 -m dflash_mlx.convert \ --model "$DRAFTER_ID" \ --output "$OUTPUT" echo " āœ“ DFlash drafter converted to MLX format" # Quick test echo "" echo "[6/6] Running quick test..." cat > /tmp/dflash_test.py << 'EOF' import sys sys.path.insert(0, '.') from mlx_lm import load from dflash_mlx import DFlashSpeculativeDecoder from dflash_mlx.convert import load_mlx_dflash print("Loading models...") model, tokenizer = load("TARGET_ID") draft, _ = load_mlx_dflash("OUTPUT") decoder = DFlashSpeculativeDecoder( target_model=model, draft_model=draft, tokenizer=tokenizer, block_size=16, ) print("\nGenerating test output...") output = decoder.generate( prompt="What is 2 + 2? Answer in one word.", max_tokens=10, temperature=0.0, ) print(f"Output: {output}") print("\nāœ“ DFlash is working correctly!") EOF sed -i '' "s|TARGET_ID|$TARGET_ID|g" /tmp/dflash_test.py sed -i '' "s|OUTPUT|$OUTPUT|g" /tmp/dflash_test.py python3 /tmp/dflash_test.py # Summary echo "" echo "==========================================" echo " Setup Complete!" echo "==========================================" echo "" echo "To use DFlash in your projects:" echo "" echo " source .venv-dflash/bin/activate" echo "" echo " python3 -c \"" echo " from mlx_lm import load" echo " from dflash_mlx import DFlashSpeculativeDecoder" echo " from dflash_mlx.convert import load_mlx_dflash" echo "" echo " model, tokenizer = load('$TARGET_ID')" echo " draft, _ = load_mlx_dflash('$OUTPUT')" echo "" echo " decoder = DFlashSpeculativeDecoder(" echo " target_model=model," echo " draft_model=draft," echo " tokenizer=tokenizer," echo " block_size=16," echo " )" echo "" echo " output = decoder.generate('Your prompt here')" echo " print(output)" echo " \"" echo "" echo "To benchmark:" echo " python3 benchmark_m2.py --target $TARGET_ID --draft $OUTPUT" echo "" echo "For more info, see M2_PRO_MAX_GUIDE.md" echo "=========================================="