tritesh's picture
Upload dflash_mlx/__init__.py
5fdf20a verified
"""
DFlash-MLX-Universal: Block Diffusion Speculative Decoding for MLX
A universal MLX implementation of DFlash that works with any MLX-converted model.
Optimized for Apple Silicon (M2/M3/M4 Pro/Max/Ultra) and compatible with all
major model families: Qwen3, Qwen3.5, LLaMA, Mistral, Gemma.
Key features:
- Architecture-agnostic adapters for any MLX model family
- KV cache management with proper trim/rewind on rejection
- Streaming generation with real-time text output
- OpenAI-compatible server (via serve.py)
- Training custom drafters on-the-fly (via trainer.py)
- Conversion of PyTorch DFlash drafters to MLX (via convert.py)
- Benchmarking and diagnostics tools
"""
from .adapters import (
MLXTargetAdapter,
Qwen3Adapter,
Qwen35Adapter,
LlamaAdapter,
MistralAdapter,
GemmaAdapter,
LoadedTargetModel,
load_target_model,
adapter_for_model_type,
detect_model_architecture,
)
from .model import DFlashDraftModel, DFlashDenoiser
from .speculative_decode import DFlashSpeculativeDecoder
from .convert import convert_dflash_to_mlx, load_mlx_dflash
from .universal import UniversalDFlashDecoder
__version__ = "0.2.0"
__all__ = [
# Adapters
"MLXTargetAdapter",
"Qwen3Adapter",
"Qwen35Adapter",
"LlamaAdapter",
"MistralAdapter",
"GemmaAdapter",
"LoadedTargetModel",
"load_target_model",
"adapter_for_model_type",
"detect_model_architecture",
# Core model
"DFlashDraftModel",
"DFlashDenoiser",
# Decoding
"DFlashSpeculativeDecoder",
"UniversalDFlashDecoder",
# Conversion
"convert_dflash_to_mlx",
"load_mlx_dflash",
]