AxiomForgeAI / scripts /gsm8k_sft_pipeline.py
jampuramprem's picture
Initial Space deployment
ec4ae03
#!/usr/bin/env python3
"""
End-to-end GSM8K pipeline: prepare JSONL → QLoRA SFT → save adapter → inference.
The trained model follows ``Step N:`` / ``Final Answer:`` formatting with SymPy-friendly
expressions (see ``src.agent.math_agent.SOLVER_SYSTEM_PROMPT``).
Examples
--------
# 1) Only build training JSONL from Hugging Face GSM8K
python scripts/gsm8k_sft_pipeline.py prepare --output data/sft/gsm8k_sft.jsonl
# 2) Fine-tune (requires GPU recommended)
python scripts/gsm8k_sft_pipeline.py train \\
--data data/sft/gsm8k_sft.jsonl \\
--output-dir checkpoints/gsm8k_sft
# 3) Run inference with saved adapter
python scripts/gsm8k_sft_pipeline.py infer \\
--adapter checkpoints/gsm8k_sft \\
--problem \"Janet has 16 eggs. She eats 3. How many are left?\"
# Full chain
python scripts/gsm8k_sft_pipeline.py all --output-dir checkpoints/gsm8k_sft
Dependencies: torch, transformers, peft, datasets, accelerate, bitsandbytes, trl, sympy
Tip: if downloads fail with XET / "Background writer channel closed", export ``HF_HUB_DISABLE_XET=1``
before running (this script sets it by default unless already set).
"""
from __future__ import annotations
import os
# hf-xet can error or segfault on interrupted/large shards; classic HTTP download is more robust.
if "HF_HUB_DISABLE_XET" not in os.environ:
os.environ["HF_HUB_DISABLE_XET"] = "1"
import argparse
import json
import math
import subprocess
import sys
from pathlib import Path
# Project root (…/Maths_LLM)
ROOT = Path(__file__).resolve().parents[1]
def cmd_prepare(args: argparse.Namespace) -> None:
cmd = [
sys.executable,
str(ROOT / "scripts" / "convert_gsm8k_to_sft.py"),
"--output",
str(Path(args.output)),
"--splits",
*args.splits,
]
if args.source == "jsonl":
cmd.extend(["--source", "jsonl", "--input", str(args.input)])
print("Running:", " ".join(cmd))
subprocess.check_call(cmd, cwd=str(ROOT))
if args.strip_scratchpads:
_rewrite_jsonl_strip_scratchpads(Path(args.output))
def _rewrite_jsonl_strip_scratchpads(jsonl_path: Path) -> None:
from src.sft.solution_format import strip_gsm8k_scratchpads
tmp = jsonl_path.with_suffix(".jsonl.tmp")
n = 0
with jsonl_path.open(encoding="utf-8") as fin, tmp.open("w", encoding="utf-8") as fout:
for line in fin:
o = json.loads(line)
for m in o.get("messages", []):
if m.get("role") == "assistant":
m["content"] = strip_gsm8k_scratchpads(m["content"])
if "text" in o:
sys_p = next(x["content"] for x in o["messages"] if x["role"] == "system")
usr = next(x["content"] for x in o["messages"] if x["role"] == "user")
asst = next(x["content"] for x in o["messages"] if x["role"] == "assistant")
o["text"] = (
f"<|system|>\n{sys_p}\n<|user|>\n{usr}\n<|assistant|>\n{asst}"
)
fout.write(json.dumps(o, ensure_ascii=False) + "\n")
n += 1
tmp.replace(jsonl_path)
print(f"Stripped <<>> scratchpads in {n} records → {jsonl_path}")
def _warmup_steps_from_ratio(
num_examples: int,
per_device_train_batch_size: int,
gradient_accumulation_steps: int,
num_train_epochs: float,
warmup_ratio: float,
) -> int:
"""Approximate HF Trainer optimizer steps; used to map legacy warmup_ratio → warmup_steps."""
if warmup_ratio <= 0:
return 0
num_batches = max(
1,
(num_examples + per_device_train_batch_size - 1) // per_device_train_batch_size,
)
num_update_steps_per_epoch = max(1, num_batches // gradient_accumulation_steps)
total_optimizer_steps = max(1, math.ceil(num_train_epochs * num_update_steps_per_epoch))
return min(total_optimizer_steps, int(total_optimizer_steps * warmup_ratio))
def cmd_train(args: argparse.Namespace) -> None:
try:
import torch
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from trl import SFTConfig, SFTTrainer
except ImportError as e:
raise SystemExit(
"Missing dependency for training. Install:\n"
" pip install torch transformers peft datasets accelerate bitsandbytes trl sympy\n"
f"Original error: {e}"
) from e
data_path = Path(args.data)
if not data_path.is_file():
raise SystemExit(f"Data file not found: {data_path}")
out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
compute_dtype = getattr(torch, args.bnb_compute_dtype)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
)
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
print(f"Loading model {args.model} …")
model = AutoModelForCausalLM.from_pretrained(
args.model,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
dtype=compute_dtype,
)
model = prepare_model_for_kbit_training(model)
peft = LoraConfig(
r=args.lora_rank,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
bias="none",
task_type="CAUSAL_LM",
target_modules=list(args.target_modules.split(",")),
)
model = get_peft_model(model, peft)
model.config.use_cache = False
model.print_trainable_parameters()
ds = load_dataset("json", data_files=str(data_path), split="train")
if args.max_samples and args.max_samples > 0:
ds = ds.select(range(min(args.max_samples, len(ds))))
def formatting_func(example):
return tokenizer.apply_chat_template(
example["messages"],
tokenize=False,
add_generation_prompt=False,
)
if args.warmup_steps is not None:
warmup_steps = max(0, args.warmup_steps)
else:
warmup_steps = _warmup_steps_from_ratio(
len(ds),
args.batch_size,
args.grad_accum,
args.epochs,
args.warmup_ratio,
)
sft_args = SFTConfig(
output_dir=str(out_dir),
num_train_epochs=args.epochs,
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=args.grad_accum,
learning_rate=args.learning_rate,
logging_steps=args.logging_steps,
save_steps=args.save_steps,
save_total_limit=3,
bf16=args.bf16 and torch.cuda.is_available(),
fp16=args.fp16 and torch.cuda.is_available() and not args.bf16,
max_length=args.max_seq_length,
warmup_steps=warmup_steps,
lr_scheduler_type="cosine",
report_to="none",
gradient_checkpointing=True,
)
trainer = SFTTrainer(
model=model,
args=sft_args,
train_dataset=ds,
processing_class=tokenizer,
formatting_func=formatting_func,
)
trainer.train()
trainer.save_model(str(out_dir))
tokenizer.save_pretrained(str(out_dir))
with (out_dir / "pipeline_meta.json").open("w", encoding="utf-8") as f:
json.dump(
{
"base_model": args.model,
"data": str(data_path),
"lora_rank": args.lora_rank,
"epochs": args.epochs,
},
f,
indent=2,
)
print(f"Saved adapter and tokenizer to {out_dir}")
def cmd_infer(args: argparse.Namespace) -> None:
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from src.agent.math_agent import SOLVER_SYSTEM_PROMPT
adapter = Path(args.adapter)
meta_path = adapter / "pipeline_meta.json"
base_model = args.base_model
if meta_path.is_file():
meta = json.loads(meta_path.read_text(encoding="utf-8"))
base_model = meta.get("base_model", base_model)
compute_dtype = getattr(torch, args.bnb_compute_dtype)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
)
tokenizer = AutoTokenizer.from_pretrained(adapter, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print(f"Loading base {base_model} + adapter {adapter} …")
base = AutoModelForCausalLM.from_pretrained(
base_model,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
)
model = PeftModel.from_pretrained(base, str(adapter))
model.eval()
user_content = (
"Solve the following problem. Show your reasoning as numbered steps, "
"then give the final numeric answer on the last line.\n\n"
f"Problem:\n{args.problem.strip()}"
)
messages = [
{"role": "system", "content": SOLVER_SYSTEM_PROMPT},
{"role": "user", "content": user_content},
]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_p=args.top_p,
do_sample=not args.greedy,
pad_token_id=tokenizer.pad_token_id,
)
gen_ids = out[0, inputs["input_ids"].shape[1] :]
text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
print("\n--- Generated ---\n")
print(text)
print("\n--- Format check ---")
from src.sft.solution_format import validate_sympy_solution_format
r = validate_sympy_solution_format(text)
print(json.dumps(r.__dict__, indent=2))
def cmd_all(args: argparse.Namespace) -> None:
out_jsonl = Path(args.data) if args.data else ROOT / "data" / "sft" / "gsm8k_sft.jsonl"
ns = argparse.Namespace(
output=out_jsonl,
source=args.prepare_source,
input=args.input,
splits=args.splits,
strip_scratchpads=args.strip_scratchpads,
)
cmd_prepare(ns)
train_ns = argparse.Namespace(
data=str(out_jsonl),
output_dir=args.output_dir,
model=args.model,
epochs=args.epochs,
batch_size=args.batch_size,
grad_accum=args.grad_accum,
learning_rate=args.learning_rate,
max_samples=args.max_samples,
lora_rank=args.lora_rank,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
target_modules=args.target_modules,
max_seq_length=args.max_seq_length,
save_steps=args.save_steps,
logging_steps=args.logging_steps,
warmup_ratio=args.warmup_ratio,
warmup_steps=args.warmup_steps,
bf16=args.bf16,
fp16=args.fp16,
bnb_compute_dtype=args.bnb_compute_dtype,
)
cmd_train(train_ns)
if args.problem:
infer_ns = argparse.Namespace(
adapter=Path(args.output_dir),
base_model=args.model,
problem=args.problem,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_p=args.top_p,
greedy=args.greedy,
bnb_compute_dtype=args.bnb_compute_dtype,
)
cmd_infer(infer_ns)
def build_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="GSM8K SFT pipeline (prepare / train / infer / all)")
sub = p.add_subparsers(dest="command", required=True)
pr = sub.add_parser("prepare", help="Run convert_gsm8k_to_sft.py")
pr.add_argument("--output", type=str, default=str(ROOT / "data" / "sft" / "gsm8k_sft.jsonl"))
pr.add_argument("--source", choices=("hf", "jsonl"), default="hf")
pr.add_argument("--input", type=str, help="JSONL path for --source jsonl")
pr.add_argument("--splits", nargs="+", default=["train", "test"])
pr.add_argument(
"--strip-scratchpads",
action="store_true",
help="Remove GSM8K <<...>> traces from assistant text after conversion.",
)
pr.set_defaults(func=cmd_prepare)
tr = sub.add_parser("train", help="QLoRA SFT on JSONL with messages field")
tr.add_argument("--data", type=str, required=True, help="JSONL from prepare step")
tr.add_argument("--output-dir", type=str, required=True)
tr.add_argument("--model", type=str, default="Qwen/Qwen2.5-Math-1.5B-Instruct")
tr.add_argument("--epochs", type=float, default=1.0)
tr.add_argument("--batch-size", type=int, default=1)
tr.add_argument("--grad-accum", type=int, default=8)
tr.add_argument("--learning-rate", type=float, default=2e-4)
tr.add_argument("--max-samples", type=int, default=0, help="0 = use full dataset")
tr.add_argument("--lora-rank", type=int, default=16)
tr.add_argument("--lora-alpha", type=int, default=32)
tr.add_argument("--lora-dropout", type=float, default=0.05)
tr.add_argument(
"--target-modules",
type=str,
default="q_proj,v_proj,o_proj,gate_proj",
)
tr.add_argument("--max-seq-length", type=int, default=2048)
tr.add_argument("--save-steps", type=int, default=200)
tr.add_argument("--logging-steps", type=int, default=10)
tr.add_argument(
"--warmup-ratio",
type=float,
default=0.03,
help="Used only if --warmup-steps is not set; converted to warmup_steps.",
)
tr.add_argument(
"--warmup-steps",
type=int,
default=None,
help="LR warmup steps; if set, overrides --warmup-ratio.",
)
tr.add_argument("--bf16", action="store_true", default=True)
tr.add_argument("--no-bf16", dest="bf16", action="store_false")
tr.add_argument("--fp16", action="store_true")
tr.add_argument("--bnb-compute-dtype", type=str, default="bfloat16")
tr.set_defaults(func=cmd_train)
inf = sub.add_parser("infer", help="Generate with saved adapter")
inf.add_argument("--adapter", type=str, required=True, help="Directory from train step")
inf.add_argument(
"--base-model",
type=str,
default="Qwen/Qwen2.5-Math-1.5B-Instruct",
help="Must match base used in training if no pipeline_meta.json",
)
inf.add_argument("--problem", type=str, required=True)
inf.add_argument("--max-new-tokens", type=int, default=1024)
inf.add_argument("--temperature", type=float, default=0.7)
inf.add_argument("--top-p", type=float, default=0.95)
inf.add_argument("--greedy", action="store_true")
inf.add_argument("--bnb-compute-dtype", type=str, default="bfloat16")
inf.set_defaults(func=cmd_infer)
al = sub.add_parser("all", help="prepare + train [+ infer if --problem]")
al.add_argument("--data", type=str, default=None, help="Output JSONL path (default data/sft/gsm8k_sft.jsonl)")
al.add_argument("--prepare-source", choices=("hf", "jsonl"), default="hf")
al.add_argument("--input", type=str, help="For jsonl prepare")
al.add_argument("--splits", nargs="+", default=["train", "test"])
al.add_argument("--strip-scratchpads", action="store_true")
al.add_argument("--output-dir", type=str, required=True)
al.add_argument("--model", type=str, default="Qwen/Qwen2.5-Math-1.5B-Instruct")
al.add_argument("--epochs", type=float, default=1.0)
al.add_argument("--batch-size", type=int, default=1)
al.add_argument("--grad-accum", type=int, default=8)
al.add_argument("--learning-rate", type=float, default=2e-4)
al.add_argument("--max-samples", type=int, default=0)
al.add_argument("--lora-rank", type=int, default=16)
al.add_argument("--lora-alpha", type=int, default=32)
al.add_argument("--lora-dropout", type=float, default=0.05)
al.add_argument("--target-modules", type=str, default="q_proj,v_proj,o_proj,gate_proj")
al.add_argument("--max-seq-length", type=int, default=2048)
al.add_argument("--save-steps", type=int, default=200)
al.add_argument("--logging-steps", type=int, default=10)
al.add_argument(
"--warmup-ratio",
type=float,
default=0.03,
help="Used only if --warmup-steps is not set; converted to warmup_steps.",
)
al.add_argument(
"--warmup-steps",
type=int,
default=None,
help="LR warmup steps; if set, overrides --warmup-ratio.",
)
al.add_argument("--bf16", action="store_true", default=True)
al.add_argument("--no-bf16", dest="bf16", action="store_false")
al.add_argument("--fp16", action="store_true")
al.add_argument("--bnb-compute-dtype", type=str, default="bfloat16")
al.add_argument("--problem", type=str, default="", help="If set, run infer after train")
al.add_argument("--max-new-tokens", type=int, default=1024)
al.add_argument("--temperature", type=float, default=0.7)
al.add_argument("--top-p", type=float, default=0.95)
al.add_argument("--greedy", action="store_true")
al.set_defaults(func=cmd_all)
return p
def main() -> None:
parser = build_parser()
args = parser.parse_args()
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
args.func(args)
if __name__ == "__main__":
main()