#!/usr/bin/env python3 """ Convert OpenAI GSM8K to SFT JSONL aligned with MathAgent solver format: Step 1: ... Step 2: ... ... Final Answer: Each record uses a chat messages list for Qwen-style fine-tuning. Usage ----- # From Hugging Face (default; same data as in test.ipynb) python scripts/convert_gsm8k_to_sft.py \\ --output data/sft/gsm8k_sft.jsonl \\ --splits train test # From a saved JSONL with columns \"question\" and \"answer\" (GSM8K schema) python scripts/convert_gsm8k_to_sft.py \\ --source jsonl \\ --input path/to/file.jsonl \\ --output data/sft/gsm8k_sft.jsonl Requires: pip install datasets (and datasets will pull pyarrow as needed) """ from __future__ import annotations import argparse import json import re from pathlib import Path from typing import Any, Iterator # Keep in sync with src.agent.math_agent.SOLVER_SYSTEM_PROMPT SOLVER_SYSTEM_PROMPT = ( "You are a step-by-step math solver. " "Solve the given problem one step at a time. " "Each step must be on its own line, starting with 'Step N:'. " "End with a line starting with 'Final Answer:'. " "Write every mathematical expression in Python/SymPy syntax " "so it can be verified programmatically." ) USER_WRAPPER = ( "Solve the following problem. Show your reasoning as numbered steps, " "then give the final numeric answer on the last line.\n\nProblem:\n{question}" ) def parse_gsm8k_answer(raw_answer: str) -> tuple[str, str]: """ Split GSM8K 'answer' field into reasoning text and final integer string. GSM8K ends solutions with a line like: #### 42 """ text = raw_answer.strip() parts = re.split(r"\s*####\s*", text, maxsplit=1) reasoning = parts[0].strip() final = parts[1].strip() if len(parts) > 1 else "" # Normalize final (sometimes extra whitespace or commas) final = re.sub(r"[,\s]+", "", final) final_match = re.search(r"-?\d+", final) final_clean = final_match.group(0) if final_match else final return reasoning, final_clean def reasoning_to_step_lines(reasoning: str) -> list[str]: """Turn reasoning into non-empty lines; each line becomes one Step N:.""" lines: list[str] = [] for raw in reasoning.splitlines(): line = raw.strip() if line: lines.append(line) if not lines: # Rare: single blob without newlines — split on sentence boundaries lightly blob = reasoning.strip() if blob: chunks = re.split(r"(?<=[.!?])\s+", blob) lines = [c.strip() for c in chunks if c.strip()] return lines def build_assistant_content(reasoning: str, final_answer: str) -> str: lines = reasoning_to_step_lines(reasoning) out_parts: list[str] = [] for i, line in enumerate(lines, start=1): # Prefer SymPy-friendly numerics: ** not ^, ascii-friendly cleaned = line.replace("^", "**") out_parts.append(f"Step {i}: {cleaned}") body = "\n".join(out_parts) if final_answer: body = f"{body}\nFinal Answer: {final_answer}" if body else f"Final Answer: {final_answer}" return body def row_to_record( question: str, answer: str, example_id: str, split: str, ) -> dict[str, Any] | None: reasoning, final_answer = parse_gsm8k_answer(answer) if not final_answer and "####" not in answer: return None assistant = build_assistant_content(reasoning, final_answer) if not assistant.strip(): return None user_content = USER_WRAPPER.format(question=question.strip()) return { "id": f"gsm8k_{example_id}", "skill_id": "gsm8k_grade_school", "source": "openai/gsm8k", "split": split, "messages": [ {"role": "system", "content": SOLVER_SYSTEM_PROMPT}, {"role": "user", "content": user_content}, {"role": "assistant", "content": assistant}, ], # Convenience for non-chat trainers "text": f"<|system|>\n{SOLVER_SYSTEM_PROMPT}\n<|user|>\n{user_content}\n<|assistant|>\n{assistant}", } def iter_hf_rows(dataset_name: str, config: str, splits: list[str]) -> Iterator[tuple[str, str, dict]]: from datasets import load_dataset ds = load_dataset(dataset_name, config) for split in splits: if split not in ds: raise KeyError(f"Split {split!r} not in dataset. Available: {list(ds.keys())}") for i, row in enumerate(ds[split]): yield f"{split}_{i}", split, row def main() -> None: p = argparse.ArgumentParser(description="Convert GSM8K to SFT JSONL (chat messages).") p.add_argument( "--source", choices=("hf", "jsonl"), default="hf", help="Load from Hugging Face dataset or a local JSONL file.", ) p.add_argument("--dataset", default="openai/gsm8k", help="HF dataset id when --source hf.") p.add_argument("--config", default="main", help="HF config name when --source hf.") p.add_argument("--splits", nargs="+", default=["train", "test"], help="HF splits to export.") p.add_argument("--input", type=Path, help="Local JSONL path when --source jsonl.") p.add_argument( "--output", type=Path, default=Path("data/sft/gsm8k_sft.jsonl"), help="Output JSONL path.", ) args = p.parse_args() if args.source == "jsonl" and not args.input: raise SystemExit("--input is required when --source jsonl") args.output.parent.mkdir(parents=True, exist_ok=True) n_ok, n_skip = 0, 0 def process(example_id: str, split: str, row: dict) -> None: nonlocal n_ok, n_skip q = row.get("question", "") a = row.get("answer", "") rec = row_to_record(q, a, example_id, split) if rec is None: n_skip += 1 return out_f.write(json.dumps(rec, ensure_ascii=False) + "\n") n_ok += 1 with args.output.open("w", encoding="utf-8") as out_f: if args.source == "hf": for example_id, split, row in iter_hf_rows(args.dataset, args.config, args.splits): process(example_id, split, row) else: for i, line in enumerate(args.input.open(encoding="utf-8")): line = line.strip() if not line: continue row = json.loads(line) process(str(i), "jsonl", row) print(f"Wrote {n_ok} examples to {args.output} ({n_skip} skipped).") if __name__ == "__main__": main()