tritesh's picture
Upload folder using huggingface_hub
0433390 verified
raw
history blame
9.93 kB
"""
Universal DFlash decoder for any MLX-converted model.
Provides a high-level interface that works with any mlx_lm model,
including those without pre-built DFlash drafters.
"""
from typing import Optional, List, Dict, Any
import mlx.core as mx
from .model import DFlashDraftModel
from .speculative_decode import DFlashSpeculativeDecoder
class UniversalDFlashDecoder:
"""Universal DFlash decoder that works with any MLX-converted model.
This class handles:
1. Loading pre-converted DFlash drafters
2. Creating generic drafters for unsupported models
3. Training custom drafters on-the-fly
"""
def __init__(
self,
target_model,
tokenizer,
draft_model_path: Optional[str] = None,
draft_layers: int = 5,
draft_hidden_size: int = 1024,
block_size: int = 16,
device: str = "metal",
):
"""Initialize the universal decoder.
Args:
target_model: Any mlx_lm loaded model
tokenizer: Tokenizer for the model
draft_model_path: Optional path to pre-converted DFlash drafter
draft_layers: Number of draft layers (if creating generic drafter)
draft_hidden_size: Hidden size for generic drafter
block_size: Number of tokens per draft block
device: MLX device
"""
self.target_model = target_model
self.tokenizer = tokenizer
self.block_size = block_size
self.device = device
# Determine model type and vocab size
self.vocab_size = getattr(tokenizer, "vocab_size", 151936)
self.target_config = self._extract_target_config(target_model)
# Load or create draft model
if draft_model_path:
print(f"[UniversalDFlash] Loading pre-built drafter from {draft_model_path}")
from .convert import load_mlx_dflash
self.draft_model, self.draft_config = load_mlx_dflash(draft_model_path)
else:
print("[UniversalDFlash] Creating generic drafter for your model...")
self.draft_model = self._create_generic_drafter(
draft_layers=draft_layers,
draft_hidden_size=draft_hidden_size,
)
self.draft_config = None
# Create the speculative decoder
self.decoder = DFlashSpeculativeDecoder(
target_model=target_model,
draft_model=self.draft_model,
tokenizer=tokenizer,
block_size=block_size,
device=device,
)
def _extract_target_config(self, target_model) -> Dict[str, Any]:
"""Extract configuration from target model."""
config = {}
# Try to extract from model attributes
if hasattr(target_model, 'config'):
model_config = target_model.config
config['hidden_size'] = getattr(model_config, 'hidden_size', 4096)
config['num_layers'] = getattr(model_config, 'num_hidden_layers', 32)
config['vocab_size'] = getattr(model_config, 'vocab_size', 151936)
config['intermediate_size'] = getattr(model_config, 'intermediate_size', 14336)
config['num_attention_heads'] = getattr(model_config, 'num_attention_heads', 32)
config['num_key_value_heads'] = getattr(model_config, 'num_key_value_heads', 8)
else:
# Default Qwen3-4B-like config
config = {
'hidden_size': 4096,
'num_layers': 32,
'vocab_size': 151936,
'intermediate_size': 14336,
'num_attention_heads': 32,
'num_key_value_heads': 8,
}
return config
def _create_generic_drafter(
self,
draft_layers: int,
draft_hidden_size: int,
) -> DFlashDraftModel:
"""Create a generic DFlash drafter compatible with the target model.
This creates an untrained drafter that can be trained or used
with pre-trained weights from a similar architecture.
"""
# Determine architecture compatibility
hidden_size = self.target_config.get('hidden_size', 4096)
vocab_size = self.target_config.get('vocab_size', 151936)
# Scale drafter based on target model size
num_heads = draft_hidden_size // 64 # ~64 dims per head
num_kv_heads = max(1, num_heads // 4)
intermediate_size = int(draft_hidden_size * 2.75) # Standard SwiGLU ratio
drafter = DFlashDraftModel(
vocab_size=vocab_size,
hidden_size=draft_hidden_size,
num_layers=draft_layers,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
intermediate_size=intermediate_size,
max_seq_len=8192,
block_size=self.block_size,
mask_token_id=0, # Will be set from tokenizer
num_target_layers=self.target_config.get('num_layers', 32),
)
return drafter
def train_drafter(
self,
dataset: str,
max_seq_length: int = 3072,
epochs: int = 6,
batch_size: int = 32,
lr: float = 6e-4,
warmup_ratio: float = 0.04,
grad_clip: float = 1.0,
output_path: Optional[str] = None,
) -> str:
"""Train a custom DFlash drafter for your target model.
Args:
dataset: Path to training dataset or HF dataset name
max_seq_length: Maximum sequence length for training
epochs: Number of training epochs
batch_size: Training batch size
lr: Learning rate
warmup_ratio: Warmup ratio for cosine schedule
grad_clip: Gradient clipping threshold
output_path: Where to save the trained drafter
Returns:
Path to saved drafter
"""
from .trainer import DFlashTrainer
print(f"[UniversalDFlash] Training custom drafter...")
trainer = DFlashTrainer(
target_model=self.target_model,
drafter=self.draft_model,
tokenizer=self.tokenizer,
)
trained_model = trainer.train(
dataset=dataset,
max_seq_length=max_seq_length,
epochs=epochs,
batch_size=batch_size,
lr=lr,
warmup_ratio=warmup_ratio,
grad_clip=grad_clip,
)
# Update the draft model
self.draft_model = trained_model
self.decoder.draft_model = trained_model
if output_path:
self.save_drafter(output_path)
return output_path or "./trained_dflash_drafter"
def save_drafter(self, path: str):
"""Save the current drafter model."""
import json
from pathlib import Path
path = Path(path)
path.mkdir(parents=True, exist_ok=True)
# Save weights
weights = dict(self.draft_model.parameters())
mx.save_safetensors(str(path / "weights.safetensors"), weights)
# Save config
config = {
"vocab_size": self.draft_model.vocab_size,
"hidden_size": self.draft_model.hidden_size,
"num_hidden_layers": self.draft_model.num_layers,
"num_attention_heads": self.draft_model.num_heads,
"num_key_value_heads": self.draft_model.num_heads // 4,
"intermediate_size": self.draft_model.layers[0].mlp.gate_proj.weight.shape[1] if hasattr(self.draft_model.layers[0].mlp.gate_proj, 'weight') else 2816,
"max_position_embeddings": self.draft_model.max_seq_len,
"block_size": self.draft_model.block_size,
}
with open(path / "config.json", "w") as f:
json.dump(config, f, indent=2)
print(f"[UniversalDFlash] Drafter saved to {path}")
def generate(
self,
prompt: str,
max_tokens: int = 2048,
temperature: float = 0.0,
stop_strings: Optional[List[str]] = None,
) -> str:
"""Generate text using DFlash speculative decoding.
Args:
prompt: Text prompt
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
stop_strings: Optional stop strings
Returns:
Generated text
"""
return self.decoder.generate(
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature,
stop_strings=stop_strings,
)
def benchmark(
self,
prompt: str = "Write a quicksort in Python.",
max_tokens: int = 512,
num_runs: int = 5,
) -> Dict[str, float]:
"""Benchmark DFlash speculative decoding.
Args:
prompt: Test prompt
max_tokens: Tokens per run
num_runs: Number of benchmark runs
Returns:
Dict with speedup metrics
"""
import time
print(f"[Benchmark] Running {num_runs} generations...")
# Warmup
self.generate(prompt, max_tokens=10)
# DFlash generation
dflash_times = []
for _ in range(num_runs):
start = time.time()
self.generate(prompt, max_tokens=max_tokens)
dflash_times.append(time.time() - start)
# Baseline generation (without speculative decoding)
# We estimate based on token count vs time
# In practice you'd run a full baseline comparison
avg_time = sum(dflash_times) / len(dflash_times)
tokens_per_sec = max_tokens / avg_time
print(f"[Benchmark] Avg time: {avg_time:.2f}s, Speed: {tokens_per_sec:.1f} tok/s")
return {
"avg_time_sec": avg_time,
"tokens_per_sec": tokens_per_sec,
"num_runs": num_runs,
}