| """ |
| Training utilities for DFlash drafters on MLX. |
| |
| Implements the training recipe from the DFlash paper: |
| - KV injection with target model features |
| - Random anchor sampling for block construction |
| - Sparse attention masking within blocks |
| - Position-dependent loss decay |
| """ |
|
|
| import math |
| from typing import Optional, List, Dict, Any, Tuple |
| import mlx.core as mx |
| import mlx.nn as nn |
| import mlx.optimizers as optim |
| from .model import DFlashDraftModel |
|
|
|
|
| class DFlashTrainer: |
| """Trainer for DFlash draft models on MLX. |
| |
| Trains the drafter to align block-level diffusion predictions |
| with a frozen autoregressive target model's outputs. |
| """ |
|
|
| def __init__( |
| self, |
| target_model, |
| drafter: DFlashDraftModel, |
| tokenizer, |
| max_seq_length: int = 3072, |
| ): |
| self.target_model = target_model |
| self.drafter = drafter |
| self.tokenizer = tokenizer |
| self.max_seq_length = max_seq_length |
| self.mask_token_id = drafter.mask_token_id |
|
|
| def _prepare_training_sample( |
| self, |
| prompt: str, |
| response: str, |
| block_size: int, |
| ) -> Dict[str, mx.array]: |
| """Prepare a single training sample. |
| |
| Constructs masked blocks with random anchors from target-generated |
| responses, matching the inference-time speculative decoding setting. |
| """ |
| |
| prompt_ids = self.tokenizer.encode(prompt) |
| response_ids = self.tokenizer.encode(response) |
|
|
| |
| total_len = len(prompt_ids) + len(response_ids) |
| if total_len > self.max_seq_length: |
| response_ids = response_ids[:self.max_seq_length - len(prompt_ids)] |
|
|
| full_ids = prompt_ids + response_ids |
| full_ids_mx = mx.array(full_ids) |
|
|
| |
| with mx.eval_mode(): |
| target_output = self._target_forward(full_ids_mx) |
| target_hidden = self.drafter.extract_context_features( |
| target_output["hidden_states"] |
| ) |
|
|
| |
| num_blocks = max(1, len(response_ids) // block_size) |
| block_starts = mx.random.randint( |
| low=len(prompt_ids), |
| high=len(full_ids) - block_size + 1, |
| shape=(num_blocks,), |
| ) |
|
|
| |
| masked_ids = mx.array(full_ids) |
| labels = mx.full((len(full_ids),), -100, dtype=mx.int32) |
|
|
| for start in block_starts.tolist(): |
| start = int(start) |
| end = min(start + block_size, len(full_ids)) |
| |
| |
| masked_ids = masked_ids.at[start + 1:end].set(self.mask_token_id) |
| |
| labels = labels.at[start + 1:end].set(full_ids_mx[start + 1:end]) |
|
|
| return { |
| "input_ids": masked_ids, |
| "labels": labels, |
| "target_hidden": target_hidden, |
| "prompt_length": len(prompt_ids), |
| } |
|
|
| def _target_forward( |
| self, |
| input_ids: mx.array, |
| ) -> Dict[str, Any]: |
| """Forward pass through target model to get hidden states.""" |
| if hasattr(self.target_model, '__call__'): |
| result = self.target_model(input_ids) |
| logits = result[0] if isinstance(result, tuple) else result |
| else: |
| logits = self.target_model(input_ids) |
|
|
| |
| hidden_states = [] |
| hidden = input_ids |
| if hasattr(self.target_model, 'embed_tokens'): |
| hidden = self.target_model.embed_tokens(hidden) |
|
|
| if hasattr(self.target_model, 'layers'): |
| for layer in self.target_model.layers: |
| hidden = layer(hidden, mask=None) |
| hidden_states.append(hidden) |
| else: |
| hidden_states = [hidden] |
|
|
| return { |
| "logits": logits, |
| "hidden_states": hidden_states, |
| } |
|
|
| def _compute_loss( |
| self, |
| input_ids: mx.array, |
| labels: mx.array, |
| target_hidden: mx.array, |
| ) -> mx.array: |
| """Compute the diffusion training loss with position-dependent decay. |
| |
| Implements the loss decay from the paper where tokens closer to |
| the anchor receive higher weights. |
| """ |
| |
| embeddings = self.drafter.embed_tokens(input_ids) |
|
|
| |
| position_ids = mx.arange(input_ids.shape[0]) |
|
|
| |
| hidden_states = self.drafter( |
| noise_embedding=embeddings, |
| target_hidden=target_hidden, |
| position_ids=position_ids, |
| ) |
|
|
| |
| logits = self.drafter.get_logits(hidden_states) |
|
|
| |
| valid_mask = labels != -100 |
| if not valid_mask.any(): |
| return mx.array(0.0) |
|
|
| valid_logits = logits[valid_mask] |
| valid_labels = labels[valid_mask] |
|
|
| |
| |
| positions = mx.arange(len(labels)) |
| |
| |
| weights = mx.ones_like(valid_labels, dtype=mx.float32) |
|
|
| |
| log_probs = mx.log_softmax(valid_logits, axis=-1) |
| nll = -log_probs[mx.arange(len(valid_labels)), valid_labels] |
| weighted_nll = nll * weights |
|
|
| return weighted_nll.mean() |
|
|
| def _build_batch( |
| self, |
| samples: List[Dict[str, Any]], |
| ) -> Dict[str, mx.array]: |
| """Batch multiple training samples.""" |
| |
| max_len = max(s["input_ids"].shape[0] for s in samples) |
|
|
| |
| batch_input_ids = [] |
| batch_labels = [] |
| batch_target_hidden = [] |
| batch_attention_mask = [] |
|
|
| for sample in samples: |
| seq_len = sample["input_ids"].shape[0] |
| pad_len = max_len - seq_len |
|
|
| |
| padded_ids = mx.concatenate([ |
| sample["input_ids"], |
| mx.full((pad_len,), self.mask_token_id, dtype=mx.int32) |
| ]) |
| batch_input_ids.append(padded_ids) |
|
|
| |
| padded_labels = mx.concatenate([ |
| sample["labels"], |
| mx.full((pad_len,), -100, dtype=mx.int32) |
| ]) |
| batch_labels.append(padded_labels) |
|
|
| |
| mask = mx.concatenate([ |
| mx.ones((seq_len,), dtype=mx.float32), |
| mx.zeros((pad_len,), dtype=mx.float32) |
| ]) |
| batch_attention_mask.append(mask) |
|
|
| |
| hidden = sample["target_hidden"] |
| if hidden.shape[1] < max_len: |
| pad = mx.zeros((hidden.shape[0], max_len - hidden.shape[1], hidden.shape[2])) |
| hidden = mx.concatenate([hidden, pad], axis=1) |
| batch_target_hidden.append(hidden) |
|
|
| return { |
| "input_ids": mx.stack(batch_input_ids), |
| "labels": mx.stack(batch_labels), |
| "target_hidden": mx.stack(batch_target_hidden), |
| "attention_mask": mx.stack(batch_attention_mask), |
| } |
|
|
| def train( |
| self, |
| dataset: str, |
| epochs: int = 6, |
| batch_size: int = 8, |
| lr: float = 6e-4, |
| warmup_ratio: float = 0.04, |
| grad_clip: float = 1.0, |
| save_every: int = 1000, |
| ) -> DFlashDraftModel: |
| """Train the DFlash drafter. |
| |
| Args: |
| dataset: Path to dataset (JSONL with {prompt, response} pairs) |
| or HF dataset name with 'prompt' and 'response' columns |
| epochs: Number of training epochs |
| batch_size: Batch size |
| lr: Learning rate |
| warmup_ratio: Warmup ratio for cosine schedule |
| grad_clip: Gradient clipping threshold |
| save_every: Save checkpoint every N steps |
| |
| Returns: |
| Trained DFlashDraftModel |
| """ |
| |
| samples = self._load_dataset(dataset) |
| print(f"[Trainer] Loaded {len(samples)} training samples") |
|
|
| |
| optimizer = optim.AdamW(learning_rate=lr) |
|
|
| |
| num_steps = (len(samples) // batch_size) * epochs |
| warmup_steps = int(num_steps * warmup_ratio) |
|
|
| def lr_schedule(step): |
| if step < warmup_steps: |
| return lr * (step / warmup_steps) |
| progress = (step - warmup_steps) / max(1, num_steps - warmup_steps) |
| return lr * 0.5 * (1 + math.cos(math.pi * progress)) |
|
|
| |
| step = 0 |
| for epoch in range(epochs): |
| |
| import random |
| random.shuffle(samples) |
|
|
| epoch_losses = [] |
| for i in range(0, len(samples), batch_size): |
| batch_samples = samples[i:i + batch_size] |
|
|
| |
| batch = self._build_batch(batch_samples) |
|
|
| |
| def loss_fn(params): |
| self.drafter.update(params) |
| loss = self._compute_loss( |
| batch["input_ids"], |
| batch["labels"], |
| batch["target_hidden"], |
| ) |
| return loss |
|
|
| |
| loss, grads = mx.value_and_grad(loss_fn)(self.drafter.parameters()) |
|
|
| |
| if grad_clip > 0: |
| grad_norm = mx.sqrt(sum(mx.sum(g * g) for g in grads.values())) |
| if grad_norm > grad_clip: |
| scale = grad_clip / grad_norm |
| grads = {k: v * scale for k, v in grads.items()} |
|
|
| |
| current_lr = lr_schedule(step) |
| optimizer.learning_rate = current_lr |
| self.drafter = optimizer.apply(grads, self.drafter) |
|
|
| loss_val = float(loss) |
| epoch_losses.append(loss_val) |
|
|
| if step % 10 == 0: |
| avg_loss = sum(epoch_losses[-10:]) / len(epoch_losses[-10:]) |
| print(f"[Trainer] Epoch {epoch+1}/{epochs} Step {step} | " |
| f"Loss: {loss_val:.4f} | LR: {current_lr:.2e}") |
|
|
| step += 1 |
|
|
| |
| if step % save_every == 0: |
| self._save_checkpoint(f"checkpoint_step_{step}") |
|
|
| avg_epoch_loss = sum(epoch_losses) / len(epoch_losses) |
| print(f"[Trainer] Epoch {epoch+1} complete | Avg Loss: {avg_epoch_loss:.4f}") |
|
|
| print("[Trainer] Training complete!") |
| return self.drafter |
|
|
| def _load_dataset(self, dataset: str) -> List[Dict[str, str]]: |
| """Load dataset from path or HF Hub.""" |
| import json |
| from pathlib import Path |
|
|
| |
| dataset_path = Path(dataset) |
| if dataset_path.exists(): |
| samples = [] |
| with open(dataset_path, "r") as f: |
| for line in f: |
| data = json.loads(line) |
| samples.append({ |
| "prompt": data.get("prompt", data.get("input", "")), |
| "response": data.get("response", data.get("output", data.get("completion", ""))), |
| }) |
| return samples |
|
|
| |
| try: |
| from datasets import load_dataset |
| ds = load_dataset(dataset, split="train") |
| samples = [] |
| for item in ds: |
| prompt = item.get("prompt", item.get("input", item.get("question", ""))) |
| response = item.get("response", item.get("output", item.get("answer", item.get("completion", "")))) |
| if prompt and response: |
| samples.append({"prompt": prompt, "response": response}) |
| return samples |
| except Exception as e: |
| print(f"[Trainer] Failed to load dataset: {e}") |
| return [] |
|
|
| def _save_checkpoint(self, name: str): |
| """Save a training checkpoint.""" |
| import json |
| from pathlib import Path |
|
|
| checkpoint_dir = Path("checkpoints") / name |
| checkpoint_dir.mkdir(parents=True, exist_ok=True) |
|
|
| weights = dict(self.drafter.parameters()) |
| mx.save_safetensors(str(checkpoint_dir / "weights.safetensors"), weights) |
|
|
| print(f"[Trainer] Saved checkpoint to {checkpoint_dir}") |
|
|