tritesh commited on
Commit
5fdf20a
·
verified ·
1 Parent(s): e8cb6a7

Upload dflash_mlx/__init__.py

Browse files
Files changed (1) hide show
  1. dflash_mlx/__init__.py +43 -3
dflash_mlx/__init__.py CHANGED
@@ -2,16 +2,56 @@
2
  DFlash-MLX-Universal: Block Diffusion Speculative Decoding for MLX
3
 
4
  A universal MLX implementation of DFlash that works with any MLX-converted model.
5
- Optimized for Apple Silicon (M2/M3/M4 Pro/Max/Ultra).
 
 
 
 
 
 
 
 
 
 
6
  """
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  from .speculative_decode import DFlashSpeculativeDecoder
 
9
  from .universal import UniversalDFlashDecoder
10
- from .convert import convert_dflash_to_mlx
11
 
12
- __version__ = "0.1.1"
13
  __all__ = [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  "DFlashSpeculativeDecoder",
15
  "UniversalDFlashDecoder",
 
16
  "convert_dflash_to_mlx",
 
17
  ]
 
2
  DFlash-MLX-Universal: Block Diffusion Speculative Decoding for MLX
3
 
4
  A universal MLX implementation of DFlash that works with any MLX-converted model.
5
+ Optimized for Apple Silicon (M2/M3/M4 Pro/Max/Ultra) and compatible with all
6
+ major model families: Qwen3, Qwen3.5, LLaMA, Mistral, Gemma.
7
+
8
+ Key features:
9
+ - Architecture-agnostic adapters for any MLX model family
10
+ - KV cache management with proper trim/rewind on rejection
11
+ - Streaming generation with real-time text output
12
+ - OpenAI-compatible server (via serve.py)
13
+ - Training custom drafters on-the-fly (via trainer.py)
14
+ - Conversion of PyTorch DFlash drafters to MLX (via convert.py)
15
+ - Benchmarking and diagnostics tools
16
  """
17
 
18
+ from .adapters import (
19
+ MLXTargetAdapter,
20
+ Qwen3Adapter,
21
+ Qwen35Adapter,
22
+ LlamaAdapter,
23
+ MistralAdapter,
24
+ GemmaAdapter,
25
+ LoadedTargetModel,
26
+ load_target_model,
27
+ adapter_for_model_type,
28
+ detect_model_architecture,
29
+ )
30
+ from .model import DFlashDraftModel, DFlashDenoiser
31
  from .speculative_decode import DFlashSpeculativeDecoder
32
+ from .convert import convert_dflash_to_mlx, load_mlx_dflash
33
  from .universal import UniversalDFlashDecoder
 
34
 
35
+ __version__ = "0.2.0"
36
  __all__ = [
37
+ # Adapters
38
+ "MLXTargetAdapter",
39
+ "Qwen3Adapter",
40
+ "Qwen35Adapter",
41
+ "LlamaAdapter",
42
+ "MistralAdapter",
43
+ "GemmaAdapter",
44
+ "LoadedTargetModel",
45
+ "load_target_model",
46
+ "adapter_for_model_type",
47
+ "detect_model_architecture",
48
+ # Core model
49
+ "DFlashDraftModel",
50
+ "DFlashDenoiser",
51
+ # Decoding
52
  "DFlashSpeculativeDecoder",
53
  "UniversalDFlashDecoder",
54
+ # Conversion
55
  "convert_dflash_to_mlx",
56
+ "load_mlx_dflash",
57
  ]