Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| End-to-end GSM8K pipeline: prepare JSONL → QLoRA SFT → save adapter → inference. | |
| The trained model follows ``Step N:`` / ``Final Answer:`` formatting with SymPy-friendly | |
| expressions (see ``src.agent.math_agent.SOLVER_SYSTEM_PROMPT``). | |
| Examples | |
| -------- | |
| # 1) Only build training JSONL from Hugging Face GSM8K | |
| python scripts/gsm8k_sft_pipeline.py prepare --output data/sft/gsm8k_sft.jsonl | |
| # 2) Fine-tune (requires GPU recommended) | |
| python scripts/gsm8k_sft_pipeline.py train \\ | |
| --data data/sft/gsm8k_sft.jsonl \\ | |
| --output-dir checkpoints/gsm8k_sft | |
| # 3) Run inference with saved adapter | |
| python scripts/gsm8k_sft_pipeline.py infer \\ | |
| --adapter checkpoints/gsm8k_sft \\ | |
| --problem \"Janet has 16 eggs. She eats 3. How many are left?\" | |
| # Full chain | |
| python scripts/gsm8k_sft_pipeline.py all --output-dir checkpoints/gsm8k_sft | |
| Dependencies: torch, transformers, peft, datasets, accelerate, bitsandbytes, trl, sympy | |
| Tip: if downloads fail with XET / "Background writer channel closed", export ``HF_HUB_DISABLE_XET=1`` | |
| before running (this script sets it by default unless already set). | |
| """ | |
| from __future__ import annotations | |
| import os | |
| # hf-xet can error or segfault on interrupted/large shards; classic HTTP download is more robust. | |
| if "HF_HUB_DISABLE_XET" not in os.environ: | |
| os.environ["HF_HUB_DISABLE_XET"] = "1" | |
| import argparse | |
| import json | |
| import math | |
| import subprocess | |
| import sys | |
| from pathlib import Path | |
| # Project root (…/Maths_LLM) | |
| ROOT = Path(__file__).resolve().parents[1] | |
| def cmd_prepare(args: argparse.Namespace) -> None: | |
| cmd = [ | |
| sys.executable, | |
| str(ROOT / "scripts" / "convert_gsm8k_to_sft.py"), | |
| "--output", | |
| str(Path(args.output)), | |
| "--splits", | |
| *args.splits, | |
| ] | |
| if args.source == "jsonl": | |
| cmd.extend(["--source", "jsonl", "--input", str(args.input)]) | |
| print("Running:", " ".join(cmd)) | |
| subprocess.check_call(cmd, cwd=str(ROOT)) | |
| if args.strip_scratchpads: | |
| _rewrite_jsonl_strip_scratchpads(Path(args.output)) | |
| def _rewrite_jsonl_strip_scratchpads(jsonl_path: Path) -> None: | |
| from src.sft.solution_format import strip_gsm8k_scratchpads | |
| tmp = jsonl_path.with_suffix(".jsonl.tmp") | |
| n = 0 | |
| with jsonl_path.open(encoding="utf-8") as fin, tmp.open("w", encoding="utf-8") as fout: | |
| for line in fin: | |
| o = json.loads(line) | |
| for m in o.get("messages", []): | |
| if m.get("role") == "assistant": | |
| m["content"] = strip_gsm8k_scratchpads(m["content"]) | |
| if "text" in o: | |
| sys_p = next(x["content"] for x in o["messages"] if x["role"] == "system") | |
| usr = next(x["content"] for x in o["messages"] if x["role"] == "user") | |
| asst = next(x["content"] for x in o["messages"] if x["role"] == "assistant") | |
| o["text"] = ( | |
| f"<|system|>\n{sys_p}\n<|user|>\n{usr}\n<|assistant|>\n{asst}" | |
| ) | |
| fout.write(json.dumps(o, ensure_ascii=False) + "\n") | |
| n += 1 | |
| tmp.replace(jsonl_path) | |
| print(f"Stripped <<>> scratchpads in {n} records → {jsonl_path}") | |
| def _warmup_steps_from_ratio( | |
| num_examples: int, | |
| per_device_train_batch_size: int, | |
| gradient_accumulation_steps: int, | |
| num_train_epochs: float, | |
| warmup_ratio: float, | |
| ) -> int: | |
| """Approximate HF Trainer optimizer steps; used to map legacy warmup_ratio → warmup_steps.""" | |
| if warmup_ratio <= 0: | |
| return 0 | |
| num_batches = max( | |
| 1, | |
| (num_examples + per_device_train_batch_size - 1) // per_device_train_batch_size, | |
| ) | |
| num_update_steps_per_epoch = max(1, num_batches // gradient_accumulation_steps) | |
| total_optimizer_steps = max(1, math.ceil(num_train_epochs * num_update_steps_per_epoch)) | |
| return min(total_optimizer_steps, int(total_optimizer_steps * warmup_ratio)) | |
| def cmd_train(args: argparse.Namespace) -> None: | |
| try: | |
| import torch | |
| from datasets import load_dataset | |
| from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| from trl import SFTConfig, SFTTrainer | |
| except ImportError as e: | |
| raise SystemExit( | |
| "Missing dependency for training. Install:\n" | |
| " pip install torch transformers peft datasets accelerate bitsandbytes trl sympy\n" | |
| f"Original error: {e}" | |
| ) from e | |
| data_path = Path(args.data) | |
| if not data_path.is_file(): | |
| raise SystemExit(f"Data file not found: {data_path}") | |
| out_dir = Path(args.output_dir) | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| compute_dtype = getattr(torch, args.bnb_compute_dtype) | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=compute_dtype, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.padding_side = "right" | |
| print(f"Loading model {args.model} …") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| args.model, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| dtype=compute_dtype, | |
| ) | |
| model = prepare_model_for_kbit_training(model) | |
| peft = LoraConfig( | |
| r=args.lora_rank, | |
| lora_alpha=args.lora_alpha, | |
| lora_dropout=args.lora_dropout, | |
| bias="none", | |
| task_type="CAUSAL_LM", | |
| target_modules=list(args.target_modules.split(",")), | |
| ) | |
| model = get_peft_model(model, peft) | |
| model.config.use_cache = False | |
| model.print_trainable_parameters() | |
| ds = load_dataset("json", data_files=str(data_path), split="train") | |
| if args.max_samples and args.max_samples > 0: | |
| ds = ds.select(range(min(args.max_samples, len(ds)))) | |
| def formatting_func(example): | |
| return tokenizer.apply_chat_template( | |
| example["messages"], | |
| tokenize=False, | |
| add_generation_prompt=False, | |
| ) | |
| if args.warmup_steps is not None: | |
| warmup_steps = max(0, args.warmup_steps) | |
| else: | |
| warmup_steps = _warmup_steps_from_ratio( | |
| len(ds), | |
| args.batch_size, | |
| args.grad_accum, | |
| args.epochs, | |
| args.warmup_ratio, | |
| ) | |
| sft_args = SFTConfig( | |
| output_dir=str(out_dir), | |
| num_train_epochs=args.epochs, | |
| per_device_train_batch_size=args.batch_size, | |
| gradient_accumulation_steps=args.grad_accum, | |
| learning_rate=args.learning_rate, | |
| logging_steps=args.logging_steps, | |
| save_steps=args.save_steps, | |
| save_total_limit=3, | |
| bf16=args.bf16 and torch.cuda.is_available(), | |
| fp16=args.fp16 and torch.cuda.is_available() and not args.bf16, | |
| max_length=args.max_seq_length, | |
| warmup_steps=warmup_steps, | |
| lr_scheduler_type="cosine", | |
| report_to="none", | |
| gradient_checkpointing=True, | |
| ) | |
| trainer = SFTTrainer( | |
| model=model, | |
| args=sft_args, | |
| train_dataset=ds, | |
| processing_class=tokenizer, | |
| formatting_func=formatting_func, | |
| ) | |
| trainer.train() | |
| trainer.save_model(str(out_dir)) | |
| tokenizer.save_pretrained(str(out_dir)) | |
| with (out_dir / "pipeline_meta.json").open("w", encoding="utf-8") as f: | |
| json.dump( | |
| { | |
| "base_model": args.model, | |
| "data": str(data_path), | |
| "lora_rank": args.lora_rank, | |
| "epochs": args.epochs, | |
| }, | |
| f, | |
| indent=2, | |
| ) | |
| print(f"Saved adapter and tokenizer to {out_dir}") | |
| def cmd_infer(args: argparse.Namespace) -> None: | |
| import torch | |
| from peft import PeftModel | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| from src.agent.math_agent import SOLVER_SYSTEM_PROMPT | |
| adapter = Path(args.adapter) | |
| meta_path = adapter / "pipeline_meta.json" | |
| base_model = args.base_model | |
| if meta_path.is_file(): | |
| meta = json.loads(meta_path.read_text(encoding="utf-8")) | |
| base_model = meta.get("base_model", base_model) | |
| compute_dtype = getattr(torch, args.bnb_compute_dtype) | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=compute_dtype, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(adapter, trust_remote_code=True) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| print(f"Loading base {base_model} + adapter {adapter} …") | |
| base = AutoModelForCausalLM.from_pretrained( | |
| base_model, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| model = PeftModel.from_pretrained(base, str(adapter)) | |
| model.eval() | |
| user_content = ( | |
| "Solve the following problem. Show your reasoning as numbered steps, " | |
| "then give the final numeric answer on the last line.\n\n" | |
| f"Problem:\n{args.problem.strip()}" | |
| ) | |
| messages = [ | |
| {"role": "system", "content": SOLVER_SYSTEM_PROMPT}, | |
| {"role": "user", "content": user_content}, | |
| ] | |
| prompt = tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| out = model.generate( | |
| **inputs, | |
| max_new_tokens=args.max_new_tokens, | |
| temperature=args.temperature, | |
| top_p=args.top_p, | |
| do_sample=not args.greedy, | |
| pad_token_id=tokenizer.pad_token_id, | |
| ) | |
| gen_ids = out[0, inputs["input_ids"].shape[1] :] | |
| text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip() | |
| print("\n--- Generated ---\n") | |
| print(text) | |
| print("\n--- Format check ---") | |
| from src.sft.solution_format import validate_sympy_solution_format | |
| r = validate_sympy_solution_format(text) | |
| print(json.dumps(r.__dict__, indent=2)) | |
| def cmd_all(args: argparse.Namespace) -> None: | |
| out_jsonl = Path(args.data) if args.data else ROOT / "data" / "sft" / "gsm8k_sft.jsonl" | |
| ns = argparse.Namespace( | |
| output=out_jsonl, | |
| source=args.prepare_source, | |
| input=args.input, | |
| splits=args.splits, | |
| strip_scratchpads=args.strip_scratchpads, | |
| ) | |
| cmd_prepare(ns) | |
| train_ns = argparse.Namespace( | |
| data=str(out_jsonl), | |
| output_dir=args.output_dir, | |
| model=args.model, | |
| epochs=args.epochs, | |
| batch_size=args.batch_size, | |
| grad_accum=args.grad_accum, | |
| learning_rate=args.learning_rate, | |
| max_samples=args.max_samples, | |
| lora_rank=args.lora_rank, | |
| lora_alpha=args.lora_alpha, | |
| lora_dropout=args.lora_dropout, | |
| target_modules=args.target_modules, | |
| max_seq_length=args.max_seq_length, | |
| save_steps=args.save_steps, | |
| logging_steps=args.logging_steps, | |
| warmup_ratio=args.warmup_ratio, | |
| warmup_steps=args.warmup_steps, | |
| bf16=args.bf16, | |
| fp16=args.fp16, | |
| bnb_compute_dtype=args.bnb_compute_dtype, | |
| ) | |
| cmd_train(train_ns) | |
| if args.problem: | |
| infer_ns = argparse.Namespace( | |
| adapter=Path(args.output_dir), | |
| base_model=args.model, | |
| problem=args.problem, | |
| max_new_tokens=args.max_new_tokens, | |
| temperature=args.temperature, | |
| top_p=args.top_p, | |
| greedy=args.greedy, | |
| bnb_compute_dtype=args.bnb_compute_dtype, | |
| ) | |
| cmd_infer(infer_ns) | |
| def build_parser() -> argparse.ArgumentParser: | |
| p = argparse.ArgumentParser(description="GSM8K SFT pipeline (prepare / train / infer / all)") | |
| sub = p.add_subparsers(dest="command", required=True) | |
| pr = sub.add_parser("prepare", help="Run convert_gsm8k_to_sft.py") | |
| pr.add_argument("--output", type=str, default=str(ROOT / "data" / "sft" / "gsm8k_sft.jsonl")) | |
| pr.add_argument("--source", choices=("hf", "jsonl"), default="hf") | |
| pr.add_argument("--input", type=str, help="JSONL path for --source jsonl") | |
| pr.add_argument("--splits", nargs="+", default=["train", "test"]) | |
| pr.add_argument( | |
| "--strip-scratchpads", | |
| action="store_true", | |
| help="Remove GSM8K <<...>> traces from assistant text after conversion.", | |
| ) | |
| pr.set_defaults(func=cmd_prepare) | |
| tr = sub.add_parser("train", help="QLoRA SFT on JSONL with messages field") | |
| tr.add_argument("--data", type=str, required=True, help="JSONL from prepare step") | |
| tr.add_argument("--output-dir", type=str, required=True) | |
| tr.add_argument("--model", type=str, default="Qwen/Qwen2.5-Math-1.5B-Instruct") | |
| tr.add_argument("--epochs", type=float, default=1.0) | |
| tr.add_argument("--batch-size", type=int, default=1) | |
| tr.add_argument("--grad-accum", type=int, default=8) | |
| tr.add_argument("--learning-rate", type=float, default=2e-4) | |
| tr.add_argument("--max-samples", type=int, default=0, help="0 = use full dataset") | |
| tr.add_argument("--lora-rank", type=int, default=16) | |
| tr.add_argument("--lora-alpha", type=int, default=32) | |
| tr.add_argument("--lora-dropout", type=float, default=0.05) | |
| tr.add_argument( | |
| "--target-modules", | |
| type=str, | |
| default="q_proj,v_proj,o_proj,gate_proj", | |
| ) | |
| tr.add_argument("--max-seq-length", type=int, default=2048) | |
| tr.add_argument("--save-steps", type=int, default=200) | |
| tr.add_argument("--logging-steps", type=int, default=10) | |
| tr.add_argument( | |
| "--warmup-ratio", | |
| type=float, | |
| default=0.03, | |
| help="Used only if --warmup-steps is not set; converted to warmup_steps.", | |
| ) | |
| tr.add_argument( | |
| "--warmup-steps", | |
| type=int, | |
| default=None, | |
| help="LR warmup steps; if set, overrides --warmup-ratio.", | |
| ) | |
| tr.add_argument("--bf16", action="store_true", default=True) | |
| tr.add_argument("--no-bf16", dest="bf16", action="store_false") | |
| tr.add_argument("--fp16", action="store_true") | |
| tr.add_argument("--bnb-compute-dtype", type=str, default="bfloat16") | |
| tr.set_defaults(func=cmd_train) | |
| inf = sub.add_parser("infer", help="Generate with saved adapter") | |
| inf.add_argument("--adapter", type=str, required=True, help="Directory from train step") | |
| inf.add_argument( | |
| "--base-model", | |
| type=str, | |
| default="Qwen/Qwen2.5-Math-1.5B-Instruct", | |
| help="Must match base used in training if no pipeline_meta.json", | |
| ) | |
| inf.add_argument("--problem", type=str, required=True) | |
| inf.add_argument("--max-new-tokens", type=int, default=1024) | |
| inf.add_argument("--temperature", type=float, default=0.7) | |
| inf.add_argument("--top-p", type=float, default=0.95) | |
| inf.add_argument("--greedy", action="store_true") | |
| inf.add_argument("--bnb-compute-dtype", type=str, default="bfloat16") | |
| inf.set_defaults(func=cmd_infer) | |
| al = sub.add_parser("all", help="prepare + train [+ infer if --problem]") | |
| al.add_argument("--data", type=str, default=None, help="Output JSONL path (default data/sft/gsm8k_sft.jsonl)") | |
| al.add_argument("--prepare-source", choices=("hf", "jsonl"), default="hf") | |
| al.add_argument("--input", type=str, help="For jsonl prepare") | |
| al.add_argument("--splits", nargs="+", default=["train", "test"]) | |
| al.add_argument("--strip-scratchpads", action="store_true") | |
| al.add_argument("--output-dir", type=str, required=True) | |
| al.add_argument("--model", type=str, default="Qwen/Qwen2.5-Math-1.5B-Instruct") | |
| al.add_argument("--epochs", type=float, default=1.0) | |
| al.add_argument("--batch-size", type=int, default=1) | |
| al.add_argument("--grad-accum", type=int, default=8) | |
| al.add_argument("--learning-rate", type=float, default=2e-4) | |
| al.add_argument("--max-samples", type=int, default=0) | |
| al.add_argument("--lora-rank", type=int, default=16) | |
| al.add_argument("--lora-alpha", type=int, default=32) | |
| al.add_argument("--lora-dropout", type=float, default=0.05) | |
| al.add_argument("--target-modules", type=str, default="q_proj,v_proj,o_proj,gate_proj") | |
| al.add_argument("--max-seq-length", type=int, default=2048) | |
| al.add_argument("--save-steps", type=int, default=200) | |
| al.add_argument("--logging-steps", type=int, default=10) | |
| al.add_argument( | |
| "--warmup-ratio", | |
| type=float, | |
| default=0.03, | |
| help="Used only if --warmup-steps is not set; converted to warmup_steps.", | |
| ) | |
| al.add_argument( | |
| "--warmup-steps", | |
| type=int, | |
| default=None, | |
| help="LR warmup steps; if set, overrides --warmup-ratio.", | |
| ) | |
| al.add_argument("--bf16", action="store_true", default=True) | |
| al.add_argument("--no-bf16", dest="bf16", action="store_false") | |
| al.add_argument("--fp16", action="store_true") | |
| al.add_argument("--bnb-compute-dtype", type=str, default="bfloat16") | |
| al.add_argument("--problem", type=str, default="", help="If set, run infer after train") | |
| al.add_argument("--max-new-tokens", type=int, default=1024) | |
| al.add_argument("--temperature", type=float, default=0.7) | |
| al.add_argument("--top-p", type=float, default=0.95) | |
| al.add_argument("--greedy", action="store_true") | |
| al.set_defaults(func=cmd_all) | |
| return p | |
| def main() -> None: | |
| parser = build_parser() | |
| args = parser.parse_args() | |
| if str(ROOT) not in sys.path: | |
| sys.path.insert(0, str(ROOT)) | |
| args.func(args) | |
| if __name__ == "__main__": | |
| main() | |