""" Data generation utilities for DFlash training. Generates training data by running the target model on prompts, creating {prompt, response} pairs for drafter training. """ import json from pathlib import Path from typing import Optional, List, Dict, Any import mlx.core as mx def generate_training_data( target_model, tokenizer, prompts_dataset: str, output_path: str, max_new_tokens: int = 2048, temperature: float = 0.0, num_samples: Optional[int] = None, system_prompt: Optional[str] = None, ) -> str: """Generate training data by running target model on prompts. This creates the supervised data that DFlash drafters need: pairs of (prompt, target_model_response). Args: target_model: MLX target model tokenizer: Tokenizer prompts_dataset: HF dataset name or path to prompts file output_path: Output JSONL file path max_new_tokens: Max tokens per response temperature: Generation temperature (0 for greedy) num_samples: Max number of samples to generate (None = all) system_prompt: Optional system prompt Returns: Path to output file """ output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) # Load prompts prompts = _load_prompts(prompts_dataset) if num_samples: prompts = prompts[:num_samples] print(f"[DataGen] Generating {len(prompts)} responses...") with open(output_path, "w") as f: for i, prompt in enumerate(prompts): print(f"[DataGen] Sample {i+1}/{len(prompts)}...") # Generate response with target model response = _generate_with_model( model=target_model, tokenizer=tokenizer, prompt=prompt, max_new_tokens=max_new_tokens, temperature=temperature, system_prompt=system_prompt, ) # Save sample sample = { "prompt": prompt, "response": response, "model": getattr(target_model, "config", {}).get("_name_or_path", "unknown"), } f.write(json.dumps(sample) + "\n") print(f"[DataGen] Done! Saved to {output_path}") return str(output_path) def _load_prompts(dataset: str) -> List[str]: """Load prompts from dataset or file.""" import json from pathlib import Path path = Path(dataset) if path.exists(): # Local file prompts = [] with open(path, "r") as f: for line in f: data = json.loads(line) prompt = data.get("prompt", data.get("input", data.get("question", ""))) if prompt: prompts.append(prompt) return prompts # Try Hugging Face dataset try: from datasets import load_dataset ds = load_dataset(dataset, split="train") prompts = [] for item in ds: prompt = item.get("prompt", item.get("input", item.get("question", item.get("text", "")))) if prompt: prompts.append(str(prompt)) return prompts except Exception as e: print(f"[DataGen] Failed to load dataset: {e}") return [] def _generate_with_model( model, tokenizer, prompt: str, max_new_tokens: int, temperature: float = 0.0, system_prompt: Optional[str] = None, ) -> str: """Generate text with an MLX model.""" # Build prompt if system_prompt and hasattr(tokenizer, 'apply_chat_template'): messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}, ] text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) elif hasattr(tokenizer, 'apply_chat_template'): messages = [{"role": "user", "content": prompt}] text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) else: text = prompt # Tokenize input_ids = mx.array(tokenizer.encode(text)) input_ids = input_ids.reshape(1, -1) # Generate generated = [] for _ in range(max_new_tokens): if hasattr(model, '__call__'): result = model(input_ids) logits = result[0] if isinstance(result, tuple) else result else: logits = model(input_ids) # Sample next token next_logits = logits[:, -1, :] if temperature < 1e-5: next_token = mx.argmax(next_logits, axis=-1) else: probs = mx.softmax(next_logits / temperature, axis=-1) next_token = mx.random.categorical(mx.log(probs)) generated.append(int(next_token[0])) input_ids = mx.concatenate([input_ids, next_token.reshape(1, 1)], axis=1) # Check for EOS if hasattr(tokenizer, 'eos_token_id') and int(next_token[0]) == tokenizer.eos_token_id: break # Decode return tokenizer.decode(generated) def create_mixed_training_data( output_path: str, math_ratio: float = 0.30, code_ratio: float = 0.20, chat_ratio: float = 0.50, total_samples: int = 100000, ) -> str: """Create a mixed training dataset from public sources. This replicates the paper's data mixture recipe: - 50% instruction/chat (UltraChat, ShareGPT) - 30% math/reasoning (GSM8K, MATH) - 20% code (HumanEval, MBPP) Args: output_path: Output JSONL path math_ratio: Fraction of math samples code_ratio: Fraction of code samples chat_ratio: Fraction of chat samples total_samples: Total number of samples Returns: Path to output file """ from datasets import load_dataset output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) samples = [] # Chat data chat_count = int(total_samples * chat_ratio) try: print("[DataGen] Loading UltraChat...") ds = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft") for i, item in enumerate(ds): if i >= chat_count: break messages = item.get("messages", []) if len(messages) >= 2: prompt = messages[-2].get("content", "") response = messages[-1].get("content", "") if prompt and response: samples.append({"prompt": prompt, "response": response, "category": "chat"}) except Exception as e: print(f"[DataGen] UltraChat failed: {e}") # Math data math_count = int(total_samples * math_ratio) try: print("[DataGen] Loading GSM8K...") ds = load_dataset("openai/gsm8k", "main", split="train") for i, item in enumerate(ds): if i >= math_count: break prompt = item.get("question", "") response = item.get("answer", "") if prompt and response: samples.append({"prompt": prompt, "response": response, "category": "math"}) except Exception as e: print(f"[DataGen] GSM8K failed: {e}") # Code data code_count = int(total_samples * code_ratio) try: print("[DataGen] Loading MBPP...") ds = load_dataset("mbpp", split="train") for i, item in enumerate(ds): if i >= code_count: break prompt = item.get("text", item.get("prompt", "")) response = item.get("code", item.get("canonical_solution", "")) if prompt and response: samples.append({"prompt": prompt, "response": response, "category": "code"}) except Exception as e: print(f"[DataGen] MBPP failed: {e}") # Save with open(output_path, "w") as f: for sample in samples: f.write(json.dumps(sample) + "\n") print(f"[DataGen] Created {len(samples)} mixed samples at {output_path}") return str(output_path)