AxiomForgeAI / scripts /create_dual_task_dataset.py
jampuramprem's picture
Initial Space deployment
ec4ae03
#!/usr/bin/env python3
"""
Create dual-task training dataset by mixing question-generation and solution-generation examples.
This script:
1. Loads existing solution data (GSM8K format)
2. Loads question-generation data (synthetic)
3. Adds task prefixes to distinguish tasks
4. Mixes datasets according to specified ratio
5. Shuffles and splits into train/validation
Usage:
python scripts/create_dual_task_dataset.py \
--solution-data data/sft/gsm8k_sft.jsonl \
--question-data data/sft/question_generation.jsonl \
--output-train data/sft/dual_task_train.jsonl \
--output-val data/sft/dual_task_val.jsonl \
--mix-ratio 0.8 \
--val-split 0.1
"""
from __future__ import annotations
import argparse
import json
import random
import sys
from pathlib import Path
from typing import Any
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT))
from src.config.prompts import SOLVE_TASK_PREFIX, GENERATE_TASK_PREFIX
def load_jsonl(path: Path) -> list[dict[str, Any]]:
"""Load JSONL file into list of records."""
records = []
with path.open(encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
records.append(json.loads(line))
return records
def add_solve_prefix(record: dict[str, Any]) -> dict[str, Any]:
"""
Add 'Solve Problem' task prefix to user message.
This signals the model to generate a step-by-step solution.
"""
modified = record.copy()
modified["messages"] = []
for msg in record["messages"]:
new_msg = msg.copy()
if msg["role"] == "user":
# Add task prefix to user content
content = msg["content"]
if not content.startswith(SOLVE_TASK_PREFIX):
new_msg["content"] = SOLVE_TASK_PREFIX + content
modified["messages"].append(new_msg)
# Update text field if present
if "text" in modified:
# Find and update user section
text = modified["text"]
if "<|user|>" in text:
parts = text.split("<|user|>")
if len(parts) > 1:
user_part = parts[1]
if not user_part.strip().startswith(SOLVE_TASK_PREFIX):
parts[1] = f"\n{SOLVE_TASK_PREFIX}" + user_part
modified["text"] = "<|user|>".join(parts)
# Mark as solve task
modified["task_type"] = "solve"
return modified
def verify_question_prefix(record: dict[str, Any]) -> dict[str, Any]:
"""
Verify question generation record has proper prefix.
Should already have it from generation script, but double-check.
"""
modified = record.copy()
modified["messages"] = []
for msg in record["messages"]:
new_msg = msg.copy()
if msg["role"] == "user":
content = msg["content"]
if not content.startswith(GENERATE_TASK_PREFIX):
new_msg["content"] = GENERATE_TASK_PREFIX + content
modified["messages"].append(new_msg)
# Update text field if present
if "text" in modified:
text = modified["text"]
if "<|user|>" in text:
parts = text.split("<|user|>")
if len(parts) > 1:
user_part = parts[1]
if not user_part.strip().startswith(GENERATE_TASK_PREFIX):
parts[1] = f"\n{GENERATE_TASK_PREFIX}" + user_part
modified["text"] = "<|user|>".join(parts)
# Mark as question generation task
modified["task_type"] = "generate"
return modified
def sample_with_ratio(
solution_records: list[dict[str, Any]],
question_records: list[dict[str, Any]],
mix_ratio: float,
target_total: int | None = None,
) -> list[dict[str, Any]]:
"""
Sample and mix datasets according to specified ratio.
Args:
solution_records: Solution examples
question_records: Question generation examples
mix_ratio: Fraction of solutions in final dataset (0.8 = 80% solutions, 20% questions)
target_total: Target total examples (None = use all available data)
Returns:
Mixed dataset
"""
n_solutions = len(solution_records)
n_questions = len(question_records)
if target_total is None:
# Use all available data
target_total = n_solutions + n_questions
# Calculate target counts
n_sol_target = int(target_total * mix_ratio)
n_q_target = target_total - n_sol_target
# Check availability
if n_sol_target > n_solutions:
print(f"Warning: Requested {n_sol_target} solutions but only {n_solutions} available.")
n_sol_target = n_solutions
if n_q_target > n_questions:
print(f"Warning: Requested {n_q_target} questions but only {n_questions} available.")
n_q_target = n_questions
# Sample
selected_solutions = random.sample(solution_records, n_sol_target)
selected_questions = random.sample(question_records, n_q_target)
print(f"Sampled {n_sol_target} solutions and {n_q_target} questions")
print(f"Actual ratio: {n_sol_target/(n_sol_target+n_q_target):.2%} solutions, "
f"{n_q_target/(n_sol_target+n_q_target):.2%} questions")
return selected_solutions + selected_questions
def write_jsonl(records: list[dict[str, Any]], path: Path) -> None:
"""Write records to JSONL file."""
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w", encoding="utf-8") as f:
for record in records:
f.write(json.dumps(record, ensure_ascii=False) + "\n")
def main() -> None:
parser = argparse.ArgumentParser(
description="Create dual-task training dataset from solution and question-generation examples."
)
parser.add_argument(
"--solution-data",
type=Path,
required=True,
help="Path to solution training data (GSM8K format)",
)
parser.add_argument(
"--question-data",
type=Path,
required=True,
help="Path to question-generation training data",
)
parser.add_argument(
"--output-train",
type=Path,
required=True,
help="Output path for training split",
)
parser.add_argument(
"--output-val",
type=Path,
required=True,
help="Output path for validation split",
)
parser.add_argument(
"--mix-ratio",
type=float,
default=0.8,
help="Fraction of solutions in mixed dataset (default: 0.8 = 80%% solutions)",
)
parser.add_argument(
"--val-split",
type=float,
default=0.1,
help="Fraction of data to use for validation (default: 0.1 = 10%%)",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="Random seed for reproducibility",
)
parser.add_argument(
"--max-total",
type=int,
default=None,
help="Maximum total examples to include (None = use all available)",
)
args = parser.parse_args()
# Validate inputs
if not args.solution_data.exists():
raise SystemExit(f"Error: Solution data not found at {args.solution_data}")
if not args.question_data.exists():
raise SystemExit(f"Error: Question data not found at {args.question_data}")
if not (0 < args.mix_ratio < 1):
raise SystemExit("Error: --mix-ratio must be between 0 and 1")
if not (0 < args.val_split < 1):
raise SystemExit("Error: --val-split must be between 0 and 1")
# Set random seed
random.seed(args.seed)
print("=" * 60)
print("Dual-Task Dataset Creation")
print("=" * 60)
# Load data
print("\n1. Loading data...")
print(f" Solution data: {args.solution_data}")
solution_records = load_jsonl(args.solution_data)
print(f" Loaded {len(solution_records)} solution examples")
print(f" Question data: {args.question_data}")
question_records = load_jsonl(args.question_data)
print(f" Loaded {len(question_records)} question-generation examples")
# Add task prefixes
print("\n2. Adding task prefixes...")
print(" Adding 'Solve Problem' prefix to solution examples...")
solution_records = [add_solve_prefix(r) for r in solution_records]
print(" Verifying 'Generate Question' prefix on question examples...")
question_records = [verify_question_prefix(r) for r in question_records]
# Mix datasets
print(f"\n3. Mixing datasets (ratio: {args.mix_ratio:.0%} solutions, {1-args.mix_ratio:.0%} questions)...")
mixed_records = sample_with_ratio(
solution_records=solution_records,
question_records=question_records,
mix_ratio=args.mix_ratio,
target_total=args.max_total,
)
# Shuffle
print(f"\n4. Shuffling {len(mixed_records)} total examples...")
random.shuffle(mixed_records)
# Split train/val
n_val = int(len(mixed_records) * args.val_split)
n_train = len(mixed_records) - n_val
train_records = mixed_records[:n_train]
val_records = mixed_records[n_train:]
print(f"\n5. Splitting data:")
print(f" Training: {len(train_records)} examples ({len(train_records)/len(mixed_records):.1%})")
print(f" Validation: {len(val_records)} examples ({len(val_records)/len(mixed_records):.1%})")
# Verify split composition
train_solve = sum(1 for r in train_records if r.get("task_type") == "solve")
train_gen = sum(1 for r in train_records if r.get("task_type") == "generate")
val_solve = sum(1 for r in val_records if r.get("task_type") == "solve")
val_gen = sum(1 for r in val_records if r.get("task_type") == "generate")
print(f"\n Train composition:")
print(f" Solve: {train_solve} ({train_solve/len(train_records):.1%})")
print(f" Generate: {train_gen} ({train_gen/len(train_records):.1%})")
print(f" Val composition:")
print(f" Solve: {val_solve} ({val_solve/len(val_records):.1%})")
print(f" Generate: {val_gen} ({val_gen/len(val_records):.1%})")
# Write outputs
print(f"\n6. Writing output files...")
print(f" Training data: {args.output_train}")
write_jsonl(train_records, args.output_train)
print(f" Validation data: {args.output_val}")
write_jsonl(val_records, args.output_val)
print("\n" + "=" * 60)
print("Dual-task dataset creation complete!")
print("=" * 60)
print(f"\nOutput files:")
print(f" Train: {args.output_train} ({len(train_records)} examples)")
print(f" Val: {args.output_val} ({len(val_records)} examples)")
print(f"\nNext step: Train dual-task model using these files")
if __name__ == "__main__":
main()