File size: 1,639 Bytes
0433390 5fdf20a 0433390 5fdf20a 0433390 5fdf20a 0433390 5fdf20a 0433390 5fdf20a 0433390 5fdf20a 0433390 5fdf20a 0433390 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 | """
DFlash-MLX-Universal: Block Diffusion Speculative Decoding for MLX
A universal MLX implementation of DFlash that works with any MLX-converted model.
Optimized for Apple Silicon (M2/M3/M4 Pro/Max/Ultra) and compatible with all
major model families: Qwen3, Qwen3.5, LLaMA, Mistral, Gemma.
Key features:
- Architecture-agnostic adapters for any MLX model family
- KV cache management with proper trim/rewind on rejection
- Streaming generation with real-time text output
- OpenAI-compatible server (via serve.py)
- Training custom drafters on-the-fly (via trainer.py)
- Conversion of PyTorch DFlash drafters to MLX (via convert.py)
- Benchmarking and diagnostics tools
"""
from .adapters import (
MLXTargetAdapter,
Qwen3Adapter,
Qwen35Adapter,
LlamaAdapter,
MistralAdapter,
GemmaAdapter,
LoadedTargetModel,
load_target_model,
adapter_for_model_type,
detect_model_architecture,
)
from .model import DFlashDraftModel, DFlashDenoiser
from .speculative_decode import DFlashSpeculativeDecoder
from .convert import convert_dflash_to_mlx, load_mlx_dflash
from .universal import UniversalDFlashDecoder
__version__ = "0.2.0"
__all__ = [
# Adapters
"MLXTargetAdapter",
"Qwen3Adapter",
"Qwen35Adapter",
"LlamaAdapter",
"MistralAdapter",
"GemmaAdapter",
"LoadedTargetModel",
"load_target_model",
"adapter_for_model_type",
"detect_model_architecture",
# Core model
"DFlashDraftModel",
"DFlashDenoiser",
# Decoding
"DFlashSpeculativeDecoder",
"UniversalDFlashDecoder",
# Conversion
"convert_dflash_to_mlx",
"load_mlx_dflash",
]
|