""" Training utilities for DFlash drafters on MLX. Implements the training recipe from the DFlash paper: - KV injection with target model features - Random anchor sampling for block construction - Sparse attention masking within blocks - Position-dependent loss decay """ import math from typing import Optional, List, Dict, Any, Tuple import mlx.core as mx import mlx.nn as nn import mlx.optimizers as optim from .model import DFlashDraftModel class DFlashTrainer: """Trainer for DFlash draft models on MLX. Trains the drafter to align block-level diffusion predictions with a frozen autoregressive target model's outputs. """ def __init__( self, target_model, drafter: DFlashDraftModel, tokenizer, max_seq_length: int = 3072, ): self.target_model = target_model self.drafter = drafter self.tokenizer = tokenizer self.max_seq_length = max_seq_length self.mask_token_id = drafter.mask_token_id def _prepare_training_sample( self, prompt: str, response: str, block_size: int, ) -> Dict[str, mx.array]: """Prepare a single training sample. Constructs masked blocks with random anchors from target-generated responses, matching the inference-time speculative decoding setting. """ # Tokenize prompt + response prompt_ids = self.tokenizer.encode(prompt) response_ids = self.tokenizer.encode(response) # Truncate if too long total_len = len(prompt_ids) + len(response_ids) if total_len > self.max_seq_length: response_ids = response_ids[:self.max_seq_length - len(prompt_ids)] full_ids = prompt_ids + response_ids full_ids_mx = mx.array(full_ids) # Build target context features with mx.eval_mode(): target_output = self._target_forward(full_ids_mx) target_hidden = self.drafter.extract_context_features( target_output["hidden_states"] ) # Random anchor sampling for blocks num_blocks = max(1, len(response_ids) // block_size) block_starts = mx.random.randint( low=len(prompt_ids), high=len(full_ids) - block_size + 1, shape=(num_blocks,), ) # Create masked sequence masked_ids = mx.array(full_ids) labels = mx.full((len(full_ids),), -100, dtype=mx.int32) # Ignore index for start in block_starts.tolist(): start = int(start) end = min(start + block_size, len(full_ids)) # Anchor is first token (from target model's accepted token) # Mask remaining positions in block masked_ids = masked_ids.at[start + 1:end].set(self.mask_token_id) # Labels for masked positions labels = labels.at[start + 1:end].set(full_ids_mx[start + 1:end]) return { "input_ids": masked_ids, "labels": labels, "target_hidden": target_hidden, "prompt_length": len(prompt_ids), } def _target_forward( self, input_ids: mx.array, ) -> Dict[str, Any]: """Forward pass through target model to get hidden states.""" if hasattr(self.target_model, '__call__'): result = self.target_model(input_ids) logits = result[0] if isinstance(result, tuple) else result else: logits = self.target_model(input_ids) # Extract hidden states layer by layer hidden_states = [] hidden = input_ids if hasattr(self.target_model, 'embed_tokens'): hidden = self.target_model.embed_tokens(hidden) if hasattr(self.target_model, 'layers'): for layer in self.target_model.layers: hidden = layer(hidden, mask=None) hidden_states.append(hidden) else: hidden_states = [hidden] return { "logits": logits, "hidden_states": hidden_states, } def _compute_loss( self, input_ids: mx.array, labels: mx.array, target_hidden: mx.array, ) -> mx.array: """Compute the diffusion training loss with position-dependent decay. Implements the loss decay from the paper where tokens closer to the anchor receive higher weights. """ # Embed tokens (including mask tokens) embeddings = self.drafter.embed_tokens(input_ids) # Build position IDs position_ids = mx.arange(input_ids.shape[0]) # Forward through drafter hidden_states = self.drafter( noise_embedding=embeddings, target_hidden=target_hidden, position_ids=position_ids, ) # Get logits logits = self.drafter.get_logits(hidden_states) # Compute cross-entropy loss for labeled positions valid_mask = labels != -100 if not valid_mask.any(): return mx.array(0.0) valid_logits = logits[valid_mask] valid_labels = labels[valid_mask] # Position-dependent weighting (exponential decay from anchor) # Find anchor positions and compute distances positions = mx.arange(len(labels)) # Simplified: uniform weighting for now # Full implementation would track block boundaries weights = mx.ones_like(valid_labels, dtype=mx.float32) # Cross entropy log_probs = mx.log_softmax(valid_logits, axis=-1) nll = -log_probs[mx.arange(len(valid_labels)), valid_labels] weighted_nll = nll * weights return weighted_nll.mean() def _build_batch( self, samples: List[Dict[str, Any]], ) -> Dict[str, mx.array]: """Batch multiple training samples.""" # Find max length max_len = max(s["input_ids"].shape[0] for s in samples) # Pad sequences batch_input_ids = [] batch_labels = [] batch_target_hidden = [] batch_attention_mask = [] for sample in samples: seq_len = sample["input_ids"].shape[0] pad_len = max_len - seq_len # Pad input_ids with mask token padded_ids = mx.concatenate([ sample["input_ids"], mx.full((pad_len,), self.mask_token_id, dtype=mx.int32) ]) batch_input_ids.append(padded_ids) # Pad labels with -100 (ignore index) padded_labels = mx.concatenate([ sample["labels"], mx.full((pad_len,), -100, dtype=mx.int32) ]) batch_labels.append(padded_labels) # Attention mask (1 for real, 0 for padding) mask = mx.concatenate([ mx.ones((seq_len,), dtype=mx.float32), mx.zeros((pad_len,), dtype=mx.float32) ]) batch_attention_mask.append(mask) # Target hidden (pad with zeros) hidden = sample["target_hidden"] if hidden.shape[1] < max_len: pad = mx.zeros((hidden.shape[0], max_len - hidden.shape[1], hidden.shape[2])) hidden = mx.concatenate([hidden, pad], axis=1) batch_target_hidden.append(hidden) return { "input_ids": mx.stack(batch_input_ids), "labels": mx.stack(batch_labels), "target_hidden": mx.stack(batch_target_hidden), "attention_mask": mx.stack(batch_attention_mask), } def train( self, dataset: str, epochs: int = 6, batch_size: int = 8, lr: float = 6e-4, warmup_ratio: float = 0.04, grad_clip: float = 1.0, save_every: int = 1000, ) -> DFlashDraftModel: """Train the DFlash drafter. Args: dataset: Path to dataset (JSONL with {prompt, response} pairs) or HF dataset name with 'prompt' and 'response' columns epochs: Number of training epochs batch_size: Batch size lr: Learning rate warmup_ratio: Warmup ratio for cosine schedule grad_clip: Gradient clipping threshold save_every: Save checkpoint every N steps Returns: Trained DFlashDraftModel """ # Load dataset samples = self._load_dataset(dataset) print(f"[Trainer] Loaded {len(samples)} training samples") # Setup optimizer optimizer = optim.AdamW(learning_rate=lr) # Cosine schedule with warmup num_steps = (len(samples) // batch_size) * epochs warmup_steps = int(num_steps * warmup_ratio) def lr_schedule(step): if step < warmup_steps: return lr * (step / warmup_steps) progress = (step - warmup_steps) / max(1, num_steps - warmup_steps) return lr * 0.5 * (1 + math.cos(math.pi * progress)) # Training loop step = 0 for epoch in range(epochs): # Shuffle samples import random random.shuffle(samples) epoch_losses = [] for i in range(0, len(samples), batch_size): batch_samples = samples[i:i + batch_size] # Prepare batch batch = self._build_batch(batch_samples) # Forward + backward def loss_fn(params): self.drafter.update(params) loss = self._compute_loss( batch["input_ids"], batch["labels"], batch["target_hidden"], ) return loss # Compute loss and gradients loss, grads = mx.value_and_grad(loss_fn)(self.drafter.parameters()) # Gradient clipping if grad_clip > 0: grad_norm = mx.sqrt(sum(mx.sum(g * g) for g in grads.values())) if grad_norm > grad_clip: scale = grad_clip / grad_norm grads = {k: v * scale for k, v in grads.items()} # Update parameters current_lr = lr_schedule(step) optimizer.learning_rate = current_lr self.drafter = optimizer.apply(grads, self.drafter) loss_val = float(loss) epoch_losses.append(loss_val) if step % 10 == 0: avg_loss = sum(epoch_losses[-10:]) / len(epoch_losses[-10:]) print(f"[Trainer] Epoch {epoch+1}/{epochs} Step {step} | " f"Loss: {loss_val:.4f} | LR: {current_lr:.2e}") step += 1 # Save checkpoint if step % save_every == 0: self._save_checkpoint(f"checkpoint_step_{step}") avg_epoch_loss = sum(epoch_losses) / len(epoch_losses) print(f"[Trainer] Epoch {epoch+1} complete | Avg Loss: {avg_epoch_loss:.4f}") print("[Trainer] Training complete!") return self.drafter def _load_dataset(self, dataset: str) -> List[Dict[str, str]]: """Load dataset from path or HF Hub.""" import json from pathlib import Path # Try local file first dataset_path = Path(dataset) if dataset_path.exists(): samples = [] with open(dataset_path, "r") as f: for line in f: data = json.loads(line) samples.append({ "prompt": data.get("prompt", data.get("input", "")), "response": data.get("response", data.get("output", data.get("completion", ""))), }) return samples # Try Hugging Face dataset try: from datasets import load_dataset ds = load_dataset(dataset, split="train") samples = [] for item in ds: prompt = item.get("prompt", item.get("input", item.get("question", ""))) response = item.get("response", item.get("output", item.get("answer", item.get("completion", "")))) if prompt and response: samples.append({"prompt": prompt, "response": response}) return samples except Exception as e: print(f"[Trainer] Failed to load dataset: {e}") return [] def _save_checkpoint(self, name: str): """Save a training checkpoint.""" import json from pathlib import Path checkpoint_dir = Path("checkpoints") / name checkpoint_dir.mkdir(parents=True, exist_ok=True) weights = dict(self.drafter.parameters()) mx.save_safetensors(str(checkpoint_dir / "weights.safetensors"), weights) print(f"[Trainer] Saved checkpoint to {checkpoint_dir}")