dflash-mlx-universal / setup_m2.sh
tritesh's picture
Upload folder using huggingface_hub
0433390 verified
raw
history blame
3.97 kB
#!/bin/bash
# Setup script for DFlash on M2 Pro Max (96GB)
# Run: chmod +x setup_m2.sh && ./setup_m2.sh
set -e
echo "=========================================="
echo " DFlash MLX Setup for M2 Pro Max (96GB)"
echo "=========================================="
# Check architecture
echo ""
echo "[1/6] Checking system..."
ARCH=$(uname -m)
if [ "$ARCH" != "arm64" ]; then
echo "Warning: Not running on Apple Silicon (arm64). MLX may not work optimally."
fi
echo " Architecture: $ARCH"
echo " Python: $(python3 --version)"
# Create virtual environment
echo ""
echo "[2/6] Creating virtual environment..."
python3 -m venv .venv-dflash
echo " Created .venv-dflash/"
# Activate
echo ""
echo "[3/6] Installing dependencies..."
source .venv-dflash/bin/activate
pip install --upgrade pip
pip install mlx-lm
pip install dflash-mlx-universal
echo " ✓ MLX-LM installed"
echo " ✓ DFlash-MLX-Universal installed"
# Create models directory
echo ""
echo "[4/6] Setting up model directories..."
mkdir -p ~/models/dflash
mkdir -p ~/models/target
echo " Created:"
echo " ~/models/dflash/ (for converted DFlash drafters)"
echo " ~/models/target/ (for target models)"
# Download and convert a drafter
echo ""
echo "[5/6] Downloading and converting DFlash drafter..."
echo " This will download ~1GB and take 2-5 minutes."
echo ""
MODEL_CHOICE="${1:-qwen3-4b}"
case $MODEL_CHOICE in
qwen3-4b|4b|default)
DRAFTER_ID="z-lab/Qwen3-4B-DFlash-b16"
TARGET_ID="Qwen/Qwen3-4B-MLX-4bit"
OUTPUT="~/models/dflash/Qwen3-4B-DFlash-mlx"
;;
qwen3-8b|8b)
DRAFTER_ID="z-lab/Qwen3-8B-DFlash-b16"
TARGET_ID="Qwen/Qwen3-8B-MLX-4bit"
OUTPUT="~/models/dflash/Qwen3-8B-DFlash-mlx"
;;
*)
echo "Unknown model choice: $MODEL_CHOICE"
echo "Use: qwen3-4b (default) or qwen3-8b"
exit 1
;;
esac
echo " Drafter: $DRAFTER_ID"
echo " Target: $TARGET_ID"
echo " Output: $OUTPUT"
echo ""
python3 -m dflash_mlx.convert \
--model "$DRAFTER_ID" \
--output "$OUTPUT"
echo " ✓ DFlash drafter converted to MLX format"
# Quick test
echo ""
echo "[6/6] Running quick test..."
cat > /tmp/dflash_test.py << 'EOF'
import sys
sys.path.insert(0, '.')
from mlx_lm import load
from dflash_mlx import DFlashSpeculativeDecoder
from dflash_mlx.convert import load_mlx_dflash
print("Loading models...")
model, tokenizer = load("TARGET_ID")
draft, _ = load_mlx_dflash("OUTPUT")
decoder = DFlashSpeculativeDecoder(
target_model=model,
draft_model=draft,
tokenizer=tokenizer,
block_size=16,
)
print("\nGenerating test output...")
output = decoder.generate(
prompt="What is 2 + 2? Answer in one word.",
max_tokens=10,
temperature=0.0,
)
print(f"Output: {output}")
print("\n✓ DFlash is working correctly!")
EOF
sed -i '' "s|TARGET_ID|$TARGET_ID|g" /tmp/dflash_test.py
sed -i '' "s|OUTPUT|$OUTPUT|g" /tmp/dflash_test.py
python3 /tmp/dflash_test.py
# Summary
echo ""
echo "=========================================="
echo " Setup Complete!"
echo "=========================================="
echo ""
echo "To use DFlash in your projects:"
echo ""
echo " source .venv-dflash/bin/activate"
echo ""
echo " python3 -c \""
echo " from mlx_lm import load"
echo " from dflash_mlx import DFlashSpeculativeDecoder"
echo " from dflash_mlx.convert import load_mlx_dflash"
echo ""
echo " model, tokenizer = load('$TARGET_ID')"
echo " draft, _ = load_mlx_dflash('$OUTPUT')"
echo ""
echo " decoder = DFlashSpeculativeDecoder("
echo " target_model=model,"
echo " draft_model=draft,"
echo " tokenizer=tokenizer,"
echo " block_size=16,"
echo " )"
echo ""
echo " output = decoder.generate('Your prompt here')"
echo " print(output)"
echo " \""
echo ""
echo "To benchmark:"
echo " python3 benchmark_m2.py --target $TARGET_ID --draft $OUTPUT"
echo ""
echo "For more info, see M2_PRO_MAX_GUIDE.md"
echo "=========================================="