Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Create dual-task training dataset by mixing question-generation and solution-generation examples. | |
| This script: | |
| 1. Loads existing solution data (GSM8K format) | |
| 2. Loads question-generation data (synthetic) | |
| 3. Adds task prefixes to distinguish tasks | |
| 4. Mixes datasets according to specified ratio | |
| 5. Shuffles and splits into train/validation | |
| Usage: | |
| python scripts/create_dual_task_dataset.py \ | |
| --solution-data data/sft/gsm8k_sft.jsonl \ | |
| --question-data data/sft/question_generation.jsonl \ | |
| --output-train data/sft/dual_task_train.jsonl \ | |
| --output-val data/sft/dual_task_val.jsonl \ | |
| --mix-ratio 0.8 \ | |
| --val-split 0.1 | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import random | |
| import sys | |
| from pathlib import Path | |
| from typing import Any | |
| ROOT = Path(__file__).resolve().parents[1] | |
| sys.path.insert(0, str(ROOT)) | |
| from src.config.prompts import SOLVE_TASK_PREFIX, GENERATE_TASK_PREFIX | |
| def load_jsonl(path: Path) -> list[dict[str, Any]]: | |
| """Load JSONL file into list of records.""" | |
| records = [] | |
| with path.open(encoding="utf-8") as f: | |
| for line in f: | |
| line = line.strip() | |
| if line: | |
| records.append(json.loads(line)) | |
| return records | |
| def add_solve_prefix(record: dict[str, Any]) -> dict[str, Any]: | |
| """ | |
| Add 'Solve Problem' task prefix to user message. | |
| This signals the model to generate a step-by-step solution. | |
| """ | |
| modified = record.copy() | |
| modified["messages"] = [] | |
| for msg in record["messages"]: | |
| new_msg = msg.copy() | |
| if msg["role"] == "user": | |
| # Add task prefix to user content | |
| content = msg["content"] | |
| if not content.startswith(SOLVE_TASK_PREFIX): | |
| new_msg["content"] = SOLVE_TASK_PREFIX + content | |
| modified["messages"].append(new_msg) | |
| # Update text field if present | |
| if "text" in modified: | |
| # Find and update user section | |
| text = modified["text"] | |
| if "<|user|>" in text: | |
| parts = text.split("<|user|>") | |
| if len(parts) > 1: | |
| user_part = parts[1] | |
| if not user_part.strip().startswith(SOLVE_TASK_PREFIX): | |
| parts[1] = f"\n{SOLVE_TASK_PREFIX}" + user_part | |
| modified["text"] = "<|user|>".join(parts) | |
| # Mark as solve task | |
| modified["task_type"] = "solve" | |
| return modified | |
| def verify_question_prefix(record: dict[str, Any]) -> dict[str, Any]: | |
| """ | |
| Verify question generation record has proper prefix. | |
| Should already have it from generation script, but double-check. | |
| """ | |
| modified = record.copy() | |
| modified["messages"] = [] | |
| for msg in record["messages"]: | |
| new_msg = msg.copy() | |
| if msg["role"] == "user": | |
| content = msg["content"] | |
| if not content.startswith(GENERATE_TASK_PREFIX): | |
| new_msg["content"] = GENERATE_TASK_PREFIX + content | |
| modified["messages"].append(new_msg) | |
| # Update text field if present | |
| if "text" in modified: | |
| text = modified["text"] | |
| if "<|user|>" in text: | |
| parts = text.split("<|user|>") | |
| if len(parts) > 1: | |
| user_part = parts[1] | |
| if not user_part.strip().startswith(GENERATE_TASK_PREFIX): | |
| parts[1] = f"\n{GENERATE_TASK_PREFIX}" + user_part | |
| modified["text"] = "<|user|>".join(parts) | |
| # Mark as question generation task | |
| modified["task_type"] = "generate" | |
| return modified | |
| def sample_with_ratio( | |
| solution_records: list[dict[str, Any]], | |
| question_records: list[dict[str, Any]], | |
| mix_ratio: float, | |
| target_total: int | None = None, | |
| ) -> list[dict[str, Any]]: | |
| """ | |
| Sample and mix datasets according to specified ratio. | |
| Args: | |
| solution_records: Solution examples | |
| question_records: Question generation examples | |
| mix_ratio: Fraction of solutions in final dataset (0.8 = 80% solutions, 20% questions) | |
| target_total: Target total examples (None = use all available data) | |
| Returns: | |
| Mixed dataset | |
| """ | |
| n_solutions = len(solution_records) | |
| n_questions = len(question_records) | |
| if target_total is None: | |
| # Use all available data | |
| target_total = n_solutions + n_questions | |
| # Calculate target counts | |
| n_sol_target = int(target_total * mix_ratio) | |
| n_q_target = target_total - n_sol_target | |
| # Check availability | |
| if n_sol_target > n_solutions: | |
| print(f"Warning: Requested {n_sol_target} solutions but only {n_solutions} available.") | |
| n_sol_target = n_solutions | |
| if n_q_target > n_questions: | |
| print(f"Warning: Requested {n_q_target} questions but only {n_questions} available.") | |
| n_q_target = n_questions | |
| # Sample | |
| selected_solutions = random.sample(solution_records, n_sol_target) | |
| selected_questions = random.sample(question_records, n_q_target) | |
| print(f"Sampled {n_sol_target} solutions and {n_q_target} questions") | |
| print(f"Actual ratio: {n_sol_target/(n_sol_target+n_q_target):.2%} solutions, " | |
| f"{n_q_target/(n_sol_target+n_q_target):.2%} questions") | |
| return selected_solutions + selected_questions | |
| def write_jsonl(records: list[dict[str, Any]], path: Path) -> None: | |
| """Write records to JSONL file.""" | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| with path.open("w", encoding="utf-8") as f: | |
| for record in records: | |
| f.write(json.dumps(record, ensure_ascii=False) + "\n") | |
| def main() -> None: | |
| parser = argparse.ArgumentParser( | |
| description="Create dual-task training dataset from solution and question-generation examples." | |
| ) | |
| parser.add_argument( | |
| "--solution-data", | |
| type=Path, | |
| required=True, | |
| help="Path to solution training data (GSM8K format)", | |
| ) | |
| parser.add_argument( | |
| "--question-data", | |
| type=Path, | |
| required=True, | |
| help="Path to question-generation training data", | |
| ) | |
| parser.add_argument( | |
| "--output-train", | |
| type=Path, | |
| required=True, | |
| help="Output path for training split", | |
| ) | |
| parser.add_argument( | |
| "--output-val", | |
| type=Path, | |
| required=True, | |
| help="Output path for validation split", | |
| ) | |
| parser.add_argument( | |
| "--mix-ratio", | |
| type=float, | |
| default=0.8, | |
| help="Fraction of solutions in mixed dataset (default: 0.8 = 80%% solutions)", | |
| ) | |
| parser.add_argument( | |
| "--val-split", | |
| type=float, | |
| default=0.1, | |
| help="Fraction of data to use for validation (default: 0.1 = 10%%)", | |
| ) | |
| parser.add_argument( | |
| "--seed", | |
| type=int, | |
| default=42, | |
| help="Random seed for reproducibility", | |
| ) | |
| parser.add_argument( | |
| "--max-total", | |
| type=int, | |
| default=None, | |
| help="Maximum total examples to include (None = use all available)", | |
| ) | |
| args = parser.parse_args() | |
| # Validate inputs | |
| if not args.solution_data.exists(): | |
| raise SystemExit(f"Error: Solution data not found at {args.solution_data}") | |
| if not args.question_data.exists(): | |
| raise SystemExit(f"Error: Question data not found at {args.question_data}") | |
| if not (0 < args.mix_ratio < 1): | |
| raise SystemExit("Error: --mix-ratio must be between 0 and 1") | |
| if not (0 < args.val_split < 1): | |
| raise SystemExit("Error: --val-split must be between 0 and 1") | |
| # Set random seed | |
| random.seed(args.seed) | |
| print("=" * 60) | |
| print("Dual-Task Dataset Creation") | |
| print("=" * 60) | |
| # Load data | |
| print("\n1. Loading data...") | |
| print(f" Solution data: {args.solution_data}") | |
| solution_records = load_jsonl(args.solution_data) | |
| print(f" Loaded {len(solution_records)} solution examples") | |
| print(f" Question data: {args.question_data}") | |
| question_records = load_jsonl(args.question_data) | |
| print(f" Loaded {len(question_records)} question-generation examples") | |
| # Add task prefixes | |
| print("\n2. Adding task prefixes...") | |
| print(" Adding 'Solve Problem' prefix to solution examples...") | |
| solution_records = [add_solve_prefix(r) for r in solution_records] | |
| print(" Verifying 'Generate Question' prefix on question examples...") | |
| question_records = [verify_question_prefix(r) for r in question_records] | |
| # Mix datasets | |
| print(f"\n3. Mixing datasets (ratio: {args.mix_ratio:.0%} solutions, {1-args.mix_ratio:.0%} questions)...") | |
| mixed_records = sample_with_ratio( | |
| solution_records=solution_records, | |
| question_records=question_records, | |
| mix_ratio=args.mix_ratio, | |
| target_total=args.max_total, | |
| ) | |
| # Shuffle | |
| print(f"\n4. Shuffling {len(mixed_records)} total examples...") | |
| random.shuffle(mixed_records) | |
| # Split train/val | |
| n_val = int(len(mixed_records) * args.val_split) | |
| n_train = len(mixed_records) - n_val | |
| train_records = mixed_records[:n_train] | |
| val_records = mixed_records[n_train:] | |
| print(f"\n5. Splitting data:") | |
| print(f" Training: {len(train_records)} examples ({len(train_records)/len(mixed_records):.1%})") | |
| print(f" Validation: {len(val_records)} examples ({len(val_records)/len(mixed_records):.1%})") | |
| # Verify split composition | |
| train_solve = sum(1 for r in train_records if r.get("task_type") == "solve") | |
| train_gen = sum(1 for r in train_records if r.get("task_type") == "generate") | |
| val_solve = sum(1 for r in val_records if r.get("task_type") == "solve") | |
| val_gen = sum(1 for r in val_records if r.get("task_type") == "generate") | |
| print(f"\n Train composition:") | |
| print(f" Solve: {train_solve} ({train_solve/len(train_records):.1%})") | |
| print(f" Generate: {train_gen} ({train_gen/len(train_records):.1%})") | |
| print(f" Val composition:") | |
| print(f" Solve: {val_solve} ({val_solve/len(val_records):.1%})") | |
| print(f" Generate: {val_gen} ({val_gen/len(val_records):.1%})") | |
| # Write outputs | |
| print(f"\n6. Writing output files...") | |
| print(f" Training data: {args.output_train}") | |
| write_jsonl(train_records, args.output_train) | |
| print(f" Validation data: {args.output_val}") | |
| write_jsonl(val_records, args.output_val) | |
| print("\n" + "=" * 60) | |
| print("Dual-task dataset creation complete!") | |
| print("=" * 60) | |
| print(f"\nOutput files:") | |
| print(f" Train: {args.output_train} ({len(train_records)} examples)") | |
| print(f" Val: {args.output_val} ({len(val_records)} examples)") | |
| print(f"\nNext step: Train dual-task model using these files") | |
| if __name__ == "__main__": | |
| main() | |