File size: 13,057 Bytes
0433390
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
"""
Training utilities for DFlash drafters on MLX.

Implements the training recipe from the DFlash paper:
- KV injection with target model features
- Random anchor sampling for block construction
- Sparse attention masking within blocks
- Position-dependent loss decay
"""

import math
from typing import Optional, List, Dict, Any, Tuple
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from .model import DFlashDraftModel


class DFlashTrainer:
    """Trainer for DFlash draft models on MLX.
    
    Trains the drafter to align block-level diffusion predictions
    with a frozen autoregressive target model's outputs.
    """

    def __init__(
        self,
        target_model,
        drafter: DFlashDraftModel,
        tokenizer,
        max_seq_length: int = 3072,
    ):
        self.target_model = target_model
        self.drafter = drafter
        self.tokenizer = tokenizer
        self.max_seq_length = max_seq_length
        self.mask_token_id = drafter.mask_token_id

    def _prepare_training_sample(
        self,
        prompt: str,
        response: str,
        block_size: int,
    ) -> Dict[str, mx.array]:
        """Prepare a single training sample.
        
        Constructs masked blocks with random anchors from target-generated
        responses, matching the inference-time speculative decoding setting.
        """
        # Tokenize prompt + response
        prompt_ids = self.tokenizer.encode(prompt)
        response_ids = self.tokenizer.encode(response)

        # Truncate if too long
        total_len = len(prompt_ids) + len(response_ids)
        if total_len > self.max_seq_length:
            response_ids = response_ids[:self.max_seq_length - len(prompt_ids)]

        full_ids = prompt_ids + response_ids
        full_ids_mx = mx.array(full_ids)

        # Build target context features
        with mx.eval_mode():
            target_output = self._target_forward(full_ids_mx)
            target_hidden = self.drafter.extract_context_features(
                target_output["hidden_states"]
            )

        # Random anchor sampling for blocks
        num_blocks = max(1, len(response_ids) // block_size)
        block_starts = mx.random.randint(
            low=len(prompt_ids),
            high=len(full_ids) - block_size + 1,
            shape=(num_blocks,),
        )

        # Create masked sequence
        masked_ids = mx.array(full_ids)
        labels = mx.full((len(full_ids),), -100, dtype=mx.int32)  # Ignore index

        for start in block_starts.tolist():
            start = int(start)
            end = min(start + block_size, len(full_ids))
            # Anchor is first token (from target model's accepted token)
            # Mask remaining positions in block
            masked_ids = masked_ids.at[start + 1:end].set(self.mask_token_id)
            # Labels for masked positions
            labels = labels.at[start + 1:end].set(full_ids_mx[start + 1:end])

        return {
            "input_ids": masked_ids,
            "labels": labels,
            "target_hidden": target_hidden,
            "prompt_length": len(prompt_ids),
        }

    def _target_forward(
        self,
        input_ids: mx.array,
    ) -> Dict[str, Any]:
        """Forward pass through target model to get hidden states."""
        if hasattr(self.target_model, '__call__'):
            result = self.target_model(input_ids)
            logits = result[0] if isinstance(result, tuple) else result
        else:
            logits = self.target_model(input_ids)

        # Extract hidden states layer by layer
        hidden_states = []
        hidden = input_ids
        if hasattr(self.target_model, 'embed_tokens'):
            hidden = self.target_model.embed_tokens(hidden)

        if hasattr(self.target_model, 'layers'):
            for layer in self.target_model.layers:
                hidden = layer(hidden, mask=None)
                hidden_states.append(hidden)
        else:
            hidden_states = [hidden]

        return {
            "logits": logits,
            "hidden_states": hidden_states,
        }

    def _compute_loss(
        self,
        input_ids: mx.array,
        labels: mx.array,
        target_hidden: mx.array,
    ) -> mx.array:
        """Compute the diffusion training loss with position-dependent decay.
        
        Implements the loss decay from the paper where tokens closer to
        the anchor receive higher weights.
        """
        # Embed tokens (including mask tokens)
        embeddings = self.drafter.embed_tokens(input_ids)

        # Build position IDs
        position_ids = mx.arange(input_ids.shape[0])

        # Forward through drafter
        hidden_states = self.drafter(
            noise_embedding=embeddings,
            target_hidden=target_hidden,
            position_ids=position_ids,
        )

        # Get logits
        logits = self.drafter.get_logits(hidden_states)

        # Compute cross-entropy loss for labeled positions
        valid_mask = labels != -100
        if not valid_mask.any():
            return mx.array(0.0)

        valid_logits = logits[valid_mask]
        valid_labels = labels[valid_mask]

        # Position-dependent weighting (exponential decay from anchor)
        # Find anchor positions and compute distances
        positions = mx.arange(len(labels))
        # Simplified: uniform weighting for now
        # Full implementation would track block boundaries
        weights = mx.ones_like(valid_labels, dtype=mx.float32)

        # Cross entropy
        log_probs = mx.log_softmax(valid_logits, axis=-1)
        nll = -log_probs[mx.arange(len(valid_labels)), valid_labels]
        weighted_nll = nll * weights

        return weighted_nll.mean()

    def _build_batch(
        self,
        samples: List[Dict[str, Any]],
    ) -> Dict[str, mx.array]:
        """Batch multiple training samples."""
        # Find max length
        max_len = max(s["input_ids"].shape[0] for s in samples)

        # Pad sequences
        batch_input_ids = []
        batch_labels = []
        batch_target_hidden = []
        batch_attention_mask = []

        for sample in samples:
            seq_len = sample["input_ids"].shape[0]
            pad_len = max_len - seq_len

            # Pad input_ids with mask token
            padded_ids = mx.concatenate([
                sample["input_ids"],
                mx.full((pad_len,), self.mask_token_id, dtype=mx.int32)
            ])
            batch_input_ids.append(padded_ids)

            # Pad labels with -100 (ignore index)
            padded_labels = mx.concatenate([
                sample["labels"],
                mx.full((pad_len,), -100, dtype=mx.int32)
            ])
            batch_labels.append(padded_labels)

            # Attention mask (1 for real, 0 for padding)
            mask = mx.concatenate([
                mx.ones((seq_len,), dtype=mx.float32),
                mx.zeros((pad_len,), dtype=mx.float32)
            ])
            batch_attention_mask.append(mask)

            # Target hidden (pad with zeros)
            hidden = sample["target_hidden"]
            if hidden.shape[1] < max_len:
                pad = mx.zeros((hidden.shape[0], max_len - hidden.shape[1], hidden.shape[2]))
                hidden = mx.concatenate([hidden, pad], axis=1)
            batch_target_hidden.append(hidden)

        return {
            "input_ids": mx.stack(batch_input_ids),
            "labels": mx.stack(batch_labels),
            "target_hidden": mx.stack(batch_target_hidden),
            "attention_mask": mx.stack(batch_attention_mask),
        }

    def train(
        self,
        dataset: str,
        epochs: int = 6,
        batch_size: int = 8,
        lr: float = 6e-4,
        warmup_ratio: float = 0.04,
        grad_clip: float = 1.0,
        save_every: int = 1000,
    ) -> DFlashDraftModel:
        """Train the DFlash drafter.
        
        Args:
            dataset: Path to dataset (JSONL with {prompt, response} pairs)
                      or HF dataset name with 'prompt' and 'response' columns
            epochs: Number of training epochs
            batch_size: Batch size
            lr: Learning rate
            warmup_ratio: Warmup ratio for cosine schedule
            grad_clip: Gradient clipping threshold
            save_every: Save checkpoint every N steps
        
        Returns:
            Trained DFlashDraftModel
        """
        # Load dataset
        samples = self._load_dataset(dataset)
        print(f"[Trainer] Loaded {len(samples)} training samples")

        # Setup optimizer
        optimizer = optim.AdamW(learning_rate=lr)

        # Cosine schedule with warmup
        num_steps = (len(samples) // batch_size) * epochs
        warmup_steps = int(num_steps * warmup_ratio)

        def lr_schedule(step):
            if step < warmup_steps:
                return lr * (step / warmup_steps)
            progress = (step - warmup_steps) / max(1, num_steps - warmup_steps)
            return lr * 0.5 * (1 + math.cos(math.pi * progress))

        # Training loop
        step = 0
        for epoch in range(epochs):
            # Shuffle samples
            import random
            random.shuffle(samples)

            epoch_losses = []
            for i in range(0, len(samples), batch_size):
                batch_samples = samples[i:i + batch_size]

                # Prepare batch
                batch = self._build_batch(batch_samples)

                # Forward + backward
                def loss_fn(params):
                    self.drafter.update(params)
                    loss = self._compute_loss(
                        batch["input_ids"],
                        batch["labels"],
                        batch["target_hidden"],
                    )
                    return loss

                # Compute loss and gradients
                loss, grads = mx.value_and_grad(loss_fn)(self.drafter.parameters())

                # Gradient clipping
                if grad_clip > 0:
                    grad_norm = mx.sqrt(sum(mx.sum(g * g) for g in grads.values()))
                    if grad_norm > grad_clip:
                        scale = grad_clip / grad_norm
                        grads = {k: v * scale for k, v in grads.items()}

                # Update parameters
                current_lr = lr_schedule(step)
                optimizer.learning_rate = current_lr
                self.drafter = optimizer.apply(grads, self.drafter)

                loss_val = float(loss)
                epoch_losses.append(loss_val)

                if step % 10 == 0:
                    avg_loss = sum(epoch_losses[-10:]) / len(epoch_losses[-10:])
                    print(f"[Trainer] Epoch {epoch+1}/{epochs} Step {step} | "
                          f"Loss: {loss_val:.4f} | LR: {current_lr:.2e}")

                step += 1

                # Save checkpoint
                if step % save_every == 0:
                    self._save_checkpoint(f"checkpoint_step_{step}")

            avg_epoch_loss = sum(epoch_losses) / len(epoch_losses)
            print(f"[Trainer] Epoch {epoch+1} complete | Avg Loss: {avg_epoch_loss:.4f}")

        print("[Trainer] Training complete!")
        return self.drafter

    def _load_dataset(self, dataset: str) -> List[Dict[str, str]]:
        """Load dataset from path or HF Hub."""
        import json
        from pathlib import Path

        # Try local file first
        dataset_path = Path(dataset)
        if dataset_path.exists():
            samples = []
            with open(dataset_path, "r") as f:
                for line in f:
                    data = json.loads(line)
                    samples.append({
                        "prompt": data.get("prompt", data.get("input", "")),
                        "response": data.get("response", data.get("output", data.get("completion", ""))),
                    })
            return samples

        # Try Hugging Face dataset
        try:
            from datasets import load_dataset
            ds = load_dataset(dataset, split="train")
            samples = []
            for item in ds:
                prompt = item.get("prompt", item.get("input", item.get("question", "")))
                response = item.get("response", item.get("output", item.get("answer", item.get("completion", ""))))
                if prompt and response:
                    samples.append({"prompt": prompt, "response": response})
            return samples
        except Exception as e:
            print(f"[Trainer] Failed to load dataset: {e}")
            return []

    def _save_checkpoint(self, name: str):
        """Save a training checkpoint."""
        import json
        from pathlib import Path

        checkpoint_dir = Path("checkpoints") / name
        checkpoint_dir.mkdir(parents=True, exist_ok=True)

        weights = dict(self.drafter.parameters())
        mx.save_safetensors(str(checkpoint_dir / "weights.safetensors"), weights)

        print(f"[Trainer] Saved checkpoint to {checkpoint_dir}")