Spaces:
Sleeping
Sleeping
Commit Β·
ec4ae03
0
Parent(s):
Initial Space deployment
Browse filesThis view is limited to 50 files because it contains too many changes. Β See raw diff
- .gitignore +73 -0
- Dockerfile +130 -0
- README.md +132 -0
- __init__.py +16 -0
- blog.md +94 -0
- client.py +76 -0
- docs/environment-overview.puml +69 -0
- docs/reward-system.puml +51 -0
- docs/training-phases.puml +27 -0
- images/axiomforgeai_scenes/scene_01.svg +52 -0
- images/axiomforgeai_scenes/scene_02.svg +72 -0
- images/axiomforgeai_scenes/scene_03.svg +67 -0
- images/axiomforgeai_scenes/scene_04.svg +78 -0
- images/axiomforgeai_scenes/scene_05.svg +66 -0
- images/axiomforgeai_scenes/scene_06.svg +79 -0
- images/axiomforgeai_scenes/scene_07.svg +66 -0
- images/axiomforgeai_scenes/scene_08.svg +74 -0
- images/axiomforgeai_scenes/scene_09.svg +61 -0
- images/axiomforgeai_scenes/scene_10.svg +86 -0
- images/blog_flow/architecture.svg +50 -0
- images/blog_flow/grading.svg +45 -0
- images/blog_flow/grpo-loop.svg +44 -0
- images/blog_flow/task-sources.svg +35 -0
- images/environment_overview.svg +0 -0
- images/training_phases.svg +1 -0
- logs/grpo/grpo_20260426_024029.log +44 -0
- logs/grpo/grpo_20260426_032827.log +0 -0
- logs/grpo/grpo_20260426_032827/config.json +44 -0
- logs/grpo/grpo_20260426_032827/console_output.log +0 -0
- logs/grpo/grpo_20260426_032827/metrics.csv +31 -0
- logs/metrics.jsonl +31 -0
- models.py +67 -0
- openenv.yaml +7 -0
- pyproject.toml +55 -0
- requirements.txt +160 -0
- scripts/__init__.py +1 -0
- scripts/convert_gsm8k_to_sft.py +193 -0
- scripts/create_dual_task_dataset.py +321 -0
- scripts/demo_before_after.py +591 -0
- scripts/dual_task_sft_pipeline.py +390 -0
- scripts/eval_sft_inference.py +565 -0
- scripts/gsm8k_sft_pipeline.py +475 -0
- scripts/launch_grpo.sh +127 -0
- scripts/plot_grpo_run.py +425 -0
- scripts/plot_training_results.py +521 -0
- scripts/precompute_extraction_cache.py +174 -0
- scripts/prepare_aqua_dataset.py +265 -0
- scripts/prepare_combined_dataset.py +711 -0
- scripts/run_grpo_training.py +0 -0
- scripts/run_inference.py +502 -0
.gitignore
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
dist/
|
| 13 |
+
downloads/
|
| 14 |
+
eggs/
|
| 15 |
+
.eggs/
|
| 16 |
+
lib/
|
| 17 |
+
lib64/
|
| 18 |
+
parts/
|
| 19 |
+
sdist/
|
| 20 |
+
var/
|
| 21 |
+
wheels/
|
| 22 |
+
share/python-wheels/
|
| 23 |
+
*.egg-info/
|
| 24 |
+
.installed.cfg
|
| 25 |
+
*.egg
|
| 26 |
+
MANIFEST
|
| 27 |
+
|
| 28 |
+
# Installer logs
|
| 29 |
+
pip-log.txt
|
| 30 |
+
pip-delete-this-directory.txt
|
| 31 |
+
|
| 32 |
+
# Unit test / coverage reports
|
| 33 |
+
.pytest_cache/
|
| 34 |
+
.coverage
|
| 35 |
+
.coverage.*
|
| 36 |
+
htmlcov/
|
| 37 |
+
.tox/
|
| 38 |
+
.nox/
|
| 39 |
+
coverage.xml
|
| 40 |
+
*.cover
|
| 41 |
+
*.py,cover
|
| 42 |
+
|
| 43 |
+
# Type checkers / static analyzers
|
| 44 |
+
.mypy_cache/
|
| 45 |
+
.pyre/
|
| 46 |
+
.ruff_cache/
|
| 47 |
+
.pytype/
|
| 48 |
+
|
| 49 |
+
# Virtual environments
|
| 50 |
+
.venv/
|
| 51 |
+
venv/
|
| 52 |
+
env/
|
| 53 |
+
ENV/
|
| 54 |
+
|
| 55 |
+
# Local environment files
|
| 56 |
+
.env
|
| 57 |
+
.env.*
|
| 58 |
+
*.local
|
| 59 |
+
|
| 60 |
+
# IDE / editor files
|
| 61 |
+
.vscode/
|
| 62 |
+
.idea/
|
| 63 |
+
*.swp
|
| 64 |
+
*.swo
|
| 65 |
+
*~
|
| 66 |
+
|
| 67 |
+
# OS files
|
| 68 |
+
.DS_Store
|
| 69 |
+
Thumbs.db
|
| 70 |
+
data/
|
| 71 |
+
|
| 72 |
+
*/ui
|
| 73 |
+
images/
|
Dockerfile
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AxiomForgeAI β GRPO Training Image
|
| 2 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 3 |
+
# Hardware target : 1Γ A100 PCIE 80 GB | AMD EPYC 7V13 | NVMe 300 GB
|
| 4 |
+
#
|
| 5 |
+
# CUDA driver : >= 13.0 (enforced at container start via entrypoint)
|
| 6 |
+
# CUDA toolkit : 12.4.1 (backward-compatible with driver 13.x)
|
| 7 |
+
# PyTorch : 2.5.1+cu124 (pinned in requirements.txt)
|
| 8 |
+
# Flash-Attn : 2.8.3 (pinned in requirements.txt)
|
| 9 |
+
#
|
| 10 |
+
# All Python package versions are taken exclusively from requirements.txt.
|
| 11 |
+
# No versions are hard-coded in this file.
|
| 12 |
+
#
|
| 13 |
+
# ββ Build βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 14 |
+
# docker build -t axiomforgeai-train:latest .
|
| 15 |
+
#
|
| 16 |
+
# ββ Interactive shell βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 17 |
+
# docker run --gpus all --ipc=host --ulimit memlock=-1 \
|
| 18 |
+
# -v $(pwd)/data:/workspace/data \
|
| 19 |
+
# -v $(pwd)/checkpoints:/workspace/checkpoints \
|
| 20 |
+
# -v $(pwd)/logs:/workspace/logs \
|
| 21 |
+
# -it axiomforgeai-train:latest bash
|
| 22 |
+
#
|
| 23 |
+
# ββ GRPO training (one-shot) ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 24 |
+
# docker run --gpus all --ipc=host --ulimit memlock=-1 \
|
| 25 |
+
# -v $(pwd)/data:/workspace/data \
|
| 26 |
+
# -v $(pwd)/checkpoints:/workspace/checkpoints \
|
| 27 |
+
# -v $(pwd)/logs:/workspace/logs \
|
| 28 |
+
# axiomforgeai-train:latest \
|
| 29 |
+
# python scripts/run_grpo_training.py \
|
| 30 |
+
# --base-model checkpoints/dual_task_v1 \
|
| 31 |
+
# --gsm8k-data data/sft/gsm8k_sft.jsonl \
|
| 32 |
+
# --num-iterations 30 --group-size 8 --questions-per-iter 16
|
| 33 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 34 |
+
|
| 35 |
+
# CUDA toolkit 12.4.1 β matches the cu124 wheels in requirements.txt and is
|
| 36 |
+
# fully compatible with the A100's CUDA 13.2 driver (driver is always β₯ toolkit).
|
| 37 |
+
FROM nvidia/cuda:12.4.1-devel-ubuntu22.04
|
| 38 |
+
|
| 39 |
+
LABEL org.opencontainers.image.title="AxiomForgeAI Training" \
|
| 40 |
+
cuda.driver.minimum="13.0" \
|
| 41 |
+
cuda.toolkit="12.4.1" \
|
| 42 |
+
torch.version="2.5.1+cu124" \
|
| 43 |
+
flash_attn.version="2.8.3"
|
| 44 |
+
|
| 45 |
+
# ββ System packages ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 46 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
| 47 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 48 |
+
python3.11 \
|
| 49 |
+
python3.11-dev \
|
| 50 |
+
python3-pip \
|
| 51 |
+
python3.11-venv \
|
| 52 |
+
git \
|
| 53 |
+
git-lfs \
|
| 54 |
+
curl \
|
| 55 |
+
wget \
|
| 56 |
+
build-essential \
|
| 57 |
+
ninja-build \
|
| 58 |
+
pkg-config \
|
| 59 |
+
libssl-dev \
|
| 60 |
+
libffi-dev \
|
| 61 |
+
ca-certificates \
|
| 62 |
+
&& ln -sf /usr/bin/python3.11 /usr/bin/python3 \
|
| 63 |
+
&& ln -sf /usr/bin/python3 /usr/bin/python \
|
| 64 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 65 |
+
|
| 66 |
+
# ββ Upgrade pip + build tooling βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 67 |
+
RUN python -m pip install --upgrade --no-cache-dir pip setuptools wheel
|
| 68 |
+
|
| 69 |
+
# ββ PyTorch (CUDA 12.4 wheels) ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 70 |
+
# Must be installed before flash-attn because flash-attn runs a torch version
|
| 71 |
+
# check at install time. The cu124 index is also used for all CUDA-linked wheels.
|
| 72 |
+
# Version is taken from requirements.txt β the --constraint flag keeps pip from
|
| 73 |
+
# re-resolving to a different version when requirements.txt is processed next.
|
| 74 |
+
RUN pip install --no-cache-dir \
|
| 75 |
+
--extra-index-url https://download.pytorch.org/whl/cu124 \
|
| 76 |
+
"torch==2.5.1" "torchvision==0.20.1" "torchaudio==2.5.1"
|
| 77 |
+
|
| 78 |
+
# ββ All remaining pinned requirements (from requirements.txt) βββββββββββββββββ
|
| 79 |
+
# flash-attn, xformers, vllm, triton, bitsandbytes, transformers, accelerate,
|
| 80 |
+
# peft, ray, sympy, scipy, numpy, openenv-core, fastapi, uvicorn, β¦ are all
|
| 81 |
+
# installed here at the exact versions pinned in requirements.txt.
|
| 82 |
+
# The cu124 index is provided so CUDA-linked wheels resolve correctly.
|
| 83 |
+
COPY requirements.txt /tmp/requirements.txt
|
| 84 |
+
RUN pip install --no-cache-dir \
|
| 85 |
+
--extra-index-url https://download.pytorch.org/whl/cu124 \
|
| 86 |
+
-r /tmp/requirements.txt
|
| 87 |
+
|
| 88 |
+
# ββ Project source βββββββββββοΏ½οΏ½ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 89 |
+
WORKDIR /workspace
|
| 90 |
+
COPY . /workspace/
|
| 91 |
+
|
| 92 |
+
# ββ Environment variables βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 93 |
+
# Repo root on PYTHONPATH so `from src.rl.X import Y` works without editable install
|
| 94 |
+
ENV PYTHONPATH="/workspace:$PYTHONPATH"
|
| 95 |
+
|
| 96 |
+
# HuggingFace model cache β mount a host path here to persist model downloads:
|
| 97 |
+
# -v /host/hf_cache:/workspace/.hf_cache
|
| 98 |
+
ENV HF_HOME="/workspace/.hf_cache"
|
| 99 |
+
ENV TRANSFORMERS_CACHE="/workspace/.hf_cache"
|
| 100 |
+
|
| 101 |
+
# A100 CUDA / NCCL tuning
|
| 102 |
+
ENV CUDA_DEVICE_MAX_CONNECTIONS=1
|
| 103 |
+
ENV NCCL_P2P_DISABLE=0
|
| 104 |
+
ENV NCCL_IB_DISABLE=0
|
| 105 |
+
# Required for Flash-Attn 2 with bfloat16 on Ampere
|
| 106 |
+
ENV TORCH_CUDNN_V8_API_ENABLED=1
|
| 107 |
+
|
| 108 |
+
# ββ Runtime entrypoint: enforce CUDA driver >= 13.0 ββββββββββββββββββββββββββ
|
| 109 |
+
# nvidia-smi is injected at runtime via --gpus, so this check runs when the
|
| 110 |
+
# container starts, not at build time.
|
| 111 |
+
RUN printf '%s\n' \
|
| 112 |
+
'#!/bin/sh' \
|
| 113 |
+
'if command -v nvidia-smi >/dev/null 2>&1; then' \
|
| 114 |
+
' CUDA_VER=$(nvidia-smi 2>/dev/null | grep -oP "CUDA Version: \K[0-9.]+" || echo "0.0")' \
|
| 115 |
+
' MAJOR=$(echo "$CUDA_VER" | cut -d. -f1)' \
|
| 116 |
+
' echo "[AxiomForgeAI] CUDA driver reports toolkit: $CUDA_VER"' \
|
| 117 |
+
' if [ "${MAJOR:-0}" -lt 13 ] 2>/dev/null; then' \
|
| 118 |
+
' echo "[ERROR] CUDA driver >= 13.0 required; detected $CUDA_VER. Upgrade your NVIDIA driver."' \
|
| 119 |
+
' exit 1' \
|
| 120 |
+
' fi' \
|
| 121 |
+
' echo "[AxiomForgeAI] CUDA $CUDA_VER >= 13.0 β OK"' \
|
| 122 |
+
'else' \
|
| 123 |
+
' echo "[WARNING] nvidia-smi not found β CUDA driver version check skipped."' \
|
| 124 |
+
'fi' \
|
| 125 |
+
'exec "$@"' \
|
| 126 |
+
> /usr/local/bin/entrypoint.sh \
|
| 127 |
+
&& chmod +x /usr/local/bin/entrypoint.sh
|
| 128 |
+
|
| 129 |
+
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
|
| 130 |
+
CMD ["bash"]
|
README.md
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: AxiomForgeAI Environment Server
|
| 3 |
+
emoji: π
|
| 4 |
+
colorFrom: indigo
|
| 5 |
+
colorTo: pink
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
app_port: 8000
|
| 9 |
+
base_path: /web
|
| 10 |
+
tags:
|
| 11 |
+
- openenv
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# AxiomForgeAI
|
| 15 |
+
|
| 16 |
+
[](https://github.com/meta-pytorch/OpenEnv)
|
| 17 |
+
|
| 18 |
+
*A self-improving math environment where a model practices on verified problems, generates new challenges when ready, and learns from solution attempts whose reasoning steps and final answers agree.*
|
| 19 |
+
|
| 20 |
+
## The Problem
|
| 21 |
+
|
| 22 |
+
Math reasoning models can fail in two different ways. Sometimes the setup, arithmetic, and algebraic steps look reasonable, but the final answer is wrong. Sometimes the final answer is right, but the reasoning that produced it is incomplete, inconsistent, or hard to trust.
|
| 23 |
+
|
| 24 |
+
For a math user, both failures matter. Checking only the final answer misses where the solution went off track. Checking only the steps misses whether the work actually reaches the right result. The useful signal is the agreement between the reasoning path and the final answer.
|
| 25 |
+
|
| 26 |
+
This project builds a practice loop around that signal. The model first works on problems with known answers, gets feedback on both the chain of reasoning and the final result, and only then starts generating new challenges for itself. The constraint is intentionally small: a 1.5B math model.
|
| 27 |
+
|
| 28 |
+
## The Environment
|
| 29 |
+
|
| 30 |
+
The environment is a practice loop for math reasoning. Each training group starts with one problem, asks the model for multiple solution attempts, scores those attempts from several angles, and uses GRPO to reinforce the attempts that are stronger than the rest of the group.
|
| 31 |
+
|
| 32 |
+

|
| 33 |
+
|
| 34 |
+
The environment has two task sources:
|
| 35 |
+
|
| 36 |
+
- **Grounded source:** A dataset problem from GSM8K / MATH comes with a known final answer. This gives the environment a reliable anchor for checking whether the model actually reached the right result.
|
| 37 |
+
- **Self-play source:** The curriculum selects a target skill and difficulty. The model writes a new question, then samples multiple solutions to that question. This adds practice beyond static datasets, but only after the grounded signal is stable enough.
|
| 38 |
+
|
| 39 |
+
Both sources feed the same scoring and update loop. For every selected problem, the model samples `K` candidate solutions. The environment checks final-answer correctness when a gold answer exists, scores reasoning quality with a PRM, checks chain consistency and symbolic arithmetic where possible, checks answer formatting, and scores self-generated questions for clarity, novelty, difficulty fit, and solvability.
|
| 40 |
+
|
| 41 |
+
GRPO then compares the `K` attempts against each other. The model is not rewarded for a solution in isolation; the strongest attempt in the group becomes the direction for learning. Training starts grounded-only, gradually mixes in self-play groups, and falls back to grounded practice if generated-question quality or answer correctness drops.
|
| 42 |
+
|
| 43 |
+
## How Self-Improvement Works
|
| 44 |
+
|
| 45 |
+
Self-improvement comes from turning each problem into a small comparison. The model does not produce one solution and move on; the environment samples several attempts, scores each attempt, and asks which reasoning path was strongest.
|
| 46 |
+
|
| 47 |
+
GRPO uses that within-group comparison as the learning signal. Attempts with correct answers, stronger reasoning chains, and cleaner final-answer format are reinforced. Attempts with broken chains or unsupported answers become weaker examples.
|
| 48 |
+
|
| 49 |
+
```text
|
| 50 |
+
practice -> sample attempts -> verify steps and answer -> compare -> reinforce -> adjust difficulty
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
## Reward System
|
| 54 |
+
|
| 55 |
+
The reward is designed to avoid a common math-training failure: optimizing for either the final answer or the reasoning trace alone. A good solution should reach the right answer, explain the path clearly, and keep the final result consistent with the steps that produced it.
|
| 56 |
+
|
| 57 |
+
| Signal | What it checks | Why it matters |
|
| 58 |
+
| --- | --- | --- |
|
| 59 |
+
| Final answer | Matches the gold answer when one exists | Keeps grounded problems tied to objective correctness |
|
| 60 |
+
| Process score | PRM score over the reasoning steps | Rewards clear mathematical progress, not just the last line |
|
| 61 |
+
| Chain consistency | Correct-prefix and step-answer consistency signals | Gives partial learning signal when a solution goes wrong midway |
|
| 62 |
+
| Format | Parseable final answer and clean response structure | Makes automatic grading reliable |
|
| 63 |
+
| Question quality | Topic fit, difficulty fit, clarity, novelty, and solvability | Keeps self-play from generating vague or useless practice tasks |
|
| 64 |
+
|
| 65 |
+
Grounded problems use the gold answer as the anchor. Self-play problems add a question-quality score before the solution reward is trusted. Both paths produce one combined score for each sampled attempt, and GRPO uses those scores only in comparison with the other attempts from the same problem.
|
| 66 |
+
|
| 67 |
+
```text
|
| 68 |
+
grounded: answer correctness + process score + chain consistency + format
|
| 69 |
+
self-play: question quality + solution quality
|
| 70 |
+
both -> one combined score per attempt -> GRPO compares attempts within the group
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
## Training Phases
|
| 74 |
+
|
| 75 |
+
Training follows a simple three-phase schedule. It starts with grounded-only practice so the model learns to keep answers and reasoning stable on problems with known solutions. Self-play is then introduced gradually, while grounded questions remain as an anchor. Once both are stable, training continues with a mixed task source and falls back to grounded-only batches if answer quality drops.
|
| 76 |
+
|
| 77 |
+

|
| 78 |
+
|
| 79 |
+
## Training Script
|
| 80 |
+
|
| 81 |
+
The GRPO training loop is available in two forms:
|
| 82 |
+
|
| 83 |
+
- [`scripts/launch_grpo.sh`](scripts/launch_grpo.sh) β the primary launch script; sets CUDA/threading env vars, verifies Flash-Attention, and calls `run_grpo_training.py` with the full parameter set.
|
| 84 |
+
|
| 85 |
+
```bash
|
| 86 |
+
bash scripts/launch_grpo.sh
|
| 87 |
+
```
|
| 88 |
+
- [`train_grpo.ipynb`](train_grpo.ipynb) β notebook version with the same parameters, structured around `env.reset / env.step / env.state / env.close` for interactive inspection.
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
## Results
|
| 92 |
+
|
| 93 |
+
These plots come from a single GPU training run and focus on the core question: did the model get better at making its reasoning and final answer agree?
|
| 94 |
+
|
| 95 |
+
### Evaluation Quality Over Training
|
| 96 |
+
|
| 97 |
+

|
| 98 |
+
|
| 99 |
+
The environment tracks final correctness, solution quality, step validity, and how long the reasoning chain stays correct. All four move upward together, which suggests the model is not just finding better final answers. It is also producing reasoning that holds up longer.
|
| 100 |
+
|
| 101 |
+
### Training Journey
|
| 102 |
+
|
| 103 |
+

|
| 104 |
+
|
| 105 |
+
Training starts with grounded practice on problems with known answers. Self-play is introduced only after the grounded signal is stable, so the model does not train on its own generated problems too early. The transition is conditional, not just a timer.
|
| 106 |
+
|
| 107 |
+
### Self-Play Curriculum
|
| 108 |
+
|
| 109 |
+

|
| 110 |
+
|
| 111 |
+
By the end of training, most practice came from self-play. The important part is that generated problems stayed solvable and novel even after self-play became a larger share of training. That makes the ramp meaningful: self-play added useful practice instead of recycled noise.
|
| 112 |
+
|
| 113 |
+
### Reward Confidence
|
| 114 |
+
|
| 115 |
+

|
| 116 |
+
|
| 117 |
+
The reward spread shows how much contrast exists between the model's best and worst attempts. Wide spread gives GRPO something to learn from. Skipped groups are cases where attempts are too similar to compare usefully. That rate falls as harder material enters the curriculum, which suggests the comparison signal stays useful.
|
| 118 |
+
|
| 119 |
+
### Step-Level Reasoning Quality
|
| 120 |
+
|
| 121 |
+

|
| 122 |
+
|
| 123 |
+
Step accuracy checks whether each line of reasoning is valid. Chain integrity checks whether those valid steps form an unbroken path to the answer. Both improve together, which means the model is building solutions that hold together more often instead of only producing better-looking outputs.
|
| 124 |
+
|
| 125 |
+
## Why It Matters
|
| 126 |
+
|
| 127 |
+
Reliable math reasoning needs more than fluent explanations or lucky final answers. A system that can separate correct reasoning from unsupported answers gives the model a better training target: not just "get the number," but build a chain of logic that reaches the number.
|
| 128 |
+
|
| 129 |
+
AxiomForgeAI matters because it turns that target into an environment. The same pattern can extend beyond math to other verifiable domains where attempts can be checked, compared, and improved: code, logic, structured data transformations, and scientific problem solving.
|
| 130 |
+
|
| 131 |
+
---
|
| 132 |
+
*Engineered for the OpenEnv Hackathon India 2026*
|
__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Axiomforgeai Environment."""
|
| 8 |
+
|
| 9 |
+
from .client import AxiomforgeaiEnv
|
| 10 |
+
from .models import AxiomforgeaiAction, AxiomforgeaiObservation
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"AxiomforgeaiAction",
|
| 14 |
+
"AxiomforgeaiObservation",
|
| 15 |
+
"AxiomforgeaiEnv",
|
| 16 |
+
]
|
blog.md
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AxiomForgeAI: Self-Improving Math Models Need More Than the Final Answer
|
| 2 |
+
|
| 3 |
+
Math models have a strange failure mode.
|
| 4 |
+
|
| 5 |
+
They can write a solution that looks careful, step-by-step, and confident, then end with the wrong answer. They can also produce the right final number with reasoning that is incomplete, inconsistent, or impossible to trust.
|
| 6 |
+
|
| 7 |
+
For math, that gap matters. The final answer is not enough. A proof, derivation, or word-problem solution only becomes useful when the path and the answer support each other.
|
| 8 |
+
|
| 9 |
+
AxiomForgeAI is built around that idea.
|
| 10 |
+
|
| 11 |
+
Instead of treating math reasoning as a one-shot generation problem, AxiomForgeAI turns it into a practice environment. The model does not simply answer a question and move on. It attempts the same problem multiple ways, receives feedback on both the reasoning path and the final answer, and learns from the attempts where the two agree.
|
| 12 |
+
|
| 13 |
+
## The Architecture
|
| 14 |
+
|
| 15 |
+

|
| 16 |
+
|
| 17 |
+
AxiomForgeAI is a training loop around one simple idea: a math solution should be judged by whether the reasoning path and the final answer support each other.
|
| 18 |
+
|
| 19 |
+
The environment first selects one task. It can come from a grounded dataset problem with a known answer, or from a self-play question written from a curriculum target. Only after that task is selected does the model sample `K` candidate solutions. The environment scores each attempt, and GRPO compares the attempts within that same problem group.
|
| 20 |
+
|
| 21 |
+
That is the important part. The model is not rewarded for sounding fluent. It is rewarded when the chain of reasoning and the final answer line up.
|
| 22 |
+
|
| 23 |
+
## Where Practice Comes From
|
| 24 |
+
|
| 25 |
+

|
| 26 |
+
|
| 27 |
+
The environment uses two sources of problems.
|
| 28 |
+
|
| 29 |
+
Grounded practice starts with dataset problems from sources like GSM8K or MATH. These problems come with known final answers, so the environment has a reliable anchor for correctness.
|
| 30 |
+
|
| 31 |
+
Self-play starts later. The curriculum selects a skill and difficulty, and the model writes a new question. That question is only useful if it is clear, solvable, on-topic, and appropriately difficult. This keeps self-play from becoming random problem generation.
|
| 32 |
+
|
| 33 |
+
Both sources eventually become the same interface: one selected problem. From there, the model samples multiple candidate solutions and the environment compares the resulting reasoning paths.
|
| 34 |
+
|
| 35 |
+
## What Gets Checked
|
| 36 |
+
|
| 37 |
+

|
| 38 |
+
|
| 39 |
+
AxiomForgeAI does not rely on a single reward signal. A final answer check is useful, but it is not enough. A process score is useful, but it is also not enough. The environment combines several signals so that a polished but wrong solution does not look good, and a lucky answer with weak reasoning does not look good either.
|
| 40 |
+
|
| 41 |
+
For grounded problems, the gold answer anchors correctness. For all attempts, the environment also looks at reasoning quality, chain consistency, symbolic arithmetic where possible, and whether the answer can be parsed cleanly. For self-play, the generated question itself is scored before the solution reward is trusted.
|
| 42 |
+
|
| 43 |
+
The result is one score per attempt. That score is not the end of training. It becomes useful because there are other attempts for the same problem.
|
| 44 |
+
|
| 45 |
+
## Why GRPO Fits
|
| 46 |
+
|
| 47 |
+

|
| 48 |
+
|
| 49 |
+
GRPO turns a problem into a small comparison game. The model samples several attempts for the same prompt. Some are wrong, some are partially right, and one may be clearly better because the answer follows from the steps.
|
| 50 |
+
|
| 51 |
+
Instead of asking whether an attempt is good in isolation, GRPO asks which attempts are stronger relative to the rest of the group. That relative signal is exactly what this project needs. The model learns from contrast: this reasoning path held together better than the others.
|
| 52 |
+
|
| 53 |
+
After the update, the improved model goes back into the environment for the next batch. The curriculum can keep it grounded, introduce more self-play, or fall back to grounded-only practice if quality drops.
|
| 54 |
+
|
| 55 |
+
## Why the 1.5B Constraint Matters
|
| 56 |
+
|
| 57 |
+
AxiomForgeAI is intentionally built around a compact math model.
|
| 58 |
+
|
| 59 |
+
That constraint makes the loop easier to see. A smaller model cannot hide every reasoning mistake behind scale. If the setup is wrong, if the arithmetic drifts, or if the final answer does not follow from the steps, the environment has to catch it and turn it into feedback.
|
| 60 |
+
|
| 61 |
+
The point is not that a compact model magically solves math. The point is that improvement has to come from better practice, better verification, and better selection of reasoning paths.
|
| 62 |
+
|
| 63 |
+
## What the Model Learns From
|
| 64 |
+
|
| 65 |
+
AxiomForgeAI rewards attempts that are mathematically useful, not just polished.
|
| 66 |
+
|
| 67 |
+
The model learns to solve problems with reasoning that supports the answer. It also learns, during self-play, to generate practice problems that are worth solving. A useful self-generated problem should be clear, solvable, on-topic, appropriately difficult, and not just a duplicate of what the model has already seen.
|
| 68 |
+
|
| 69 |
+
That makes the loop different from ordinary fine-tuning. The model is not only seeing more answers. It is practicing, being checked, and learning from the solution paths that survived verification.
|
| 70 |
+
|
| 71 |
+
## Where Examples Will Go
|
| 72 |
+
|
| 73 |
+
This section will include real model responses from the run.
|
| 74 |
+
|
| 75 |
+
- an example where the model had good steps but a wrong final answer
|
| 76 |
+
- an example where the model guessed correctly but the reasoning was weak
|
| 77 |
+
- an example after training where the reasoning chain and final answer agree
|
| 78 |
+
- a self-generated problem that passed the quality checks
|
| 79 |
+
|
| 80 |
+
These examples are important because the project is not only about a metric. The clearest evidence is seeing the model become better at making the path and the answer line up.
|
| 81 |
+
|
| 82 |
+
## Why This Matters
|
| 83 |
+
|
| 84 |
+
Math is a good starting point because mistakes are often checkable. Arithmetic can be verified. Final answers can be compared. Reasoning steps can be scored. That makes math a clean domain for building self-improvement loops.
|
| 85 |
+
|
| 86 |
+
But the pattern is bigger than math.
|
| 87 |
+
|
| 88 |
+
Many useful AI tasks have the same structure. Generate an attempt, check it, compare it against alternatives, and reinforce the better path. Code, logic, structured data transformation, and scientific problem solving all benefit from environments where progress can be verified.
|
| 89 |
+
|
| 90 |
+
AxiomForgeAI is one version of that pattern. It asks a simple question.
|
| 91 |
+
|
| 92 |
+
> What if a model could practice until its reasoning and answers agreed?
|
| 93 |
+
|
| 94 |
+
That is the loop this project builds.
|
client.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""AxiomForgeAI Math RL Environment Client."""
|
| 8 |
+
|
| 9 |
+
from typing import Any, Dict, Optional
|
| 10 |
+
|
| 11 |
+
from openenv.core import EnvClient
|
| 12 |
+
from openenv.core.client_types import StepResult
|
| 13 |
+
from openenv.core.env_server.types import State
|
| 14 |
+
|
| 15 |
+
from .models import AxiomforgeaiAction, AxiomforgeaiObservation
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class AxiomforgeaiEnv(
|
| 19 |
+
EnvClient[AxiomforgeaiAction, AxiomforgeaiObservation, State]
|
| 20 |
+
):
|
| 21 |
+
"""
|
| 22 |
+
Client for the AxiomForgeAI math RL environment.
|
| 23 |
+
|
| 24 |
+
Maintains a persistent WebSocket connection to the environment server.
|
| 25 |
+
Each client instance gets its own session with independent episode state.
|
| 26 |
+
|
| 27 |
+
Episode flow::
|
| 28 |
+
|
| 29 |
+
with AxiomforgeaiEnv(base_url="http://localhost:8000") as env:
|
| 30 |
+
# 1. Reset β receive a math question
|
| 31 |
+
result = env.reset()
|
| 32 |
+
question = result.observation.question
|
| 33 |
+
|
| 34 |
+
# 2. Step β submit a solution, receive reward + feedback
|
| 35 |
+
solution = "Step 1: ... Final Answer: 42"
|
| 36 |
+
result = env.step(AxiomforgeaiAction(solution=solution))
|
| 37 |
+
print(result.reward, result.observation.feedback)
|
| 38 |
+
|
| 39 |
+
Example with Docker::
|
| 40 |
+
|
| 41 |
+
client = AxiomforgeaiEnv.from_docker_image("axiomforgeai-env:latest")
|
| 42 |
+
try:
|
| 43 |
+
result = client.reset()
|
| 44 |
+
result = client.step(AxiomforgeaiAction(solution="Final Answer: 17"))
|
| 45 |
+
finally:
|
| 46 |
+
client.close()
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
def _step_payload(self, action: AxiomforgeaiAction) -> Dict[str, Any]:
|
| 50 |
+
"""Convert AxiomforgeaiAction to JSON payload for the step endpoint."""
|
| 51 |
+
return {"solution": action.solution}
|
| 52 |
+
|
| 53 |
+
def _parse_result(self, payload: Dict[str, Any]) -> StepResult[AxiomforgeaiObservation]:
|
| 54 |
+
"""Parse the server's step response into a StepResult."""
|
| 55 |
+
obs_data: Dict[str, Any] = payload.get("observation", {})
|
| 56 |
+
observation = AxiomforgeaiObservation(
|
| 57 |
+
question=obs_data.get("question", ""),
|
| 58 |
+
topic=obs_data.get("topic", ""),
|
| 59 |
+
difficulty=float(obs_data.get("difficulty", 0.5)),
|
| 60 |
+
feedback=obs_data.get("feedback", ""),
|
| 61 |
+
done=payload.get("done", False),
|
| 62 |
+
reward=payload.get("reward"),
|
| 63 |
+
metadata=obs_data.get("metadata"),
|
| 64 |
+
)
|
| 65 |
+
return StepResult(
|
| 66 |
+
observation=observation,
|
| 67 |
+
reward=payload.get("reward"),
|
| 68 |
+
done=payload.get("done", False),
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
def _parse_state(self, payload: Dict[str, Any]) -> State:
|
| 72 |
+
"""Parse the server's state response into a State object."""
|
| 73 |
+
return State(
|
| 74 |
+
episode_id=payload.get("episode_id"),
|
| 75 |
+
step_count=payload.get("step_count", 0),
|
| 76 |
+
)
|
docs/environment-overview.puml
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@startuml environment_overview
|
| 2 |
+
!theme plain
|
| 3 |
+
top to bottom direction
|
| 4 |
+
skinparam backgroundColor #FEFEFE
|
| 5 |
+
skinparam defaultFontName Arial
|
| 6 |
+
skinparam defaultFontSize 14
|
| 7 |
+
skinparam ArrowColor #334155
|
| 8 |
+
skinparam RectangleBorderColor #64748B
|
| 9 |
+
skinparam RectangleFontColor #0F172A
|
| 10 |
+
skinparam roundcorner 10
|
| 11 |
+
skinparam linetype ortho
|
| 12 |
+
skinparam packageStyle rectangle
|
| 13 |
+
skinparam nodesep 42
|
| 14 |
+
skinparam ranksep 42
|
| 15 |
+
|
| 16 |
+
title AxiomForgeAI - Phase-Controlled Math Reasoning Loop
|
| 17 |
+
|
| 18 |
+
rectangle "Small Math Model\n1.5B parameters" as MODEL #DBEAFE
|
| 19 |
+
|
| 20 |
+
rectangle "Phase Controller\nwarmup: grounded only\nramp: gradual self-play\ncontinuous: capped mix + fallback" as PHASE #E2E8F0
|
| 21 |
+
|
| 22 |
+
rectangle "Task Source\nfor each GRPO group" as SELECT #E2E8F0
|
| 23 |
+
|
| 24 |
+
rectangle "Grounded Source\nKnown-answer practice" as GLANE #ECFDF5 {
|
| 25 |
+
rectangle "Dataset problem\nGSM8K / MATH" as GQ #CCFBF1
|
| 26 |
+
rectangle "Gold answer\navailable" as GOLD #CCFBF1
|
| 27 |
+
rectangle "Model samples\nK solutions" as GSOL #CCFBF1
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
rectangle "Self-Play Source\nModel-made challenges" as SLANE #EEF2FF {
|
| 31 |
+
rectangle "Curriculum picks\nskill + difficulty" as CURRIC #E0E7FF
|
| 32 |
+
rectangle "Model writes\na new question" as SQ #E0E7FF
|
| 33 |
+
rectangle "Model samples\nK solutions" as SSOL #E0E7FF
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
rectangle "Shared Grading\nanswer, steps, arithmetic, format\n+ question quality for self-play" as GRADERS #F1F5F9
|
| 37 |
+
|
| 38 |
+
rectangle "Group Comparison\nWhich attempts worked best?" as COMPARE #EDE9FE
|
| 39 |
+
rectangle "GRPO Update\nReinforce stronger reasoning" as GRPO #DDD6FE
|
| 40 |
+
rectangle "Improved Model\nfor the next round" as NEXT #DBEAFE
|
| 41 |
+
|
| 42 |
+
MODEL -down-> PHASE
|
| 43 |
+
PHASE -down-> SELECT
|
| 44 |
+
|
| 45 |
+
note right of PHASE
|
| 46 |
+
sets mix
|
| 47 |
+
end note
|
| 48 |
+
|
| 49 |
+
SELECT -left-> GQ : grounded slot
|
| 50 |
+
GQ --> GOLD
|
| 51 |
+
GOLD --> GSOL
|
| 52 |
+
|
| 53 |
+
SELECT -right-> CURRIC : self-play slot
|
| 54 |
+
CURRIC --> SQ
|
| 55 |
+
SQ --> SSOL
|
| 56 |
+
|
| 57 |
+
GSOL -down-> GRADERS
|
| 58 |
+
SSOL -down-> GRADERS
|
| 59 |
+
GRADERS -right-> COMPARE
|
| 60 |
+
COMPARE -right-> GRPO
|
| 61 |
+
GRPO -right-> NEXT
|
| 62 |
+
NEXT -up-> MODEL : repeat
|
| 63 |
+
|
| 64 |
+
note bottom of SELECT
|
| 65 |
+
Each batch is randomly interleaved.
|
| 66 |
+
Phase 1 uses grounded only.
|
| 67 |
+
Later phases add self-play slots by ratio.
|
| 68 |
+
end note
|
| 69 |
+
@enduml
|
docs/reward-system.puml
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@startuml reward_system
|
| 2 |
+
!theme plain
|
| 3 |
+
top to bottom direction
|
| 4 |
+
skinparam backgroundColor #FEFEFE
|
| 5 |
+
skinparam defaultFontName Arial
|
| 6 |
+
skinparam defaultFontSize 14
|
| 7 |
+
skinparam ArrowColor #334155
|
| 8 |
+
skinparam RectangleBorderColor #64748B
|
| 9 |
+
skinparam RectangleFontColor #0F172A
|
| 10 |
+
skinparam roundcorner 10
|
| 11 |
+
skinparam linetype ortho
|
| 12 |
+
skinparam packageStyle rectangle
|
| 13 |
+
skinparam nodesep 54
|
| 14 |
+
skinparam ranksep 60
|
| 15 |
+
|
| 16 |
+
title AxiomForgeAI - Reward System
|
| 17 |
+
|
| 18 |
+
rectangle "Sampled Solution Attempt" as ATTEMPT #DBEAFE
|
| 19 |
+
|
| 20 |
+
rectangle "Grounded Reward\nknown-answer problem" as GROUNDED #ECFDF5 {
|
| 21 |
+
rectangle "Final answer\nmatches gold" as GOLD #CCFBF1
|
| 22 |
+
rectangle "PRM process score\nreasoning quality" as GPRM #CCFBF1
|
| 23 |
+
rectangle "Chain consistency\ncorrect prefix + final check" as GCHAIN #CCFBF1
|
| 24 |
+
rectangle "Format score\nparseable final answer" as GFORMAT #CCFBF1
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
rectangle "Self-Play Reward\ngenerated challenge" as SELFPLAY #EEF2FF {
|
| 28 |
+
rectangle "Question quality\nclarity, novelty, solvability" as QUALITY #E0E7FF
|
| 29 |
+
rectangle "Solution quality\nPRM + chain checks" as SOLUTION #E0E7FF
|
| 30 |
+
rectangle "Format score\nparseable final answer" as SFORMAT #E0E7FF
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
rectangle "Combined Reward\none score per attempt" as SCORE #F1F5F9
|
| 34 |
+
rectangle "GRPO Group Comparison\nrank attempts within the same problem" as COMPARE #EDE9FE
|
| 35 |
+
rectangle "Step-Answer Alignment\nreward paths where reasoning supports the result" as ALIGN #DDD6FE
|
| 36 |
+
|
| 37 |
+
ATTEMPT -left-> GROUNDED : grounded
|
| 38 |
+
ATTEMPT -right-> SELFPLAY : self-play
|
| 39 |
+
|
| 40 |
+
GOLD --> GPRM
|
| 41 |
+
GPRM --> GCHAIN
|
| 42 |
+
GCHAIN --> GFORMAT
|
| 43 |
+
|
| 44 |
+
QUALITY --> SOLUTION
|
| 45 |
+
SOLUTION --> SFORMAT
|
| 46 |
+
|
| 47 |
+
GFORMAT -down-> SCORE
|
| 48 |
+
SFORMAT -down-> SCORE
|
| 49 |
+
SCORE -right-> COMPARE
|
| 50 |
+
COMPARE -right-> ALIGN
|
| 51 |
+
@enduml
|
docs/training-phases.puml
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@startuml training_phases
|
| 2 |
+
!theme plain
|
| 3 |
+
left to right direction
|
| 4 |
+
skinparam backgroundColor #FEFEFE
|
| 5 |
+
skinparam defaultFontName Arial
|
| 6 |
+
skinparam defaultFontSize 14
|
| 7 |
+
skinparam ArrowColor #334155
|
| 8 |
+
skinparam RectangleBorderColor #64748B
|
| 9 |
+
skinparam RectangleFontColor #0F172A
|
| 10 |
+
skinparam roundcorner 10
|
| 11 |
+
skinparam linetype ortho
|
| 12 |
+
skinparam packageStyle rectangle
|
| 13 |
+
skinparam nodesep 42
|
| 14 |
+
skinparam ranksep 42
|
| 15 |
+
|
| 16 |
+
title AxiomForgeAI - Training Phases
|
| 17 |
+
|
| 18 |
+
rectangle "Phase 1\nGrounded Only" as Warmup #ECFDF5
|
| 19 |
+
rectangle "Phase 2\nSelf-Play Ramp" as Ramp #EEF2FF
|
| 20 |
+
rectangle "Phase 3\nMixed Training" as Improve #F1F5F9
|
| 21 |
+
rectangle "Fallback\nGrounded Recovery" as Fallback #EDE9FE
|
| 22 |
+
|
| 23 |
+
Warmup --> Ramp
|
| 24 |
+
Ramp --> Improve
|
| 25 |
+
Improve --> Fallback : if quality drops
|
| 26 |
+
Fallback --> Improve : recover
|
| 27 |
+
@enduml
|
images/axiomforgeai_scenes/scene_01.svg
ADDED
|
|
images/axiomforgeai_scenes/scene_02.svg
ADDED
|
|
images/axiomforgeai_scenes/scene_03.svg
ADDED
|
|
images/axiomforgeai_scenes/scene_04.svg
ADDED
|
|
images/axiomforgeai_scenes/scene_05.svg
ADDED
|
|
images/axiomforgeai_scenes/scene_06.svg
ADDED
|
|
images/axiomforgeai_scenes/scene_07.svg
ADDED
|
|
images/axiomforgeai_scenes/scene_08.svg
ADDED
|
|
images/axiomforgeai_scenes/scene_09.svg
ADDED
|
|
images/axiomforgeai_scenes/scene_10.svg
ADDED
|
|
images/blog_flow/architecture.svg
ADDED
|
|
images/blog_flow/grading.svg
ADDED
|
|
images/blog_flow/grpo-loop.svg
ADDED
|
|
images/blog_flow/task-sources.svg
ADDED
|
|
images/environment_overview.svg
ADDED
|
|
images/training_phases.svg
ADDED
|
|
logs/grpo/grpo_20260426_024029.log
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2026-04-26 02:40:33,617 INFO __main__ - ======================================================================
|
| 2 |
+
2026-04-26 02:40:33,617 INFO __main__ - GRPO run: grpo_20260426_024029
|
| 3 |
+
2026-04-26 02:40:33,617 INFO __main__ - Checkpoints : checkpoints/grpo/grpo_20260426_024029
|
| 4 |
+
2026-04-26 02:40:33,618 INFO __main__ - Logs : logs/grpo/grpo_20260426_024029
|
| 5 |
+
2026-04-26 02:40:33,618 INFO __main__ - Console log : logs/grpo/grpo_20260426_024029/console_output.log
|
| 6 |
+
2026-04-26 02:40:33,618 INFO __main__ - ======================================================================
|
| 7 |
+
2026-04-26 02:40:33,736 INFO src.utils.attn_backend - Attention backend selected: flash_attention_2
|
| 8 |
+
2026-04-26 02:40:33,736 INFO __main__ - Device: cuda:0 | attn: flash_attention_2
|
| 9 |
+
2026-04-26 02:40:33,753 INFO __main__ - GPU: NVIDIA A100 80GB PCIe | 85.1 GB VRAM | capability sm_80
|
| 10 |
+
2026-04-26 02:40:33,753 INFO __main__ - Run config: K=8 K_q=2 N=16 lr=5.0e-06 T=0.80 max_new=800 | clip_eps=0.20 kl_coef=0.0400 warmup=6 | diff_alpha=3.0 | self_play=70% grounded=30% | math_mix=30% math_maxdiff=3 | overlong_filter=True | eval_every=5 eval_N=100 | grad_clip=0.50 save_every=5 keep_last=3 | question_GRPO=ENABLED (K_q=2)
|
| 11 |
+
2026-04-26 02:40:33,753 INFO __main__ - Loading model from checkpoints/dual_task_v1 ...
|
| 12 |
+
2026-04-26 02:40:34,405 INFO __main__ - Tokenizer has no chat_template; loading from base model Qwen/Qwen2.5-Math-1.5B-Instruct
|
| 13 |
+
2026-04-26 02:40:34,731 INFO __main__ - Chat template loaded successfully.
|
| 14 |
+
2026-04-26 02:40:34,731 INFO __main__ - Detected PEFT adapter β loading base Qwen/Qwen2.5-Math-1.5B-Instruct then merging checkpoints/dual_task_v1
|
| 15 |
+
2026-04-26 02:40:36,242 WARNING __main__ - All parameters were frozen on load (PEFT merge_and_unload bug). Re-enabled requires_grad β any prior frozen runs were training nothing.
|
| 16 |
+
2026-04-26 02:40:36,242 INFO __main__ - Flash-Attn 2 active β gradient checkpointing OFF (Flash already gives O(T) attention memory).
|
| 17 |
+
2026-04-26 02:40:36,243 INFO __main__ - Trainable parameters: 1,543,714,304 / 1,543,714,304 (100.0%)
|
| 18 |
+
2026-04-26 02:40:36,244 INFO __main__ - Creating frozen reference policy (kl_coef=0.0400, ~3.1 GB VRAM)...
|
| 19 |
+
2026-04-26 02:40:36,305 INFO __main__ - Reference policy ready.
|
| 20 |
+
2026-04-26 02:40:36,306 INFO __main__ - LR schedule: 5.0e-06 warmup(6 iters) β cosine decay(24 iters, min=5.0e-07)
|
| 21 |
+
2026-04-26 02:40:36,415 INFO __main__ - Loaded 8792 QA pairs from data/sft/gsm8k_sft.jsonl
|
| 22 |
+
2026-04-26 02:40:36,424 INFO __main__ - Loaded 4072 MATH pairs from data/math/math_numeric.jsonl
|
| 23 |
+
2026-04-26 02:40:36,424 INFO __main__ - MATH mixing: 30% MATH (4072 problems) + 70% GSM8K (8792 problems)
|
| 24 |
+
2026-04-26 02:40:36,424 INFO src.rl.prm_scorer - Loading PRM Qwen/Qwen2.5-Math-PRM-7B (4-bit=True, dtype=torch.bfloat16) on cuda:0 β¦
|
| 25 |
+
|
| 26 |
+
Some weights of the model checkpoint at Qwen/Qwen2.5-Math-PRM-7B were not used when initializing Qwen2ForProcessRewardModel: ['lm_head.weight']
|
| 27 |
+
- This IS expected if you are initializing Qwen2ForProcessRewardModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
|
| 28 |
+
- This IS NOT expected if you are initializing Qwen2ForProcessRewardModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
|
| 29 |
+
2026-04-26 02:40:40,150 INFO src.rl.prm_scorer - PRM ready. GPU memory allocated: 9.97 GB step_sep_id=151651
|
| 30 |
+
2026-04-26 02:40:40,151 INFO __main__ - PRM loaded: Qwen/Qwen2.5-Math-PRM-7B (4-bit)
|
| 31 |
+
2026-04-26 02:40:40,154 INFO src.rl.unified_accuracy - Extraction cache not found at data/extraction_cache.json β will build on first use
|
| 32 |
+
2026-04-26 02:40:40,154 INFO __main__ - Unified accuracy calculator ready (extractor=Qwen/Qwen2.5-0.5B-Instruct, cache=data/extraction_cache.json)
|
| 33 |
+
2026-04-26 02:40:40,154 INFO __main__ - Warming up step-chain extractor (eager load)...
|
| 34 |
+
2026-04-26 02:40:40,154 INFO src.rl.unified_accuracy - Loading step chain extractor: Qwen/Qwen2.5-0.5B-Instruct
|
| 35 |
+
2026-04-26 02:40:41,033 INFO src.rl.unified_accuracy - Step chain extractor loaded
|
| 36 |
+
2026-04-26 02:40:41,034 INFO __main__ - Extractor warmup complete
|
| 37 |
+
2026-04-26 02:40:41,034 INFO src.rl.llm_question_classifier - LLMQuestionClassifier ready (model=Qwen2ForCausalLM, cache=10000, topics=24)
|
| 38 |
+
2026-04-26 02:40:42,571 INFO __main__ - Detected structured dataset (8792 records) β bootstrapping curriculum from skill_ids instead of keyword classifier.
|
| 39 |
+
2026-04-26 02:40:42,575 INFO src.rl.curriculum_manager - Curriculum bootstrapped from 8792 records across 1 topics
|
| 40 |
+
2026-04-26 02:40:42,575 INFO __main__ - ======================================================================
|
| 41 |
+
2026-04-26 02:40:42,575 INFO __main__ - INITIAL EVALUATION (Iteration 0)
|
| 42 |
+
2026-04-26 02:40:42,575 INFO __main__ - ======================================================================
|
| 43 |
+
|
| 44 |
+
|
logs/grpo/grpo_20260426_032827.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
logs/grpo/grpo_20260426_032827/config.json
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"base_model": "checkpoints/dual_task_v1",
|
| 3 |
+
"output_dir": "checkpoints/grpo",
|
| 4 |
+
"gsm8k_data": "data/sft/gsm8k_sft.jsonl",
|
| 5 |
+
"eval_data_path": "data/sft/gsm8k_test.jsonl",
|
| 6 |
+
"num_iterations": 60,
|
| 7 |
+
"group_size": 10,
|
| 8 |
+
"q_group_size": 2,
|
| 9 |
+
"questions_per_iter": 20,
|
| 10 |
+
"learning_rate": 5e-06,
|
| 11 |
+
"max_new_tokens": 1000,
|
| 12 |
+
"temperature": 0.8,
|
| 13 |
+
"eval_every": 5,
|
| 14 |
+
"eval_max_samples": 150,
|
| 15 |
+
"eval_max_new_tokens": 1000,
|
| 16 |
+
"eval_pass_at_k": 0,
|
| 17 |
+
"use_prm": true,
|
| 18 |
+
"prm_model": "Qwen/Qwen2.5-Math-PRM-7B",
|
| 19 |
+
"skip_initial_eval": false,
|
| 20 |
+
"run_name": "grpo_20260426_032827",
|
| 21 |
+
"max_grad_norm": 0.5,
|
| 22 |
+
"kl_coef": 0.06,
|
| 23 |
+
"math_data": null,
|
| 24 |
+
"math_mix_ratio": 0.3,
|
| 25 |
+
"math_mix_ratio_late": 0.5,
|
| 26 |
+
"math_ramp_start": 18,
|
| 27 |
+
"math_max_difficulty": 3,
|
| 28 |
+
"clip_eps": 0.2,
|
| 29 |
+
"warmup_iters": 8,
|
| 30 |
+
"min_lr_ratio": 0.1,
|
| 31 |
+
"difficulty_alpha": 3.5,
|
| 32 |
+
"overlong_filter": true,
|
| 33 |
+
"save_every": 5,
|
| 34 |
+
"keep_last": 4,
|
| 35 |
+
"self_play_ratio": 0.7,
|
| 36 |
+
"min_warmup": 12,
|
| 37 |
+
"selfplay_gt_thresh": 0.65,
|
| 38 |
+
"selfplay_grounded_thresh": 0.65,
|
| 39 |
+
"selfplay_step_thresh": 0.68,
|
| 40 |
+
"selfplay_ramp_iters": 28,
|
| 41 |
+
"grounded_floor": 0.55,
|
| 42 |
+
"extractor_model": "Qwen/Qwen2.5-0.5B-Instruct",
|
| 43 |
+
"extraction_cache": "data/extraction_cache.json"
|
| 44 |
+
}
|
logs/grpo/grpo_20260426_032827/console_output.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
logs/grpo/grpo_20260426_032827/metrics.csv
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
iteration,timestamp,loss,mean_reward,std_reward,batch_accuracy,grounded_acc,gt_match_rate,step_accuracy,lccp,n_groups,skipped_groups,n_sp_groups,sp_ratio,sp_suspended,training_phase,learning_rate,iter_time_s,q_reward,q_valid_rate,q_novelty,q_solvability,chain_prm_corr,chain_scoring_on,eval_combined,eval_correct_rt,eval_prm,eval_step_acc,eval_lccp,eval_format,eval_n_scored,eval_final_ans
|
| 2 |
+
1,2026-04-26T03:38:38,0.000610,0.914309,0.163605,0.960000,0.960000,0.780000,0.894861,0.814111,12,8,0,0.000000,0,GROUNDED_ONLY,0.000001,127.637996,0.000000,0.000000,0.000000,0.000000,0.000000,0,,,,,,,,
|
| 3 |
+
2,2026-04-26T03:41:58,-0.000034,0.847892,0.216018,0.914141,0.914141,0.651500,0.866692,0.765381,18,2,0,0.000000,0,GROUNDED_ONLY,0.000002,199.518393,0.000000,0.000000,0.000000,0.000000,0.000000,0,,,,,,,,
|
| 4 |
+
3,2026-04-26T03:45:08,0.000366,0.896391,0.170699,0.954545,0.954545,0.707100,0.876898,0.765238,12,8,0,0.000000,0,GROUNDED_ONLY,0.000002,189.836063,0.000000,0.000000,0.000000,0.000000,0.000000,0,,,,,,,,
|
| 5 |
+
4,2026-04-26T03:48:10,0.000942,0.865431,0.218756,0.893939,0.893939,0.732300,0.858504,0.764982,11,9,0,0.000000,0,GROUNDED_ONLY,0.000003,182.125475,0.000000,0.000000,0.000000,0.000000,0.000000,0,,,,,,,,
|
| 6 |
+
5,2026-04-26T03:59:39,0.000081,0.856875,0.239487,0.884422,0.884422,0.693500,0.918500,0.843100,16,4,0,0.000000,0,GROUNDED_ONLY,0.000003,201.679190,0.000000,0.000000,0.000000,0.000000,0.000000,0,0.919200,0.793300,0.903500,0.918500,0.843100,0.997700,150,0.793333
|
| 7 |
+
6,2026-04-26T04:02:52,-0.000063,0.879253,0.215318,0.909548,0.909548,0.748700,0.884646,0.805897,12,8,0,0.000000,0,GROUNDED_ONLY,0.000004,193.350312,0.000000,0.000000,0.000000,0.000000,0.000000,0,,,,,,,,
|
| 8 |
+
7,2026-04-26T04:06:20,0.001071,0.837888,0.223356,0.883249,0.883249,0.639600,0.813073,0.658069,14,6,0,0.000000,0,GROUNDED_ONLY,0.000004,208.223944,0.000000,0.000000,0.000000,0.000000,0.000000,0,,,,,,,,
|
| 9 |
+
8,2026-04-26T04:09:11,-0.000257,0.875536,0.200109,0.895000,0.895000,0.690000,0.864722,0.747928,13,7,0,0.000000,0,GROUNDED_ONLY,0.000005,170.595953,0.000000,0.000000,0.000000,0.000000,0.000000,0,,,,,,,,
|
| 10 |
+
9,2026-04-26T04:12:52,0.000060,0.906506,0.176914,0.964646,0.964646,0.803000,0.893573,0.817532,15,5,0,0.000000,0,GROUNDED_ONLY,0.000005,221.350669,0.000000,0.000000,0.000000,0.000000,0.000000,0,,,,,,,,
|
| 11 |
+
10,2026-04-26T04:24:49,0.000425,0.880765,0.175501,0.954774,0.954774,0.683400,0.920500,0.842600,14,6,0,0.000000,0,GROUNDED_ONLY,0.000005,188.981772,0.000000,0.000000,0.000000,0.000000,0.000000,0,0.919900,0.793300,0.906600,0.920500,0.842600,0.998000,150,0.793333
|
| 12 |
+
11,2026-04-26T04:27:11,-0.000557,0.969814,0.098322,0.985000,0.985000,0.930000,0.966268,0.921810,8,12,0,0.000000,0,GROUNDED_ONLY,0.000005,141.966778,0.000000,0.000000,0.000000,0.000000,0.000000,0,,,,,,,,
|
| 13 |
+
12,2026-04-26T04:30:09,0.000073,0.849274,0.212864,0.900000,0.900000,0.650000,0.820526,0.687272,14,6,0,0.000000,0,SELFPLAY_RAMP,0.000005,177.954757,0.000000,0.000000,0.000000,0.000000,0.000000,0,,,,,,,,
|
| 14 |
+
13,2026-04-26T04:39:26,0.000268,0.898824,0.185992,0.930000,0.930000,0.780000,0.870960,0.788730,14,6,0,0.000000,0,SELFPLAY_RAMP,0.000005,556.185637,0.000000,0.000000,0.000000,0.000000,-0.040000,0,,,,,,,,
|
| 15 |
+
14,2026-04-26T04:48:54,0.000496,0.855832,0.208499,0.952381,0.947368,0.673700,0.857607,0.747807,18,3,1,0.036000,0,SELFPLAY_RAMP,0.000005,568.400518,0.763000,1.000000,0.428900,1.000000,0.209000,0,,,,,,,,
|
| 16 |
+
15,2026-04-26T05:06:28,0.000023,0.927972,0.167187,0.937799,0.931217,0.836000,0.924200,0.842400,12,9,1,0.071000,0,SELFPLAY_RAMP,0.000005,550.143772,0.721800,1.000000,0.458000,1.000000,0.079000,0,0.926200,0.800000,0.907200,0.924200,0.842400,1.000000,150,0.800000
|
| 17 |
+
16,2026-04-26T05:16:04,0.000330,0.914605,0.172733,0.949772,0.938547,0.832400,0.895523,0.843899,15,7,2,0.107000,0,SELFPLAY_RAMP,0.000005,575.528946,0.787800,1.000000,0.447500,0.960000,0.089000,0,,,,,,,,
|
| 18 |
+
17,2026-04-26T05:26:20,-0.000137,0.888123,0.195006,0.938326,0.916168,0.700600,0.855796,0.768235,20,3,3,0.143000,0,SELFPLAY_RAMP,0.000005,616.018573,0.798200,1.000000,0.461600,1.000000,-0.191000,0,,,,,,,,
|
| 19 |
+
18,2026-04-26T05:35:30,0.000079,0.866401,0.178010,0.953975,0.943396,0.591200,0.830780,0.692011,19,5,4,0.179000,0,SELFPLAY_RAMP,0.000005,550.572628,0.739400,1.000000,0.452000,0.976200,0.021000,0,,,,,,,,
|
| 20 |
+
19,2026-04-26T05:44:13,0.000151,0.891281,0.172665,0.953586,0.949045,0.764300,0.851398,0.756874,16,8,4,0.214000,0,SELFPLAY_RAMP,0.000005,522.428960,0.733100,1.000000,0.456400,0.972500,0.075000,0,,,,,,,,
|
| 21 |
+
20,2026-04-26T06:02:54,0.000244,0.896291,0.177842,0.927711,0.906040,0.798700,0.925300,0.842800,18,7,5,0.250000,0,SELFPLAY_RAMP,0.000004,619.886349,0.770000,1.000000,0.474100,0.945000,-0.118000,0,0.923400,0.800000,0.905600,0.925300,0.842800,1.000000,150,0.800000
|
| 22 |
+
21,2026-04-26T06:11:04,0.000192,0.841732,0.187981,0.923077,0.914286,0.735700,0.819504,0.693061,21,5,6,0.286000,0,SELFPLAY_RAMP,0.000004,490.366938,0.697200,1.000000,0.449300,0.962500,0.209000,0,,,,,,,,
|
| 23 |
+
22,2026-04-26T06:21:16,0.000579,0.917519,0.124242,0.984314,0.985294,0.904400,0.964735,0.928489,20,6,6,0.321000,0,SELFPLAY_RAMP,0.000004,611.872286,0.699800,1.000000,0.457100,0.979000,0.145000,0,,,,,,,,
|
| 24 |
+
23,2026-04-26T06:28:41,0.000614,0.920698,0.147419,0.977011,0.950820,0.803300,0.907500,0.847631,18,9,7,0.357000,0,SELFPLAY_RAMP,0.000004,444.320885,0.726000,1.000000,0.441200,0.988500,0.143000,0,,,,,,,,
|
| 25 |
+
24,2026-04-26T06:36:32,-0.000213,0.879590,0.173313,0.935714,0.933333,0.791700,0.898819,0.812292,20,8,8,0.393000,0,SELFPLAY_RAMP,0.000004,471.698962,0.662100,1.000000,0.440800,0.968800,0.082000,0,,,,,,,,
|
| 26 |
+
25,2026-04-26T06:53:36,0.000344,0.844528,0.208658,0.927336,0.853211,0.605500,0.919800,0.846800,28,1,9,0.429000,0,SELFPLAY_RAMP,0.000004,524.655717,0.647100,1.000000,0.439400,0.967200,0.127000,0,0.922100,0.793300,0.903400,0.919800,0.846800,1.000000,150,0.793333
|
| 27 |
+
26,2026-04-26T07:02:06,0.000421,0.866649,0.179636,0.920415,0.926606,0.789000,0.889846,0.794302,26,3,9,0.464000,0,SELFPLAY_RAMP,0.000004,509.677450,0.679200,1.000000,0.448800,0.931700,0.065000,0,,,,,,,,
|
| 28 |
+
27,2026-04-26T07:12:03,-0.000227,0.877934,0.162866,0.956376,0.939394,0.686900,0.861628,0.740657,25,5,10,0.500000,0,SELFPLAY_RAMP,0.000004,597.521238,0.683100,1.000000,0.458400,0.975900,0.067000,0,,,,,,,,
|
| 29 |
+
28,2026-04-26T07:22:06,0.000042,0.869600,0.159154,0.941935,0.877778,0.655600,0.833443,0.618623,29,2,11,0.536000,0,SELFPLAY_RAMP,0.000004,603.099793,0.669300,1.000000,0.448900,0.983600,0.047000,0,,,,,,,,
|
| 30 |
+
29,2026-04-26T07:31:46,0.000377,0.867441,0.170826,0.947020,0.892857,0.726200,0.867407,0.760394,28,3,11,0.571000,0,SELFPLAY_RAMP,0.000003,579.690467,0.649600,1.000000,0.442500,0.973900,0.123000,0,,,,,,,,
|
| 31 |
+
30,2026-04-26T07:48:26,-0.000299,0.870581,0.160260,0.965517,0.950000,0.800000,0.923200,0.850000,27,5,12,0.607000,0,SELFPLAY_RAMP,0.000003,503.087982,0.676400,1.000000,0.456600,0.969900,0.099000,0,0.920400,0.793300,0.904400,0.923200,0.850000,1.000000,150,0.793333
|
logs/metrics.jsonl
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"iteration": 0, "accuracy": 0.9162, "combined_score": 0.9162, "step_accuracy": 0.9111, "lccp": 0.8392, "correct_rate": 0.7867, "prm_mean": 0.8988, "prm_final": 0.9275, "sympy_mean": 0.0, "format_mean": 1.0, "n_scored": 150, "total": 150, "final_answer_correct": 118, "final_answer_accuracy": 0.7866666666666666}
|
| 2 |
+
{"iteration": 1, "loss": 0.0006103356778718686, "mean_reward": 0.914308755129325, "std_reward": 0.1636050993381563, "batch_accuracy": 0.96, "grounded_accuracy": 0.96, "gt_match_rate": 0.78, "step_accuracy": 0.8948611111111111, "lccp": 0.8141111111111111, "n_groups": 12, "skipped_groups": 8, "learning_rate": 1.0625000000000002e-06, "iter_time_s": 127.63799649500288, "training_phase": "GROUNDED_ONLY", "effective_sp_ratio": 0.0, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.0, "extraction_success_rate": 0.0, "chain_scoring_active": 0, "n_self_play_groups": 0, "q_gen_attempts": 0, "q_gen_valid": 0, "q_gen_valid_rate": 0.0, "mean_question_reward": 0.0, "q_quality_rate": 0.0, "q_topic_match": 0.0, "q_difficulty_fit": 0.0, "q_clarity": 0.0, "q_novelty": 0.0, "q_solvability": 0.0}
|
| 3 |
+
{"iteration": 2, "loss": -3.432962815471304e-05, "mean_reward": 0.8478923191518654, "std_reward": 0.2160182166583165, "batch_accuracy": 0.9141414141414141, "grounded_accuracy": 0.9141414141414141, "gt_match_rate": 0.6515, "step_accuracy": 0.8666916416916417, "lccp": 0.7653809153809155, "n_groups": 18, "skipped_groups": 2, "learning_rate": 1.6250000000000001e-06, "iter_time_s": 199.5183933188673, "training_phase": "GROUNDED_ONLY", "effective_sp_ratio": 0.0, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.0, "extraction_success_rate": 0.0, "chain_scoring_active": 0, "n_self_play_groups": 0, "q_gen_attempts": 0, "q_gen_valid": 0, "q_gen_valid_rate": 0.0, "mean_question_reward": 0.0, "q_quality_rate": 0.0, "q_topic_match": 0.0, "q_difficulty_fit": 0.0, "q_clarity": 0.0, "q_novelty": 0.0, "q_solvability": 0.0}
|
| 4 |
+
{"iteration": 3, "loss": 0.0003658987698145211, "mean_reward": 0.8963912433066207, "std_reward": 0.17069859725714537, "batch_accuracy": 0.9545454545454546, "grounded_accuracy": 0.9545454545454546, "gt_match_rate": 0.7071, "step_accuracy": 0.876897947731281, "lccp": 0.765237694404361, "n_groups": 12, "skipped_groups": 8, "learning_rate": 2.1875000000000002e-06, "iter_time_s": 189.83606291818433, "training_phase": "GROUNDED_ONLY", "effective_sp_ratio": 0.0, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.0, "extraction_success_rate": 0.0, "chain_scoring_active": 0, "n_self_play_groups": 0, "q_gen_attempts": 0, "q_gen_valid": 0, "q_gen_valid_rate": 0.0, "mean_question_reward": 0.0, "q_quality_rate": 0.0, "q_topic_match": 0.0, "q_difficulty_fit": 0.0, "q_clarity": 0.0, "q_novelty": 0.0, "q_solvability": 0.0}
|
| 5 |
+
{"iteration": 4, "loss": 0.0009415318305731158, "mean_reward": 0.8654313890820613, "std_reward": 0.21875612713334075, "batch_accuracy": 0.8939393939393939, "grounded_accuracy": 0.8939393939393939, "gt_match_rate": 0.7323, "step_accuracy": 0.8585036876703543, "lccp": 0.7649821628988295, "n_groups": 11, "skipped_groups": 9, "learning_rate": 2.7500000000000004e-06, "iter_time_s": 182.12547484994866, "training_phase": "GROUNDED_ONLY", "effective_sp_ratio": 0.0, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.0, "extraction_success_rate": 0.0, "chain_scoring_active": 0, "n_self_play_groups": 0, "q_gen_attempts": 0, "q_gen_valid": 0, "q_gen_valid_rate": 0.0, "mean_question_reward": 0.0, "q_quality_rate": 0.0, "q_topic_match": 0.0, "q_difficulty_fit": 0.0, "q_clarity": 0.0, "q_novelty": 0.0, "q_solvability": 0.0}
|
| 6 |
+
{"iteration": 5, "loss": 8.118284122815567e-05, "mean_reward": 0.8568747993989829, "std_reward": 0.23948718740823036, "batch_accuracy": 0.8844221105527639, "grounded_accuracy": 0.8844221105527639, "gt_match_rate": 0.6935, "step_accuracy": 0.9185, "lccp": 0.8431, "n_groups": 16, "skipped_groups": 4, "learning_rate": 3.3125000000000005e-06, "iter_time_s": 201.67919013393112, "training_phase": "GROUNDED_ONLY", "effective_sp_ratio": 0.0, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.0, "extraction_success_rate": 0.0, "chain_scoring_active": 0, "n_self_play_groups": 0, "q_gen_attempts": 0, "q_gen_valid": 0, "q_gen_valid_rate": 0.0, "mean_question_reward": 0.0, "q_quality_rate": 0.0, "q_topic_match": 0.0, "q_difficulty_fit": 0.0, "q_clarity": 0.0, "q_novelty": 0.0, "q_solvability": 0.0, "accuracy": 0.9192, "combined_score": 0.9192, "correct_rate": 0.7933, "prm_mean": 0.9035, "prm_final": 0.9305, "sympy_mean": 0.0, "format_mean": 0.9977, "n_scored": 150, "total": 150, "final_answer_correct": 119, "final_answer_accuracy": 0.7933333333333333}
|
| 7 |
+
{"iteration": 6, "loss": -6.271734067316477e-05, "mean_reward": 0.8792530329566163, "std_reward": 0.21531797453446344, "batch_accuracy": 0.9095477386934674, "grounded_accuracy": 0.9095477386934674, "gt_match_rate": 0.7487, "step_accuracy": 0.8846455219822055, "lccp": 0.8058971263242619, "n_groups": 12, "skipped_groups": 8, "learning_rate": 3.875e-06, "iter_time_s": 193.35031225602143, "training_phase": "GROUNDED_ONLY", "effective_sp_ratio": 0.0, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.0, "extraction_success_rate": 0.0, "chain_scoring_active": 0, "n_self_play_groups": 0, "q_gen_attempts": 0, "q_gen_valid": 0, "q_gen_valid_rate": 0.0, "mean_question_reward": 0.0, "q_quality_rate": 0.0, "q_topic_match": 0.0, "q_difficulty_fit": 0.0, "q_clarity": 0.0, "q_novelty": 0.0, "q_solvability": 0.0}
|
| 8 |
+
{"iteration": 7, "loss": 0.0010708057920315436, "mean_reward": 0.8378877251545859, "std_reward": 0.2233563664223874, "batch_accuracy": 0.883248730964467, "grounded_accuracy": 0.883248730964467, "gt_match_rate": 0.6396, "step_accuracy": 0.8130725309659319, "lccp": 0.6580686304671076, "n_groups": 14, "skipped_groups": 6, "learning_rate": 4.4375e-06, "iter_time_s": 208.22394350194372, "training_phase": "GROUNDED_ONLY", "effective_sp_ratio": 0.0, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.0, "extraction_success_rate": 0.0, "chain_scoring_active": 0, "n_self_play_groups": 0, "q_gen_attempts": 0, "q_gen_valid": 0, "q_gen_valid_rate": 0.0, "mean_question_reward": 0.0, "q_quality_rate": 0.0, "q_topic_match": 0.0, "q_difficulty_fit": 0.0, "q_clarity": 0.0, "q_novelty": 0.0, "q_solvability": 0.0}
|
| 9 |
+
{"iteration": 8, "loss": -0.0002566667799678376, "mean_reward": 0.8755362041151912, "std_reward": 0.20010863742401203, "batch_accuracy": 0.895, "grounded_accuracy": 0.895, "gt_match_rate": 0.69, "step_accuracy": 0.8647215007215007, "lccp": 0.7479280303030303, "n_groups": 13, "skipped_groups": 7, "learning_rate": 5e-06, "iter_time_s": 170.59595341305248, "training_phase": "GROUNDED_ONLY", "effective_sp_ratio": 0.0, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.0, "extraction_success_rate": 0.0, "chain_scoring_active": 0, "n_self_play_groups": 0, "q_gen_attempts": 0, "q_gen_valid": 0, "q_gen_valid_rate": 0.0, "mean_question_reward": 0.0, "q_quality_rate": 0.0, "q_topic_match": 0.0, "q_difficulty_fit": 0.0, "q_clarity": 0.0, "q_novelty": 0.0, "q_solvability": 0.0}
|
| 10 |
+
{"iteration": 9, "loss": 5.9516330460004004e-05, "mean_reward": 0.906506146327221, "std_reward": 0.1769136401553803, "batch_accuracy": 0.9646464646464646, "grounded_accuracy": 0.9646464646464646, "gt_match_rate": 0.803, "step_accuracy": 0.8935726310726311, "lccp": 0.8175324675324676, "n_groups": 15, "skipped_groups": 5, "learning_rate": 4.995894997002465e-06, "iter_time_s": 221.35066892812029, "training_phase": "GROUNDED_ONLY", "effective_sp_ratio": 0.0, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.0, "extraction_success_rate": 0.0, "chain_scoring_active": 0, "n_self_play_groups": 0, "q_gen_attempts": 0, "q_gen_valid": 0, "q_gen_valid_rate": 0.0, "mean_question_reward": 0.0, "q_quality_rate": 0.0, "q_topic_match": 0.0, "q_difficulty_fit": 0.0, "q_clarity": 0.0, "q_novelty": 0.0, "q_solvability": 0.0}
|
| 11 |
+
{"iteration": 10, "loss": 0.0004252615440886335, "mean_reward": 0.8807654454859567, "std_reward": 0.17550108931309533, "batch_accuracy": 0.9547738693467337, "grounded_accuracy": 0.9547738693467337, "gt_match_rate": 0.6834, "step_accuracy": 0.9205, "lccp": 0.8426, "n_groups": 14, "skipped_groups": 6, "learning_rate": 4.983594966720622e-06, "iter_time_s": 188.98177218902856, "training_phase": "GROUNDED_ONLY", "effective_sp_ratio": 0.0, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.0, "extraction_success_rate": 0.0, "chain_scoring_active": 0, "n_self_play_groups": 0, "q_gen_attempts": 0, "q_gen_valid": 0, "q_gen_valid_rate": 0.0, "mean_question_reward": 0.0, "q_quality_rate": 0.0, "q_topic_match": 0.0, "q_difficulty_fit": 0.0, "q_clarity": 0.0, "q_novelty": 0.0, "q_solvability": 0.0, "accuracy": 0.9199, "combined_score": 0.9199, "correct_rate": 0.7933, "prm_mean": 0.9066, "prm_final": 0.9408, "sympy_mean": 0.0, "format_mean": 0.998, "n_scored": 150, "total": 150, "final_answer_correct": 119, "final_answer_accuracy": 0.7933333333333333}
|
| 12 |
+
{"iteration": 11, "loss": -0.0005566358695432427, "mean_reward": 0.9698135460130081, "std_reward": 0.0983216960471261, "batch_accuracy": 0.985, "grounded_accuracy": 0.985, "gt_match_rate": 0.93, "step_accuracy": 0.9662678571428571, "lccp": 0.9218095238095237, "n_groups": 8, "skipped_groups": 12, "learning_rate": 4.963144790631074e-06, "iter_time_s": 141.96677790791728, "training_phase": "GROUNDED_ONLY", "effective_sp_ratio": 0.0, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.0, "extraction_success_rate": 0.0, "chain_scoring_active": 0, "n_self_play_groups": 0, "q_gen_attempts": 0, "q_gen_valid": 0, "q_gen_valid_rate": 0.0, "mean_question_reward": 0.0, "q_quality_rate": 0.0, "q_topic_match": 0.0, "q_difficulty_fit": 0.0, "q_clarity": 0.0, "q_novelty": 0.0, "q_solvability": 0.0}
|
| 13 |
+
{"iteration": 12, "loss": 7.270745637859883e-05, "mean_reward": 0.8492740230597824, "std_reward": 0.2128636238290247, "batch_accuracy": 0.9, "grounded_accuracy": 0.9, "gt_match_rate": 0.65, "step_accuracy": 0.8205257936507937, "lccp": 0.6872718253968253, "n_groups": 14, "skipped_groups": 6, "learning_rate": 4.934619089208618e-06, "iter_time_s": 177.9547567779664, "training_phase": "SELFPLAY_RAMP", "effective_sp_ratio": 0.0, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.0, "extraction_success_rate": 0.0, "chain_scoring_active": 0, "n_self_play_groups": 0, "q_gen_attempts": 0, "q_gen_valid": 0, "q_gen_valid_rate": 0.0, "mean_question_reward": 0.0, "q_quality_rate": 0.0, "q_topic_match": 0.0, "q_difficulty_fit": 0.0, "q_clarity": 0.0, "q_novelty": 0.0, "q_solvability": 0.0}
|
| 14 |
+
{"iteration": 13, "loss": 0.00026773045517204864, "mean_reward": 0.8988236995312778, "std_reward": 0.18599151493605476, "batch_accuracy": 0.93, "grounded_accuracy": 0.93, "gt_match_rate": 0.78, "step_accuracy": 0.8709603174603174, "lccp": 0.7887301587301587, "n_groups": 14, "skipped_groups": 6, "learning_rate": 4.898121949644228e-06, "iter_time_s": 556.1856374200433, "training_phase": "SELFPLAY_RAMP", "effective_sp_ratio": 0.0, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": -0.04, "extraction_success_rate": 1.0, "chain_scoring_active": 0, "n_self_play_groups": 0, "q_gen_attempts": 0, "q_gen_valid": 0, "q_gen_valid_rate": 0.0, "mean_question_reward": 0.0, "q_quality_rate": 0.0, "q_topic_match": 0.0, "q_difficulty_fit": 0.0, "q_clarity": 0.0, "q_novelty": 0.0, "q_solvability": 0.0}
|
| 15 |
+
{"iteration": 14, "loss": 0.0004961729192069066, "mean_reward": 0.8558324048863098, "std_reward": 0.20849902292009304, "batch_accuracy": 0.9523809523809523, "grounded_accuracy": 0.9473684210526315, "gt_match_rate": 0.6737, "step_accuracy": 0.8576065162907268, "lccp": 0.7478070175438597, "n_groups": 18, "skipped_groups": 3, "learning_rate": 4.853786546042184e-06, "iter_time_s": 568.4005180909298, "training_phase": "SELFPLAY_RAMP", "effective_sp_ratio": 0.036, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.209, "extraction_success_rate": 1.0, "chain_scoring_active": 0, "n_self_play_groups": 1, "q_gen_attempts": 1, "q_gen_valid": 1, "q_gen_valid_rate": 1.0, "mean_question_reward": 0.763, "q_quality_rate": 1.0, "q_topic_match": 0.575, "q_difficulty_fit": 0.89, "q_clarity": 1.0, "q_novelty": 0.4289, "q_solvability": 1.0}
|
| 16 |
+
{"iteration": 15, "loss": 2.3262581635208335e-05, "mean_reward": 0.927972135586315, "std_reward": 0.16718736928397065, "batch_accuracy": 0.937799043062201, "grounded_accuracy": 0.9312169312169312, "gt_match_rate": 0.836, "step_accuracy": 0.9242, "lccp": 0.8424, "n_groups": 12, "skipped_groups": 9, "learning_rate": 4.801774653482204e-06, "iter_time_s": 550.1437717408407, "training_phase": "SELFPLAY_RAMP", "effective_sp_ratio": 0.071, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.079, "extraction_success_rate": 0.98, "chain_scoring_active": 0, "n_self_play_groups": 1, "q_gen_attempts": 1, "q_gen_valid": 1, "q_gen_valid_rate": 1.0, "mean_question_reward": 0.7218, "q_quality_rate": 1.0, "q_topic_match": 0.35, "q_difficulty_fit": 0.9511, "q_clarity": 1.0, "q_novelty": 0.458, "q_solvability": 1.0, "accuracy": 0.9262, "combined_score": 0.9262, "correct_rate": 0.8, "prm_mean": 0.9072, "prm_final": 0.9404, "sympy_mean": 0.0, "format_mean": 1.0, "n_scored": 150, "total": 150, "final_answer_correct": 120, "final_answer_accuracy": 0.8}
|
| 17 |
+
{"iteration": 16, "loss": 0.0003296181123005226, "mean_reward": 0.9146047620088099, "std_reward": 0.17273258044260062, "batch_accuracy": 0.9497716894977168, "grounded_accuracy": 0.9385474860335196, "gt_match_rate": 0.8324, "step_accuracy": 0.8955234709424654, "lccp": 0.8438994897095455, "n_groups": 15, "skipped_groups": 7, "learning_rate": 4.742276057719723e-06, "iter_time_s": 575.5289459908381, "training_phase": "SELFPLAY_RAMP", "effective_sp_ratio": 0.107, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.089, "extraction_success_rate": 0.94, "chain_scoring_active": 0, "n_self_play_groups": 2, "q_gen_attempts": 2, "q_gen_valid": 2, "q_gen_valid_rate": 1.0, "mean_question_reward": 0.7878, "q_quality_rate": 1.0, "q_topic_match": 0.875, "q_difficulty_fit": 0.5838, "q_clarity": 1.0, "q_novelty": 0.4475, "q_solvability": 0.96}
|
| 18 |
+
{"iteration": 17, "loss": -0.00013719029248022708, "mean_reward": 0.8881227328092163, "std_reward": 0.1950058307020988, "batch_accuracy": 0.9383259911894273, "grounded_accuracy": 0.9161676646706587, "gt_match_rate": 0.7006, "step_accuracy": 0.8557955517536356, "lccp": 0.7682349586541203, "n_groups": 20, "skipped_groups": 3, "learning_rate": 4.675507862678258e-06, "iter_time_s": 616.0185732548125, "training_phase": "SELFPLAY_RAMP", "effective_sp_ratio": 0.143, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": -0.191, "extraction_success_rate": 1.0, "chain_scoring_active": 0, "n_self_play_groups": 3, "q_gen_attempts": 3, "q_gen_valid": 3, "q_gen_valid_rate": 1.0, "mean_question_reward": 0.7982, "q_quality_rate": 1.0, "q_topic_match": 0.69, "q_difficulty_fit": 0.8892, "q_clarity": 1.0, "q_novelty": 0.4616, "q_solvability": 1.0}
|
| 19 |
+
{"iteration": 18, "loss": 7.917114673641903e-05, "mean_reward": 0.8664005137011263, "std_reward": 0.178010205898339, "batch_accuracy": 0.9539748953974896, "grounded_accuracy": 0.9433962264150944, "gt_match_rate": 0.5912, "step_accuracy": 0.830780173704702, "lccp": 0.6920110811620246, "n_groups": 19, "skipped_groups": 5, "learning_rate": 4.601713698260728e-06, "iter_time_s": 550.572628196096, "training_phase": "SELFPLAY_RAMP", "effective_sp_ratio": 0.179, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.021, "extraction_success_rate": 0.98, "chain_scoring_active": 0, "n_self_play_groups": 4, "q_gen_attempts": 4, "q_gen_valid": 4, "q_gen_valid_rate": 1.0, "mean_question_reward": 0.7394, "q_quality_rate": 1.0, "q_topic_match": 0.6375, "q_difficulty_fit": 0.6293, "q_clarity": 1.0, "q_novelty": 0.452, "q_solvability": 0.9762}
|
| 20 |
+
{"iteration": 19, "loss": 0.00015087392284840462, "mean_reward": 0.8912812767256229, "std_reward": 0.1726645221785555, "batch_accuracy": 0.9535864978902954, "grounded_accuracy": 0.9490445859872612, "gt_match_rate": 0.7643, "step_accuracy": 0.8513975055376328, "lccp": 0.7568744772566428, "n_groups": 16, "skipped_groups": 8, "learning_rate": 4.521162831370364e-06, "iter_time_s": 522.4289600129705, "training_phase": "SELFPLAY_RAMP", "effective_sp_ratio": 0.214, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.075, "extraction_success_rate": 0.98, "chain_scoring_active": 0, "n_self_play_groups": 4, "q_gen_attempts": 4, "q_gen_valid": 4, "q_gen_valid_rate": 1.0, "mean_question_reward": 0.7331, "q_quality_rate": 1.0, "q_topic_match": 0.4813, "q_difficulty_fit": 0.8466, "q_clarity": 1.0, "q_novelty": 0.4564, "q_solvability": 0.9725}
|
| 21 |
+
{"iteration": 20, "loss": 0.00024373266084391312, "mean_reward": 0.8962914079724992, "std_reward": 0.1778417367801085, "batch_accuracy": 0.927710843373494, "grounded_accuracy": 0.9060402684563759, "gt_match_rate": 0.7987, "step_accuracy": 0.9253, "lccp": 0.8428, "n_groups": 18, "skipped_groups": 7, "learning_rate": 4.434149183384978e-06, "iter_time_s": 619.8863487117924, "training_phase": "SELFPLAY_RAMP", "effective_sp_ratio": 0.25, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": -0.118, "extraction_success_rate": 0.96, "chain_scoring_active": 0, "n_self_play_groups": 5, "q_gen_attempts": 5, "q_gen_valid": 5, "q_gen_valid_rate": 1.0, "mean_question_reward": 0.77, "q_quality_rate": 1.0, "q_topic_match": 0.723, "q_difficulty_fit": 0.703, "q_clarity": 1.0, "q_novelty": 0.4741, "q_solvability": 0.945, "accuracy": 0.9234, "combined_score": 0.9234, "correct_rate": 0.8, "prm_mean": 0.9056, "prm_final": 0.9353, "sympy_mean": 0.0, "format_mean": 1.0, "n_scored": 150, "total": 150, "final_answer_correct": 120, "final_answer_accuracy": 0.8}
|
| 22 |
+
{"iteration": 21, "loss": 0.0001916794737033862, "mean_reward": 0.8417323480901788, "std_reward": 0.1879809468583581, "batch_accuracy": 0.9230769230769231, "grounded_accuracy": 0.9142857142857143, "gt_match_rate": 0.7357, "step_accuracy": 0.8195039682539682, "lccp": 0.6930612244897959, "n_groups": 21, "skipped_groups": 5, "learning_rate": 4.340990257669732e-06, "iter_time_s": 490.36693838005885, "training_phase": "SELFPLAY_RAMP", "effective_sp_ratio": 0.286, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.209, "extraction_success_rate": 0.92, "chain_scoring_active": 0, "n_self_play_groups": 6, "q_gen_attempts": 6, "q_gen_valid": 6, "q_gen_valid_rate": 1.0, "mean_question_reward": 0.6972, "q_quality_rate": 1.0, "q_topic_match": 0.5742, "q_difficulty_fit": 0.4754, "q_clarity": 1.0, "q_novelty": 0.4493, "q_solvability": 0.9625}
|
| 23 |
+
{"iteration": 22, "loss": 0.000578732604299148, "mean_reward": 0.9175190043251262, "std_reward": 0.12424225720214971, "batch_accuracy": 0.984313725490196, "grounded_accuracy": 0.9852941176470589, "gt_match_rate": 0.9044, "step_accuracy": 0.9647345301757068, "lccp": 0.9284886681945506, "n_groups": 20, "skipped_groups": 6, "learning_rate": 4.2420259810417895e-06, "iter_time_s": 611.8722857821267, "training_phase": "SELFPLAY_RAMP", "effective_sp_ratio": 0.321, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.145, "extraction_success_rate": 0.92, "chain_scoring_active": 0, "n_self_play_groups": 6, "q_gen_attempts": 6, "q_gen_valid": 6, "q_gen_valid_rate": 1.0, "mean_question_reward": 0.6998, "q_quality_rate": 1.0, "q_topic_match": 0.6189, "q_difficulty_fit": 0.3856, "q_clarity": 1.0, "q_novelty": 0.4571, "q_solvability": 0.979}
|
| 24 |
+
{"iteration": 23, "loss": 0.0006137362383419208, "mean_reward": 0.9206978778568132, "std_reward": 0.14741914089456262, "batch_accuracy": 0.9770114942528736, "grounded_accuracy": 0.9508196721311475, "gt_match_rate": 0.8033, "step_accuracy": 0.9075003548364204, "lccp": 0.847631466893762, "n_groups": 18, "skipped_groups": 9, "learning_rate": 4.137617463414222e-06, "iter_time_s": 444.32088500098325, "training_phase": "SELFPLAY_RAMP", "effective_sp_ratio": 0.357, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.143, "extraction_success_rate": 1.0, "chain_scoring_active": 0, "n_self_play_groups": 7, "q_gen_attempts": 7, "q_gen_valid": 7, "q_gen_valid_rate": 1.0, "mean_question_reward": 0.726, "q_quality_rate": 1.0, "q_topic_match": 0.5621, "q_difficulty_fit": 0.6634, "q_clarity": 1.0, "q_novelty": 0.4412, "q_solvability": 0.9885}
|
| 25 |
+
{"iteration": 24, "loss": -0.00021296025724950595, "mean_reward": 0.8795895609748888, "std_reward": 0.1733128827089799, "batch_accuracy": 0.9357142857142857, "grounded_accuracy": 0.9333333333333333, "gt_match_rate": 0.7917, "step_accuracy": 0.8988194444444446, "lccp": 0.8122916666666666, "n_groups": 20, "skipped_groups": 8, "learning_rate": 4.0281456801451e-06, "iter_time_s": 471.6989622868132, "training_phase": "SELFPLAY_RAMP", "effective_sp_ratio": 0.393, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.082, "extraction_success_rate": 0.98, "chain_scoring_active": 0, "n_self_play_groups": 8, "q_gen_attempts": 8, "q_gen_valid": 8, "q_gen_valid_rate": 1.0, "mean_question_reward": 0.6621, "q_quality_rate": 1.0, "q_topic_match": 0.5344, "q_difficulty_fit": 0.3108, "q_clarity": 1.0, "q_novelty": 0.4408, "q_solvability": 0.9688}
|
| 26 |
+
{"iteration": 25, "loss": 0.0003441530472758002, "mean_reward": 0.8445275205076134, "std_reward": 0.20865777545087066, "batch_accuracy": 0.9273356401384083, "grounded_accuracy": 0.8532110091743119, "gt_match_rate": 0.6055, "step_accuracy": 0.9198, "lccp": 0.8468, "n_groups": 28, "skipped_groups": 1, "learning_rate": 3.9140100818997275e-06, "iter_time_s": 524.655717118876, "training_phase": "SELFPLAY_RAMP", "effective_sp_ratio": 0.429, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.127, "extraction_success_rate": 0.94, "chain_scoring_active": 0, "n_self_play_groups": 9, "q_gen_attempts": 9, "q_gen_valid": 9, "q_gen_valid_rate": 1.0, "mean_question_reward": 0.6471, "q_quality_rate": 1.0, "q_topic_match": 0.505, "q_difficulty_fit": 0.2634, "q_clarity": 1.0, "q_novelty": 0.4394, "q_solvability": 0.9672, "accuracy": 0.9221, "combined_score": 0.9221, "correct_rate": 0.7933, "prm_mean": 0.9034, "prm_final": 0.9329, "sympy_mean": 0.0, "format_mean": 1.0, "n_scored": 150, "total": 150, "final_answer_correct": 119, "final_answer_accuracy": 0.7933333333333333}
|
| 27 |
+
{"iteration": 26, "loss": 0.0004209962865808428, "mean_reward": 0.8666489827432893, "std_reward": 0.1796360842988206, "batch_accuracy": 0.9204152249134948, "grounded_accuracy": 0.926605504587156, "gt_match_rate": 0.789, "step_accuracy": 0.8898463666812292, "lccp": 0.7943024610455803, "n_groups": 26, "skipped_groups": 3, "learning_rate": 3.795627137098479e-06, "iter_time_s": 509.6774504878558, "training_phase": "SELFPLAY_RAMP", "effective_sp_ratio": 0.464, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.065, "extraction_success_rate": 0.94, "chain_scoring_active": 0, "n_self_play_groups": 9, "q_gen_attempts": 9, "q_gen_valid": 9, "q_gen_valid_rate": 1.0, "mean_question_reward": 0.6792, "q_quality_rate": 1.0, "q_topic_match": 0.6639, "q_difficulty_fit": 0.2476, "q_clarity": 1.0, "q_novelty": 0.4488, "q_solvability": 0.9317}
|
| 28 |
+
{"iteration": 27, "loss": -0.00022697661013808103, "mean_reward": 0.877933982604161, "std_reward": 0.1628662024521015, "batch_accuracy": 0.9563758389261745, "grounded_accuracy": 0.9393939393939394, "gt_match_rate": 0.6869, "step_accuracy": 0.8616281866281865, "lccp": 0.7406565656565657, "n_groups": 25, "skipped_groups": 5, "learning_rate": 3.673428812268702e-06, "iter_time_s": 597.5212381640449, "training_phase": "SELFPLAY_RAMP", "effective_sp_ratio": 0.5, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.067, "extraction_success_rate": 0.92, "chain_scoring_active": 0, "n_self_play_groups": 10, "q_gen_attempts": 10, "q_gen_valid": 10, "q_gen_valid_rate": 1.0, "mean_question_reward": 0.6831, "q_quality_rate": 1.0, "q_topic_match": 0.5699, "q_difficulty_fit": 0.3583, "q_clarity": 1.0, "q_novelty": 0.4584, "q_solvability": 0.9759}
|
| 29 |
+
{"iteration": 28, "loss": 4.199455770111822e-05, "mean_reward": 0.8695997487614422, "std_reward": 0.15915376074701193, "batch_accuracy": 0.9419354838709677, "grounded_accuracy": 0.8777777777777778, "gt_match_rate": 0.6556, "step_accuracy": 0.8334434828062279, "lccp": 0.6186230200445887, "n_groups": 29, "skipped_groups": 2, "learning_rate": 3.5478609958457035e-06, "iter_time_s": 603.0997926741838, "training_phase": "SELFPLAY_RAMP", "effective_sp_ratio": 0.536, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.047, "extraction_success_rate": 0.8, "chain_scoring_active": 0, "n_self_play_groups": 11, "q_gen_attempts": 11, "q_gen_valid": 11, "q_gen_valid_rate": 1.0, "mean_question_reward": 0.6693, "q_quality_rate": 1.0, "q_topic_match": 0.5931, "q_difficulty_fit": 0.23, "q_clarity": 1.0, "q_novelty": 0.4489, "q_solvability": 0.9836}
|
| 30 |
+
{"iteration": 29, "loss": 0.0003765096731578004, "mean_reward": 0.8674408392873937, "std_reward": 0.17082623284979875, "batch_accuracy": 0.9470198675496688, "grounded_accuracy": 0.8928571428571429, "gt_match_rate": 0.7262, "step_accuracy": 0.8674065194639727, "lccp": 0.7603936306964257, "n_groups": 28, "skipped_groups": 3, "learning_rate": 3.419381871174205e-06, "iter_time_s": 579.6904674370307, "training_phase": "SELFPLAY_RAMP", "effective_sp_ratio": 0.571, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.123, "extraction_success_rate": 0.84, "chain_scoring_active": 0, "n_self_play_groups": 11, "q_gen_attempts": 11, "q_gen_valid": 11, "q_gen_valid_rate": 1.0, "mean_question_reward": 0.6496, "q_quality_rate": 1.0, "q_topic_match": 0.5636, "q_difficulty_fit": 0.1695, "q_clarity": 1.0, "q_novelty": 0.4425, "q_solvability": 0.9739}
|
| 31 |
+
{"iteration": 30, "loss": -0.00029927124827130075, "mean_reward": 0.8705812118012987, "std_reward": 0.16025951815561293, "batch_accuracy": 0.9655172413793104, "grounded_accuracy": 0.95, "gt_match_rate": 0.8, "step_accuracy": 0.9232, "lccp": 0.85, "n_groups": 27, "skipped_groups": 5, "learning_rate": 3.2884602446470037e-06, "iter_time_s": 503.08798154001124, "training_phase": "SELFPLAY_RAMP", "effective_sp_ratio": 0.607, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.099, "extraction_success_rate": 0.92, "chain_scoring_active": 0, "n_self_play_groups": 12, "q_gen_attempts": 12, "q_gen_valid": 12, "q_gen_valid_rate": 1.0, "mean_question_reward": 0.6764, "q_quality_rate": 1.0, "q_topic_match": 0.6752, "q_difficulty_fit": 0.1485, "q_clarity": 1.0, "q_novelty": 0.4566, "q_solvability": 0.9699, "accuracy": 0.9204, "combined_score": 0.9204, "correct_rate": 0.7933, "prm_mean": 0.9044, "prm_final": 0.9289, "sympy_mean": 0.0, "format_mean": 1.0, "n_scored": 150, "total": 150, "final_answer_correct": 119, "final_answer_accuracy": 0.7933333333333333}
|
models.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Data models for the AxiomForgeAI math RL environment.
|
| 9 |
+
|
| 10 |
+
The AxiomForgeAI environment presents math questions drawn from an adaptive
|
| 11 |
+
curriculum; external agents submit step-by-step solutions and receive scored
|
| 12 |
+
observations. The environment integrates with the GRPO training pipeline
|
| 13 |
+
defined in scripts/run_grpo_training.py.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from openenv.core.env_server.types import Action, Observation
|
| 17 |
+
from pydantic import Field
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class AxiomforgeaiAction(Action):
|
| 21 |
+
"""Action for the AxiomForgeAI math environment.
|
| 22 |
+
|
| 23 |
+
The agent submits a step-by-step solution to the current question.
|
| 24 |
+
Solutions should follow the format::
|
| 25 |
+
|
| 26 |
+
Step 1: <reasoning>
|
| 27 |
+
Step 2: <reasoning>
|
| 28 |
+
...
|
| 29 |
+
Final Answer: <numeric value>
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
solution: str = Field(
|
| 33 |
+
default="",
|
| 34 |
+
description=(
|
| 35 |
+
"Step-by-step solution to the current math question. "
|
| 36 |
+
"Use 'Step N: ...' lines and end with 'Final Answer: <value>'."
|
| 37 |
+
),
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class AxiomforgeaiObservation(Observation):
|
| 42 |
+
"""Observation from the AxiomForgeAI math environment.
|
| 43 |
+
|
| 44 |
+
On reset the question is populated and reward/feedback are empty.
|
| 45 |
+
After a step the reward and feedback reflect the quality of the submitted
|
| 46 |
+
solution; done=True signals the end of the single-step episode.
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
question: str = Field(
|
| 50 |
+
default="",
|
| 51 |
+
description="Math question the agent must solve.",
|
| 52 |
+
)
|
| 53 |
+
topic: str = Field(
|
| 54 |
+
default="",
|
| 55 |
+
description="Mathematical topic of the question (e.g. 'algebra', 'geometry').",
|
| 56 |
+
)
|
| 57 |
+
difficulty: float = Field(
|
| 58 |
+
default=0.5,
|
| 59 |
+
description="Estimated difficulty of the question in [0, 1].",
|
| 60 |
+
)
|
| 61 |
+
feedback: str = Field(
|
| 62 |
+
default="",
|
| 63 |
+
description=(
|
| 64 |
+
"Human-readable feedback on the submitted solution "
|
| 65 |
+
"(empty on reset, populated after step)."
|
| 66 |
+
),
|
| 67 |
+
)
|
openenv.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spec_version: 1
|
| 2 |
+
name: AxiomForgeAI
|
| 3 |
+
type: space
|
| 4 |
+
runtime: fastapi
|
| 5 |
+
app: server.app:app
|
| 6 |
+
port: 8000
|
| 7 |
+
|
pyproject.toml
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
[build-system]
|
| 8 |
+
requires = ["setuptools>=45", "wheel"]
|
| 9 |
+
build-backend = "setuptools.build_meta"
|
| 10 |
+
|
| 11 |
+
[project]
|
| 12 |
+
name = "openenv-AxiomForgeAI"
|
| 13 |
+
version = "0.1.0"
|
| 14 |
+
description = "Axiomforgeai environment for OpenEnv"
|
| 15 |
+
requires-python = ">=3.10"
|
| 16 |
+
dependencies = [
|
| 17 |
+
# Core OpenEnv runtime (provides FastAPI server + HTTP client types)
|
| 18 |
+
# install from github
|
| 19 |
+
# "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git",
|
| 20 |
+
"openenv-core[core]>=0.2.2",
|
| 21 |
+
# Environment-specific dependencies
|
| 22 |
+
# Add all dependencies needed for your environment here
|
| 23 |
+
# Examples:
|
| 24 |
+
# "numpy>=1.19.0",
|
| 25 |
+
# "torch>=2.0.0",
|
| 26 |
+
# "gymnasium>=0.29.0",
|
| 27 |
+
# "openspiel>=1.0.0",
|
| 28 |
+
# "smolagents>=1.22.0,<2",
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
[project.optional-dependencies]
|
| 32 |
+
dev = [
|
| 33 |
+
"pytest>=8.0.0",
|
| 34 |
+
"pytest-cov>=4.0.0",
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
[project.scripts]
|
| 38 |
+
# Server entry point - enables running via: uv run --project . server
|
| 39 |
+
# or: python -m AxiomForgeAI.server.app
|
| 40 |
+
server = "AxiomForgeAI.server.app:main"
|
| 41 |
+
|
| 42 |
+
[tool.setuptools]
|
| 43 |
+
include-package-data = true
|
| 44 |
+
packages = [
|
| 45 |
+
"AxiomForgeAI",
|
| 46 |
+
"AxiomForgeAI.server",
|
| 47 |
+
"src",
|
| 48 |
+
"src.config",
|
| 49 |
+
"src.rl",
|
| 50 |
+
"src.sft",
|
| 51 |
+
"src.utils",
|
| 52 |
+
"src.self_play",
|
| 53 |
+
"scripts",
|
| 54 |
+
]
|
| 55 |
+
package-dir = { "AxiomForgeAI" = ".", "AxiomForgeAI.server" = "server" }
|
requirements.txt
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate==1.2.1
|
| 2 |
+
aiohappyeyeballs==2.6.1
|
| 3 |
+
aiohttp==3.13.5
|
| 4 |
+
aiohttp-cors==0.8.1
|
| 5 |
+
aiosignal==1.4.0
|
| 6 |
+
airportsdata==20260315
|
| 7 |
+
annotated-doc==0.0.4
|
| 8 |
+
annotated-types==0.7.0
|
| 9 |
+
anyio==4.13.0
|
| 10 |
+
astor==0.8.1
|
| 11 |
+
attrs==26.1.0
|
| 12 |
+
bitsandbytes==0.44.1
|
| 13 |
+
blake3==1.0.8
|
| 14 |
+
certifi==2026.4.22
|
| 15 |
+
cffi==2.0.0
|
| 16 |
+
charset-normalizer==3.4.7
|
| 17 |
+
click==8.3.2
|
| 18 |
+
cloudpickle==3.1.2
|
| 19 |
+
colorful==0.5.8
|
| 20 |
+
compressed-tensors==0.9.0
|
| 21 |
+
cryptography==46.0.7
|
| 22 |
+
datasets==3.2.0
|
| 23 |
+
depyf==0.18.0
|
| 24 |
+
dill==0.3.8
|
| 25 |
+
diskcache==5.6.3
|
| 26 |
+
distlib==0.4.0
|
| 27 |
+
distro==1.9.0
|
| 28 |
+
einops==0.8.2
|
| 29 |
+
fastapi==0.136.0
|
| 30 |
+
filelock==3.29.0
|
| 31 |
+
frozenlist==1.8.0
|
| 32 |
+
fsspec==2024.9.0
|
| 33 |
+
gguf==0.10.0
|
| 34 |
+
google-api-core==2.30.3
|
| 35 |
+
google-auth==2.49.2
|
| 36 |
+
googleapis-common-protos==1.74.0
|
| 37 |
+
grpcio==1.80.0
|
| 38 |
+
h11==0.16.0
|
| 39 |
+
hf-xet==1.4.3
|
| 40 |
+
hjson==3.1.0
|
| 41 |
+
httpcore==1.0.9
|
| 42 |
+
httptools==0.7.1
|
| 43 |
+
httpx==0.28.1
|
| 44 |
+
huggingface-hub==0.36.2
|
| 45 |
+
idna==3.12
|
| 46 |
+
importlib-metadata==9.0.0
|
| 47 |
+
interegular==0.3.3
|
| 48 |
+
jinja2==3.1.6
|
| 49 |
+
jiter==0.14.0
|
| 50 |
+
jsonschema==4.26.0
|
| 51 |
+
jsonschema-specifications==2025.9.1
|
| 52 |
+
lark==1.2.2
|
| 53 |
+
linkify-it-py==2.1.0
|
| 54 |
+
lm-format-enforcer==0.10.12
|
| 55 |
+
markdown-it-py==4.0.0
|
| 56 |
+
markupsafe==3.0.3
|
| 57 |
+
mdit-py-plugins==0.5.0
|
| 58 |
+
mdurl==0.1.2
|
| 59 |
+
memray==1.19.3
|
| 60 |
+
mistral-common==1.11.0
|
| 61 |
+
mpmath==1.3.0
|
| 62 |
+
msgpack==1.1.2
|
| 63 |
+
msgspec==0.21.1
|
| 64 |
+
multidict==6.7.1
|
| 65 |
+
multiprocess==0.70.16
|
| 66 |
+
nest-asyncio==1.6.0
|
| 67 |
+
networkx==3.6.1
|
| 68 |
+
ninja==1.13.0
|
| 69 |
+
numpy==1.26.4
|
| 70 |
+
nvidia-cublas-cu12==12.4.5.8
|
| 71 |
+
nvidia-cuda-cupti-cu12==12.4.127
|
| 72 |
+
nvidia-cuda-nvrtc-cu12==12.4.127
|
| 73 |
+
nvidia-cuda-runtime-cu12==12.4.127
|
| 74 |
+
nvidia-cudnn-cu12==9.1.0.70
|
| 75 |
+
nvidia-cufft-cu12==11.2.1.3
|
| 76 |
+
nvidia-curand-cu12==10.3.5.147
|
| 77 |
+
nvidia-cusolver-cu12==11.6.1.9
|
| 78 |
+
nvidia-cusparse-cu12==12.3.1.170
|
| 79 |
+
nvidia-ml-py==13.595.45
|
| 80 |
+
nvidia-nccl-cu12==2.21.5
|
| 81 |
+
nvidia-nvjitlink-cu12==12.4.127
|
| 82 |
+
nvidia-nvtx-cu12==12.4.127
|
| 83 |
+
openai==2.32.0
|
| 84 |
+
opencensus==0.11.4
|
| 85 |
+
opencensus-context==0.1.3
|
| 86 |
+
opencv-python-headless==4.11.0.86
|
| 87 |
+
outlines==0.1.11
|
| 88 |
+
outlines-core==0.1.26
|
| 89 |
+
packaging==26.1
|
| 90 |
+
pandas==3.0.2
|
| 91 |
+
partial-json-parser==0.2.1.1.post7
|
| 92 |
+
peft==0.19.1
|
| 93 |
+
pillow==12.2.0
|
| 94 |
+
platformdirs==4.9.6
|
| 95 |
+
prometheus-client==0.25.0
|
| 96 |
+
prometheus-fastapi-instrumentator==7.1.0
|
| 97 |
+
propcache==0.4.1
|
| 98 |
+
proto-plus==1.27.2
|
| 99 |
+
protobuf==7.34.1
|
| 100 |
+
psutil==7.2.2
|
| 101 |
+
py-cpuinfo==9.0.0
|
| 102 |
+
py-spy==0.4.1
|
| 103 |
+
pyarrow==24.0.0
|
| 104 |
+
pyasn1==0.6.3
|
| 105 |
+
pyasn1-modules==0.4.2
|
| 106 |
+
pycountry==26.2.16
|
| 107 |
+
pycparser==3.0
|
| 108 |
+
pydantic==2.13.3
|
| 109 |
+
pydantic-core==2.46.3
|
| 110 |
+
pydantic-extra-types==2.11.1
|
| 111 |
+
pygments==2.20.0
|
| 112 |
+
python-dateutil==2.9.0.post0
|
| 113 |
+
python-discovery==1.2.2
|
| 114 |
+
python-dotenv==1.2.2
|
| 115 |
+
pyyaml==6.0.3
|
| 116 |
+
pyzmq==27.1.0
|
| 117 |
+
ray==2.39.0
|
| 118 |
+
referencing==0.37.0
|
| 119 |
+
regex==2026.4.4
|
| 120 |
+
requests==2.33.1
|
| 121 |
+
rich==15.0.0
|
| 122 |
+
rpds-py==0.30.0
|
| 123 |
+
safetensors==0.7.0
|
| 124 |
+
scipy>=1.14.0
|
| 125 |
+
sentencepiece==0.2.1
|
| 126 |
+
setuptools==82.0.1
|
| 127 |
+
six==1.17.0
|
| 128 |
+
smart-open==7.6.0
|
| 129 |
+
sniffio==1.3.1
|
| 130 |
+
starlette==0.52.1
|
| 131 |
+
sympy==1.13.1
|
| 132 |
+
textual==8.2.4
|
| 133 |
+
tiktoken==0.12.0
|
| 134 |
+
tokenizers==0.20.3
|
| 135 |
+
torch==2.5.1
|
| 136 |
+
torchaudio==2.5.1
|
| 137 |
+
torchvision==0.20.1
|
| 138 |
+
tqdm==4.67.3
|
| 139 |
+
transformers==4.46.3
|
| 140 |
+
triton==3.1.0
|
| 141 |
+
trl==0.12.1
|
| 142 |
+
typing-extensions==4.15.0
|
| 143 |
+
typing-inspection==0.4.2
|
| 144 |
+
uc-micro-py==2.0.0
|
| 145 |
+
urllib3==2.6.3
|
| 146 |
+
uvicorn==0.45.0
|
| 147 |
+
uvloop==0.22.1
|
| 148 |
+
virtualenv==21.2.4
|
| 149 |
+
vllm==0.7.0
|
| 150 |
+
watchfiles==1.1.1
|
| 151 |
+
websockets==16.0
|
| 152 |
+
wrapt==2.1.2
|
| 153 |
+
xformers==0.0.28.post3
|
| 154 |
+
xgrammar==0.1.33
|
| 155 |
+
xxhash==3.6.0
|
| 156 |
+
yarl==1.23.0
|
| 157 |
+
zipp==3.23.1
|
| 158 |
+
matplotlib==3.10.9
|
| 159 |
+
flash-attn==2.8.3
|
| 160 |
+
gradio>=4.44.0
|
scripts/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Training and evaluation scripts for math reasoning models."""
|
scripts/convert_gsm8k_to_sft.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Convert OpenAI GSM8K to SFT JSONL aligned with MathAgent solver format:
|
| 4 |
+
|
| 5 |
+
Step 1: ...
|
| 6 |
+
Step 2: ...
|
| 7 |
+
...
|
| 8 |
+
Final Answer: <integer>
|
| 9 |
+
|
| 10 |
+
Each record uses a chat messages list for Qwen-style fine-tuning.
|
| 11 |
+
|
| 12 |
+
Usage
|
| 13 |
+
-----
|
| 14 |
+
# From Hugging Face (default; same data as in test.ipynb)
|
| 15 |
+
python scripts/convert_gsm8k_to_sft.py \\
|
| 16 |
+
--output data/sft/gsm8k_sft.jsonl \\
|
| 17 |
+
--splits train test
|
| 18 |
+
|
| 19 |
+
# From a saved JSONL with columns \"question\" and \"answer\" (GSM8K schema)
|
| 20 |
+
python scripts/convert_gsm8k_to_sft.py \\
|
| 21 |
+
--source jsonl \\
|
| 22 |
+
--input path/to/file.jsonl \\
|
| 23 |
+
--output data/sft/gsm8k_sft.jsonl
|
| 24 |
+
|
| 25 |
+
Requires: pip install datasets (and datasets will pull pyarrow as needed)
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
from __future__ import annotations
|
| 29 |
+
|
| 30 |
+
import argparse
|
| 31 |
+
import json
|
| 32 |
+
import re
|
| 33 |
+
from pathlib import Path
|
| 34 |
+
from typing import Any, Iterator
|
| 35 |
+
|
| 36 |
+
# Keep in sync with src.agent.math_agent.SOLVER_SYSTEM_PROMPT
|
| 37 |
+
SOLVER_SYSTEM_PROMPT = (
|
| 38 |
+
"You are a step-by-step math solver. "
|
| 39 |
+
"Solve the given problem one step at a time. "
|
| 40 |
+
"Each step must be on its own line, starting with 'Step N:'. "
|
| 41 |
+
"End with a line starting with 'Final Answer:'. "
|
| 42 |
+
"Write every mathematical expression in Python/SymPy syntax "
|
| 43 |
+
"so it can be verified programmatically."
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
USER_WRAPPER = (
|
| 47 |
+
"Solve the following problem. Show your reasoning as numbered steps, "
|
| 48 |
+
"then give the final numeric answer on the last line.\n\nProblem:\n{question}"
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def parse_gsm8k_answer(raw_answer: str) -> tuple[str, str]:
|
| 53 |
+
"""
|
| 54 |
+
Split GSM8K 'answer' field into reasoning text and final integer string.
|
| 55 |
+
|
| 56 |
+
GSM8K ends solutions with a line like: #### 42
|
| 57 |
+
"""
|
| 58 |
+
text = raw_answer.strip()
|
| 59 |
+
parts = re.split(r"\s*####\s*", text, maxsplit=1)
|
| 60 |
+
reasoning = parts[0].strip()
|
| 61 |
+
final = parts[1].strip() if len(parts) > 1 else ""
|
| 62 |
+
# Normalize final (sometimes extra whitespace or commas)
|
| 63 |
+
final = re.sub(r"[,\s]+", "", final)
|
| 64 |
+
final_match = re.search(r"-?\d+", final)
|
| 65 |
+
final_clean = final_match.group(0) if final_match else final
|
| 66 |
+
return reasoning, final_clean
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def reasoning_to_step_lines(reasoning: str) -> list[str]:
|
| 70 |
+
"""Turn reasoning into non-empty lines; each line becomes one Step N:."""
|
| 71 |
+
lines: list[str] = []
|
| 72 |
+
for raw in reasoning.splitlines():
|
| 73 |
+
line = raw.strip()
|
| 74 |
+
if line:
|
| 75 |
+
lines.append(line)
|
| 76 |
+
if not lines:
|
| 77 |
+
# Rare: single blob without newlines β split on sentence boundaries lightly
|
| 78 |
+
blob = reasoning.strip()
|
| 79 |
+
if blob:
|
| 80 |
+
chunks = re.split(r"(?<=[.!?])\s+", blob)
|
| 81 |
+
lines = [c.strip() for c in chunks if c.strip()]
|
| 82 |
+
return lines
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def build_assistant_content(reasoning: str, final_answer: str) -> str:
|
| 86 |
+
lines = reasoning_to_step_lines(reasoning)
|
| 87 |
+
out_parts: list[str] = []
|
| 88 |
+
for i, line in enumerate(lines, start=1):
|
| 89 |
+
# Prefer SymPy-friendly numerics: ** not ^, ascii-friendly
|
| 90 |
+
cleaned = line.replace("^", "**")
|
| 91 |
+
out_parts.append(f"Step {i}: {cleaned}")
|
| 92 |
+
body = "\n".join(out_parts)
|
| 93 |
+
if final_answer:
|
| 94 |
+
body = f"{body}\nFinal Answer: {final_answer}" if body else f"Final Answer: {final_answer}"
|
| 95 |
+
return body
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def row_to_record(
|
| 99 |
+
question: str,
|
| 100 |
+
answer: str,
|
| 101 |
+
example_id: str,
|
| 102 |
+
split: str,
|
| 103 |
+
) -> dict[str, Any] | None:
|
| 104 |
+
reasoning, final_answer = parse_gsm8k_answer(answer)
|
| 105 |
+
if not final_answer and "####" not in answer:
|
| 106 |
+
return None
|
| 107 |
+
assistant = build_assistant_content(reasoning, final_answer)
|
| 108 |
+
if not assistant.strip():
|
| 109 |
+
return None
|
| 110 |
+
|
| 111 |
+
user_content = USER_WRAPPER.format(question=question.strip())
|
| 112 |
+
|
| 113 |
+
return {
|
| 114 |
+
"id": f"gsm8k_{example_id}",
|
| 115 |
+
"skill_id": "gsm8k_grade_school",
|
| 116 |
+
"source": "openai/gsm8k",
|
| 117 |
+
"split": split,
|
| 118 |
+
"messages": [
|
| 119 |
+
{"role": "system", "content": SOLVER_SYSTEM_PROMPT},
|
| 120 |
+
{"role": "user", "content": user_content},
|
| 121 |
+
{"role": "assistant", "content": assistant},
|
| 122 |
+
],
|
| 123 |
+
# Convenience for non-chat trainers
|
| 124 |
+
"text": f"<|system|>\n{SOLVER_SYSTEM_PROMPT}\n<|user|>\n{user_content}\n<|assistant|>\n{assistant}",
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def iter_hf_rows(dataset_name: str, config: str, splits: list[str]) -> Iterator[tuple[str, str, dict]]:
|
| 129 |
+
from datasets import load_dataset
|
| 130 |
+
|
| 131 |
+
ds = load_dataset(dataset_name, config)
|
| 132 |
+
for split in splits:
|
| 133 |
+
if split not in ds:
|
| 134 |
+
raise KeyError(f"Split {split!r} not in dataset. Available: {list(ds.keys())}")
|
| 135 |
+
for i, row in enumerate(ds[split]):
|
| 136 |
+
yield f"{split}_{i}", split, row
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def main() -> None:
|
| 140 |
+
p = argparse.ArgumentParser(description="Convert GSM8K to SFT JSONL (chat messages).")
|
| 141 |
+
p.add_argument(
|
| 142 |
+
"--source",
|
| 143 |
+
choices=("hf", "jsonl"),
|
| 144 |
+
default="hf",
|
| 145 |
+
help="Load from Hugging Face dataset or a local JSONL file.",
|
| 146 |
+
)
|
| 147 |
+
p.add_argument("--dataset", default="openai/gsm8k", help="HF dataset id when --source hf.")
|
| 148 |
+
p.add_argument("--config", default="main", help="HF config name when --source hf.")
|
| 149 |
+
p.add_argument("--splits", nargs="+", default=["train", "test"], help="HF splits to export.")
|
| 150 |
+
p.add_argument("--input", type=Path, help="Local JSONL path when --source jsonl.")
|
| 151 |
+
p.add_argument(
|
| 152 |
+
"--output",
|
| 153 |
+
type=Path,
|
| 154 |
+
default=Path("data/sft/gsm8k_sft.jsonl"),
|
| 155 |
+
help="Output JSONL path.",
|
| 156 |
+
)
|
| 157 |
+
args = p.parse_args()
|
| 158 |
+
|
| 159 |
+
if args.source == "jsonl" and not args.input:
|
| 160 |
+
raise SystemExit("--input is required when --source jsonl")
|
| 161 |
+
|
| 162 |
+
args.output.parent.mkdir(parents=True, exist_ok=True)
|
| 163 |
+
|
| 164 |
+
n_ok, n_skip = 0, 0
|
| 165 |
+
|
| 166 |
+
def process(example_id: str, split: str, row: dict) -> None:
|
| 167 |
+
nonlocal n_ok, n_skip
|
| 168 |
+
q = row.get("question", "")
|
| 169 |
+
a = row.get("answer", "")
|
| 170 |
+
rec = row_to_record(q, a, example_id, split)
|
| 171 |
+
if rec is None:
|
| 172 |
+
n_skip += 1
|
| 173 |
+
return
|
| 174 |
+
out_f.write(json.dumps(rec, ensure_ascii=False) + "\n")
|
| 175 |
+
n_ok += 1
|
| 176 |
+
|
| 177 |
+
with args.output.open("w", encoding="utf-8") as out_f:
|
| 178 |
+
if args.source == "hf":
|
| 179 |
+
for example_id, split, row in iter_hf_rows(args.dataset, args.config, args.splits):
|
| 180 |
+
process(example_id, split, row)
|
| 181 |
+
else:
|
| 182 |
+
for i, line in enumerate(args.input.open(encoding="utf-8")):
|
| 183 |
+
line = line.strip()
|
| 184 |
+
if not line:
|
| 185 |
+
continue
|
| 186 |
+
row = json.loads(line)
|
| 187 |
+
process(str(i), "jsonl", row)
|
| 188 |
+
|
| 189 |
+
print(f"Wrote {n_ok} examples to {args.output} ({n_skip} skipped).")
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
if __name__ == "__main__":
|
| 193 |
+
main()
|
scripts/create_dual_task_dataset.py
ADDED
|
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Create dual-task training dataset by mixing question-generation and solution-generation examples.
|
| 4 |
+
|
| 5 |
+
This script:
|
| 6 |
+
1. Loads existing solution data (GSM8K format)
|
| 7 |
+
2. Loads question-generation data (synthetic)
|
| 8 |
+
3. Adds task prefixes to distinguish tasks
|
| 9 |
+
4. Mixes datasets according to specified ratio
|
| 10 |
+
5. Shuffles and splits into train/validation
|
| 11 |
+
|
| 12 |
+
Usage:
|
| 13 |
+
python scripts/create_dual_task_dataset.py \
|
| 14 |
+
--solution-data data/sft/gsm8k_sft.jsonl \
|
| 15 |
+
--question-data data/sft/question_generation.jsonl \
|
| 16 |
+
--output-train data/sft/dual_task_train.jsonl \
|
| 17 |
+
--output-val data/sft/dual_task_val.jsonl \
|
| 18 |
+
--mix-ratio 0.8 \
|
| 19 |
+
--val-split 0.1
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
from __future__ import annotations
|
| 23 |
+
|
| 24 |
+
import argparse
|
| 25 |
+
import json
|
| 26 |
+
import random
|
| 27 |
+
import sys
|
| 28 |
+
from pathlib import Path
|
| 29 |
+
from typing import Any
|
| 30 |
+
|
| 31 |
+
ROOT = Path(__file__).resolve().parents[1]
|
| 32 |
+
sys.path.insert(0, str(ROOT))
|
| 33 |
+
|
| 34 |
+
from src.config.prompts import SOLVE_TASK_PREFIX, GENERATE_TASK_PREFIX
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def load_jsonl(path: Path) -> list[dict[str, Any]]:
|
| 38 |
+
"""Load JSONL file into list of records."""
|
| 39 |
+
records = []
|
| 40 |
+
with path.open(encoding="utf-8") as f:
|
| 41 |
+
for line in f:
|
| 42 |
+
line = line.strip()
|
| 43 |
+
if line:
|
| 44 |
+
records.append(json.loads(line))
|
| 45 |
+
return records
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def add_solve_prefix(record: dict[str, Any]) -> dict[str, Any]:
|
| 49 |
+
"""
|
| 50 |
+
Add 'Solve Problem' task prefix to user message.
|
| 51 |
+
|
| 52 |
+
This signals the model to generate a step-by-step solution.
|
| 53 |
+
"""
|
| 54 |
+
modified = record.copy()
|
| 55 |
+
modified["messages"] = []
|
| 56 |
+
|
| 57 |
+
for msg in record["messages"]:
|
| 58 |
+
new_msg = msg.copy()
|
| 59 |
+
if msg["role"] == "user":
|
| 60 |
+
# Add task prefix to user content
|
| 61 |
+
content = msg["content"]
|
| 62 |
+
if not content.startswith(SOLVE_TASK_PREFIX):
|
| 63 |
+
new_msg["content"] = SOLVE_TASK_PREFIX + content
|
| 64 |
+
modified["messages"].append(new_msg)
|
| 65 |
+
|
| 66 |
+
# Update text field if present
|
| 67 |
+
if "text" in modified:
|
| 68 |
+
# Find and update user section
|
| 69 |
+
text = modified["text"]
|
| 70 |
+
if "<|user|>" in text:
|
| 71 |
+
parts = text.split("<|user|>")
|
| 72 |
+
if len(parts) > 1:
|
| 73 |
+
user_part = parts[1]
|
| 74 |
+
if not user_part.strip().startswith(SOLVE_TASK_PREFIX):
|
| 75 |
+
parts[1] = f"\n{SOLVE_TASK_PREFIX}" + user_part
|
| 76 |
+
modified["text"] = "<|user|>".join(parts)
|
| 77 |
+
|
| 78 |
+
# Mark as solve task
|
| 79 |
+
modified["task_type"] = "solve"
|
| 80 |
+
|
| 81 |
+
return modified
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def verify_question_prefix(record: dict[str, Any]) -> dict[str, Any]:
|
| 85 |
+
"""
|
| 86 |
+
Verify question generation record has proper prefix.
|
| 87 |
+
|
| 88 |
+
Should already have it from generation script, but double-check.
|
| 89 |
+
"""
|
| 90 |
+
modified = record.copy()
|
| 91 |
+
modified["messages"] = []
|
| 92 |
+
|
| 93 |
+
for msg in record["messages"]:
|
| 94 |
+
new_msg = msg.copy()
|
| 95 |
+
if msg["role"] == "user":
|
| 96 |
+
content = msg["content"]
|
| 97 |
+
if not content.startswith(GENERATE_TASK_PREFIX):
|
| 98 |
+
new_msg["content"] = GENERATE_TASK_PREFIX + content
|
| 99 |
+
modified["messages"].append(new_msg)
|
| 100 |
+
|
| 101 |
+
# Update text field if present
|
| 102 |
+
if "text" in modified:
|
| 103 |
+
text = modified["text"]
|
| 104 |
+
if "<|user|>" in text:
|
| 105 |
+
parts = text.split("<|user|>")
|
| 106 |
+
if len(parts) > 1:
|
| 107 |
+
user_part = parts[1]
|
| 108 |
+
if not user_part.strip().startswith(GENERATE_TASK_PREFIX):
|
| 109 |
+
parts[1] = f"\n{GENERATE_TASK_PREFIX}" + user_part
|
| 110 |
+
modified["text"] = "<|user|>".join(parts)
|
| 111 |
+
|
| 112 |
+
# Mark as question generation task
|
| 113 |
+
modified["task_type"] = "generate"
|
| 114 |
+
|
| 115 |
+
return modified
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def sample_with_ratio(
|
| 119 |
+
solution_records: list[dict[str, Any]],
|
| 120 |
+
question_records: list[dict[str, Any]],
|
| 121 |
+
mix_ratio: float,
|
| 122 |
+
target_total: int | None = None,
|
| 123 |
+
) -> list[dict[str, Any]]:
|
| 124 |
+
"""
|
| 125 |
+
Sample and mix datasets according to specified ratio.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
solution_records: Solution examples
|
| 129 |
+
question_records: Question generation examples
|
| 130 |
+
mix_ratio: Fraction of solutions in final dataset (0.8 = 80% solutions, 20% questions)
|
| 131 |
+
target_total: Target total examples (None = use all available data)
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
Mixed dataset
|
| 135 |
+
"""
|
| 136 |
+
n_solutions = len(solution_records)
|
| 137 |
+
n_questions = len(question_records)
|
| 138 |
+
|
| 139 |
+
if target_total is None:
|
| 140 |
+
# Use all available data
|
| 141 |
+
target_total = n_solutions + n_questions
|
| 142 |
+
|
| 143 |
+
# Calculate target counts
|
| 144 |
+
n_sol_target = int(target_total * mix_ratio)
|
| 145 |
+
n_q_target = target_total - n_sol_target
|
| 146 |
+
|
| 147 |
+
# Check availability
|
| 148 |
+
if n_sol_target > n_solutions:
|
| 149 |
+
print(f"Warning: Requested {n_sol_target} solutions but only {n_solutions} available.")
|
| 150 |
+
n_sol_target = n_solutions
|
| 151 |
+
|
| 152 |
+
if n_q_target > n_questions:
|
| 153 |
+
print(f"Warning: Requested {n_q_target} questions but only {n_questions} available.")
|
| 154 |
+
n_q_target = n_questions
|
| 155 |
+
|
| 156 |
+
# Sample
|
| 157 |
+
selected_solutions = random.sample(solution_records, n_sol_target)
|
| 158 |
+
selected_questions = random.sample(question_records, n_q_target)
|
| 159 |
+
|
| 160 |
+
print(f"Sampled {n_sol_target} solutions and {n_q_target} questions")
|
| 161 |
+
print(f"Actual ratio: {n_sol_target/(n_sol_target+n_q_target):.2%} solutions, "
|
| 162 |
+
f"{n_q_target/(n_sol_target+n_q_target):.2%} questions")
|
| 163 |
+
|
| 164 |
+
return selected_solutions + selected_questions
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def write_jsonl(records: list[dict[str, Any]], path: Path) -> None:
|
| 168 |
+
"""Write records to JSONL file."""
|
| 169 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 170 |
+
with path.open("w", encoding="utf-8") as f:
|
| 171 |
+
for record in records:
|
| 172 |
+
f.write(json.dumps(record, ensure_ascii=False) + "\n")
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def main() -> None:
|
| 176 |
+
parser = argparse.ArgumentParser(
|
| 177 |
+
description="Create dual-task training dataset from solution and question-generation examples."
|
| 178 |
+
)
|
| 179 |
+
parser.add_argument(
|
| 180 |
+
"--solution-data",
|
| 181 |
+
type=Path,
|
| 182 |
+
required=True,
|
| 183 |
+
help="Path to solution training data (GSM8K format)",
|
| 184 |
+
)
|
| 185 |
+
parser.add_argument(
|
| 186 |
+
"--question-data",
|
| 187 |
+
type=Path,
|
| 188 |
+
required=True,
|
| 189 |
+
help="Path to question-generation training data",
|
| 190 |
+
)
|
| 191 |
+
parser.add_argument(
|
| 192 |
+
"--output-train",
|
| 193 |
+
type=Path,
|
| 194 |
+
required=True,
|
| 195 |
+
help="Output path for training split",
|
| 196 |
+
)
|
| 197 |
+
parser.add_argument(
|
| 198 |
+
"--output-val",
|
| 199 |
+
type=Path,
|
| 200 |
+
required=True,
|
| 201 |
+
help="Output path for validation split",
|
| 202 |
+
)
|
| 203 |
+
parser.add_argument(
|
| 204 |
+
"--mix-ratio",
|
| 205 |
+
type=float,
|
| 206 |
+
default=0.8,
|
| 207 |
+
help="Fraction of solutions in mixed dataset (default: 0.8 = 80%% solutions)",
|
| 208 |
+
)
|
| 209 |
+
parser.add_argument(
|
| 210 |
+
"--val-split",
|
| 211 |
+
type=float,
|
| 212 |
+
default=0.1,
|
| 213 |
+
help="Fraction of data to use for validation (default: 0.1 = 10%%)",
|
| 214 |
+
)
|
| 215 |
+
parser.add_argument(
|
| 216 |
+
"--seed",
|
| 217 |
+
type=int,
|
| 218 |
+
default=42,
|
| 219 |
+
help="Random seed for reproducibility",
|
| 220 |
+
)
|
| 221 |
+
parser.add_argument(
|
| 222 |
+
"--max-total",
|
| 223 |
+
type=int,
|
| 224 |
+
default=None,
|
| 225 |
+
help="Maximum total examples to include (None = use all available)",
|
| 226 |
+
)
|
| 227 |
+
args = parser.parse_args()
|
| 228 |
+
|
| 229 |
+
# Validate inputs
|
| 230 |
+
if not args.solution_data.exists():
|
| 231 |
+
raise SystemExit(f"Error: Solution data not found at {args.solution_data}")
|
| 232 |
+
if not args.question_data.exists():
|
| 233 |
+
raise SystemExit(f"Error: Question data not found at {args.question_data}")
|
| 234 |
+
|
| 235 |
+
if not (0 < args.mix_ratio < 1):
|
| 236 |
+
raise SystemExit("Error: --mix-ratio must be between 0 and 1")
|
| 237 |
+
if not (0 < args.val_split < 1):
|
| 238 |
+
raise SystemExit("Error: --val-split must be between 0 and 1")
|
| 239 |
+
|
| 240 |
+
# Set random seed
|
| 241 |
+
random.seed(args.seed)
|
| 242 |
+
|
| 243 |
+
print("=" * 60)
|
| 244 |
+
print("Dual-Task Dataset Creation")
|
| 245 |
+
print("=" * 60)
|
| 246 |
+
|
| 247 |
+
# Load data
|
| 248 |
+
print("\n1. Loading data...")
|
| 249 |
+
print(f" Solution data: {args.solution_data}")
|
| 250 |
+
solution_records = load_jsonl(args.solution_data)
|
| 251 |
+
print(f" Loaded {len(solution_records)} solution examples")
|
| 252 |
+
|
| 253 |
+
print(f" Question data: {args.question_data}")
|
| 254 |
+
question_records = load_jsonl(args.question_data)
|
| 255 |
+
print(f" Loaded {len(question_records)} question-generation examples")
|
| 256 |
+
|
| 257 |
+
# Add task prefixes
|
| 258 |
+
print("\n2. Adding task prefixes...")
|
| 259 |
+
print(" Adding 'Solve Problem' prefix to solution examples...")
|
| 260 |
+
solution_records = [add_solve_prefix(r) for r in solution_records]
|
| 261 |
+
|
| 262 |
+
print(" Verifying 'Generate Question' prefix on question examples...")
|
| 263 |
+
question_records = [verify_question_prefix(r) for r in question_records]
|
| 264 |
+
|
| 265 |
+
# Mix datasets
|
| 266 |
+
print(f"\n3. Mixing datasets (ratio: {args.mix_ratio:.0%} solutions, {1-args.mix_ratio:.0%} questions)...")
|
| 267 |
+
mixed_records = sample_with_ratio(
|
| 268 |
+
solution_records=solution_records,
|
| 269 |
+
question_records=question_records,
|
| 270 |
+
mix_ratio=args.mix_ratio,
|
| 271 |
+
target_total=args.max_total,
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
# Shuffle
|
| 275 |
+
print(f"\n4. Shuffling {len(mixed_records)} total examples...")
|
| 276 |
+
random.shuffle(mixed_records)
|
| 277 |
+
|
| 278 |
+
# Split train/val
|
| 279 |
+
n_val = int(len(mixed_records) * args.val_split)
|
| 280 |
+
n_train = len(mixed_records) - n_val
|
| 281 |
+
|
| 282 |
+
train_records = mixed_records[:n_train]
|
| 283 |
+
val_records = mixed_records[n_train:]
|
| 284 |
+
|
| 285 |
+
print(f"\n5. Splitting data:")
|
| 286 |
+
print(f" Training: {len(train_records)} examples ({len(train_records)/len(mixed_records):.1%})")
|
| 287 |
+
print(f" Validation: {len(val_records)} examples ({len(val_records)/len(mixed_records):.1%})")
|
| 288 |
+
|
| 289 |
+
# Verify split composition
|
| 290 |
+
train_solve = sum(1 for r in train_records if r.get("task_type") == "solve")
|
| 291 |
+
train_gen = sum(1 for r in train_records if r.get("task_type") == "generate")
|
| 292 |
+
val_solve = sum(1 for r in val_records if r.get("task_type") == "solve")
|
| 293 |
+
val_gen = sum(1 for r in val_records if r.get("task_type") == "generate")
|
| 294 |
+
|
| 295 |
+
print(f"\n Train composition:")
|
| 296 |
+
print(f" Solve: {train_solve} ({train_solve/len(train_records):.1%})")
|
| 297 |
+
print(f" Generate: {train_gen} ({train_gen/len(train_records):.1%})")
|
| 298 |
+
|
| 299 |
+
print(f" Val composition:")
|
| 300 |
+
print(f" Solve: {val_solve} ({val_solve/len(val_records):.1%})")
|
| 301 |
+
print(f" Generate: {val_gen} ({val_gen/len(val_records):.1%})")
|
| 302 |
+
|
| 303 |
+
# Write outputs
|
| 304 |
+
print(f"\n6. Writing output files...")
|
| 305 |
+
print(f" Training data: {args.output_train}")
|
| 306 |
+
write_jsonl(train_records, args.output_train)
|
| 307 |
+
|
| 308 |
+
print(f" Validation data: {args.output_val}")
|
| 309 |
+
write_jsonl(val_records, args.output_val)
|
| 310 |
+
|
| 311 |
+
print("\n" + "=" * 60)
|
| 312 |
+
print("Dual-task dataset creation complete!")
|
| 313 |
+
print("=" * 60)
|
| 314 |
+
print(f"\nOutput files:")
|
| 315 |
+
print(f" Train: {args.output_train} ({len(train_records)} examples)")
|
| 316 |
+
print(f" Val: {args.output_val} ({len(val_records)} examples)")
|
| 317 |
+
print(f"\nNext step: Train dual-task model using these files")
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
if __name__ == "__main__":
|
| 321 |
+
main()
|
scripts/demo_before_after.py
ADDED
|
@@ -0,0 +1,591 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Before / after demo β baseline vs GRPO-trained policy.
|
| 2 |
+
|
| 3 |
+
Designed for hackathon judges: loads both models, runs greedy evaluation on
|
| 4 |
+
a fixed problem set, and prints a clean side-by-side comparison with full
|
| 5 |
+
solution text for the most interesting examples.
|
| 6 |
+
|
| 7 |
+
Features
|
| 8 |
+
--------
|
| 9 |
+
* Handles all checkpoint types: HF model IDs, GRPO full-weight saves,
|
| 10 |
+
PEFT/LoRA adapter directories.
|
| 11 |
+
* Automatically loads the chat template from the base model when the
|
| 12 |
+
checkpoint tokenizer doesn't have one (fixes the 0% accuracy bug that
|
| 13 |
+
silently swallows TemplateErrors).
|
| 14 |
+
* Reads ``metrics.jsonl`` (if present) and prints the full accuracy curve,
|
| 15 |
+
showing judges the training progression at a glance.
|
| 16 |
+
* Saves machine-readable JSON (for grading scripts) and prints a human-
|
| 17 |
+
readable Markdown table.
|
| 18 |
+
* Shows full solution text for the best wins and worst regressions.
|
| 19 |
+
|
| 20 |
+
Quick-start
|
| 21 |
+
-----------
|
| 22 |
+
After a GRPO run, point at ``best_policy/``::
|
| 23 |
+
|
| 24 |
+
python scripts/demo_before_after.py \\
|
| 25 |
+
--baseline-model checkpoints/dual_task_v1 \\
|
| 26 |
+
--trained-model checkpoints/grpo/<run>/best_policy \\
|
| 27 |
+
--problems data/sft/gsm8k_sft.jsonl \\
|
| 28 |
+
--max-samples 100
|
| 29 |
+
|
| 30 |
+
Include the training curve::
|
| 31 |
+
|
| 32 |
+
python scripts/demo_before_after.py \\
|
| 33 |
+
--baseline-model checkpoints/dual_task_v1 \\
|
| 34 |
+
--trained-model checkpoints/grpo/<run>/best_policy \\
|
| 35 |
+
--metrics-jsonl checkpoints/grpo/<run>/metrics.jsonl \\
|
| 36 |
+
--problems data/sft/gsm8k_sft.jsonl \\
|
| 37 |
+
--max-samples 100 \\
|
| 38 |
+
--records-out results/demo.json
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
from __future__ import annotations
|
| 42 |
+
|
| 43 |
+
import argparse
|
| 44 |
+
import json
|
| 45 |
+
import logging
|
| 46 |
+
import re
|
| 47 |
+
import sys
|
| 48 |
+
import time
|
| 49 |
+
import types
|
| 50 |
+
from dataclasses import dataclass, field
|
| 51 |
+
from pathlib import Path
|
| 52 |
+
from typing import Dict, List, Optional, Tuple
|
| 53 |
+
|
| 54 |
+
import torch
|
| 55 |
+
from peft import PeftModel
|
| 56 |
+
from tqdm.auto import tqdm
|
| 57 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 58 |
+
|
| 59 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 60 |
+
from src.sft.solution_format import extract_final_answer_numeric_str
|
| 61 |
+
from src.utils.attn_backend import select_attn_implementation
|
| 62 |
+
|
| 63 |
+
logging.basicConfig(
|
| 64 |
+
level=logging.INFO,
|
| 65 |
+
format="%(asctime)s %(levelname)-8s %(name)s - %(message)s",
|
| 66 |
+
)
|
| 67 |
+
logger = logging.getLogger(__name__)
|
| 68 |
+
|
| 69 |
+
_SEP = "=" * 78
|
| 70 |
+
_SEP2 = "-" * 78
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
# ---------------------------------------------------------------------------
|
| 74 |
+
# Data
|
| 75 |
+
# ---------------------------------------------------------------------------
|
| 76 |
+
|
| 77 |
+
@dataclass
|
| 78 |
+
class Problem:
|
| 79 |
+
question: str
|
| 80 |
+
gold_final: str
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def _parse_gold(answer: str) -> str:
|
| 84 |
+
m = re.search(r"####\s*([-0-9.,/ ]+)", answer)
|
| 85 |
+
if m:
|
| 86 |
+
return m.group(1).strip().replace(",", "")
|
| 87 |
+
return answer.strip().splitlines()[-1].strip()
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def _load_problems(path: Path, max_samples: int) -> List[Problem]:
|
| 91 |
+
"""Accept GSM8K ``{question, answer}`` or SFT ``{messages}`` JSONL."""
|
| 92 |
+
out: List[Problem] = []
|
| 93 |
+
with path.open(encoding="utf-8") as fh:
|
| 94 |
+
for line in fh:
|
| 95 |
+
if max_samples > 0 and len(out) >= max_samples:
|
| 96 |
+
break
|
| 97 |
+
line = line.strip()
|
| 98 |
+
if not line:
|
| 99 |
+
continue
|
| 100 |
+
obj = json.loads(line)
|
| 101 |
+
if "question" in obj and "answer" in obj:
|
| 102 |
+
out.append(Problem(
|
| 103 |
+
question=obj["question"].strip(),
|
| 104 |
+
gold_final=_parse_gold(obj["answer"]),
|
| 105 |
+
))
|
| 106 |
+
elif "messages" in obj:
|
| 107 |
+
user = next(
|
| 108 |
+
(m["content"] for m in obj["messages"] if m.get("role") == "user"), ""
|
| 109 |
+
).strip()
|
| 110 |
+
asst = next(
|
| 111 |
+
(m["content"] for m in obj["messages"] if m.get("role") == "assistant"), ""
|
| 112 |
+
)
|
| 113 |
+
gold = extract_final_answer_numeric_str(asst) or ""
|
| 114 |
+
out.append(Problem(question=user, gold_final=gold.strip()))
|
| 115 |
+
return out
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
# ---------------------------------------------------------------------------
|
| 119 |
+
# Model loading β handles HF IDs, full-weight saves, and PEFT adapters
|
| 120 |
+
# ---------------------------------------------------------------------------
|
| 121 |
+
|
| 122 |
+
def _ensure_chat_template(
|
| 123 |
+
tokenizer: AutoTokenizer,
|
| 124 |
+
fallback_model: str = "Qwen/Qwen2.5-Math-1.5B-Instruct",
|
| 125 |
+
) -> None:
|
| 126 |
+
"""Load chat template from *fallback_model* when the checkpoint lacks one.
|
| 127 |
+
|
| 128 |
+
SFT adapter checkpoints often omit the chat_template from their tokenizer
|
| 129 |
+
config. Without it, ``apply_chat_template`` raises a TemplateError that
|
| 130 |
+
is silently swallowed inside ``evaluate_gsm8k``, returning 0% accuracy.
|
| 131 |
+
"""
|
| 132 |
+
if tokenizer.chat_template is not None:
|
| 133 |
+
return
|
| 134 |
+
logger.info("Tokenizer missing chat_template β loading from %s", fallback_model)
|
| 135 |
+
try:
|
| 136 |
+
_base_tok = AutoTokenizer.from_pretrained(fallback_model, trust_remote_code=True)
|
| 137 |
+
if _base_tok.chat_template is not None:
|
| 138 |
+
tokenizer.chat_template = _base_tok.chat_template
|
| 139 |
+
logger.info("Chat template loaded.")
|
| 140 |
+
except Exception as exc:
|
| 141 |
+
logger.warning("Could not load chat template: %s", exc)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def _load_model(
|
| 145 |
+
checkpoint: str,
|
| 146 |
+
base_model_id: str,
|
| 147 |
+
device: torch.device,
|
| 148 |
+
dtype: torch.dtype,
|
| 149 |
+
attn_impl: str,
|
| 150 |
+
) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
|
| 151 |
+
"""Load model + tokenizer from any checkpoint style.
|
| 152 |
+
|
| 153 |
+
Handles:
|
| 154 |
+
* HuggingFace model ID (e.g. ``Qwen/Qwen2.5-Math-1.5B-Instruct``)
|
| 155 |
+
* GRPO full-weight save (directory with ``model.safetensors`` / pytorch_model*)
|
| 156 |
+
* PEFT/LoRA adapter dir (directory with ``adapter_config.json``)
|
| 157 |
+
"""
|
| 158 |
+
# PEFT shim β prevents crash in merge_and_unload on some versions.
|
| 159 |
+
if "transformers.integrations.tensor_parallel" not in sys.modules:
|
| 160 |
+
sys.modules["transformers.integrations.tensor_parallel"] = types.ModuleType(
|
| 161 |
+
"tensor_parallel"
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
ckpt_path = Path(checkpoint)
|
| 165 |
+
is_adapter = ckpt_path.is_dir() and (ckpt_path / "adapter_config.json").exists()
|
| 166 |
+
is_local_full = ckpt_path.is_dir() and not is_adapter
|
| 167 |
+
|
| 168 |
+
# Tokenizer
|
| 169 |
+
tok_src = checkpoint if (ckpt_path.is_dir() and (ckpt_path / "tokenizer_config.json").exists()) else base_model_id
|
| 170 |
+
tokenizer = AutoTokenizer.from_pretrained(tok_src, trust_remote_code=True)
|
| 171 |
+
if tokenizer.pad_token is None:
|
| 172 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 173 |
+
tokenizer.padding_side = "left" # standard for generation
|
| 174 |
+
_ensure_chat_template(tokenizer, fallback_model=base_model_id)
|
| 175 |
+
|
| 176 |
+
load_kw = dict(
|
| 177 |
+
torch_dtype=dtype,
|
| 178 |
+
low_cpu_mem_usage=True,
|
| 179 |
+
device_map={"": device},
|
| 180 |
+
trust_remote_code=True,
|
| 181 |
+
attn_implementation=attn_impl,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
if is_adapter:
|
| 185 |
+
# Read base model from pipeline_meta.json if present
|
| 186 |
+
meta_file = ckpt_path / "pipeline_meta.json"
|
| 187 |
+
_base = base_model_id
|
| 188 |
+
if meta_file.exists():
|
| 189 |
+
_base = json.loads(meta_file.read_text()).get("base_model", _base)
|
| 190 |
+
logger.info("PEFT adapter β loading base %s then merging %s", _base, checkpoint)
|
| 191 |
+
_base_mdl = AutoModelForCausalLM.from_pretrained(_base, **load_kw)
|
| 192 |
+
model = PeftModel.from_pretrained(_base_mdl, checkpoint).merge_and_unload()
|
| 193 |
+
model = model.to(device)
|
| 194 |
+
else:
|
| 195 |
+
# Full weights (GRPO save) or HF model ID
|
| 196 |
+
src = checkpoint if is_local_full else checkpoint
|
| 197 |
+
logger.info("Loading full-weight model from %s", src)
|
| 198 |
+
model = AutoModelForCausalLM.from_pretrained(src, **load_kw)
|
| 199 |
+
|
| 200 |
+
# Re-enable requires_grad isn't needed for eval, but ensure eval mode.
|
| 201 |
+
model.eval()
|
| 202 |
+
n = sum(p.numel() for p in model.parameters())
|
| 203 |
+
logger.info("Loaded: %s (%.2fB params, %.1f GB VRAM est.)",
|
| 204 |
+
checkpoint, n / 1e9, n * 2 / 1e9)
|
| 205 |
+
return model, tokenizer
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
# ---------------------------------------------------------------------------
|
| 209 |
+
# Generation
|
| 210 |
+
# ---------------------------------------------------------------------------
|
| 211 |
+
|
| 212 |
+
def _build_prompt(tokenizer: AutoTokenizer, question: str) -> str:
|
| 213 |
+
"""Format question using the model's chat template (matches training format)."""
|
| 214 |
+
if tokenizer.chat_template is None:
|
| 215 |
+
return question
|
| 216 |
+
msgs = [
|
| 217 |
+
{"role": "system", "content": "You are a helpful math assistant. Solve the problem step-by-step and end with 'Final Answer: <number>'."},
|
| 218 |
+
{"role": "user", "content": question},
|
| 219 |
+
]
|
| 220 |
+
try:
|
| 221 |
+
return tokenizer.apply_chat_template(
|
| 222 |
+
msgs, tokenize=False, add_generation_prompt=True
|
| 223 |
+
)
|
| 224 |
+
except Exception:
|
| 225 |
+
return question
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def _stop_ids(tokenizer: AutoTokenizer) -> List[int]:
|
| 229 |
+
ids = []
|
| 230 |
+
if tokenizer.eos_token_id is not None:
|
| 231 |
+
ids.append(tokenizer.eos_token_id)
|
| 232 |
+
im_end = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
| 233 |
+
if isinstance(im_end, int) and im_end not in ids:
|
| 234 |
+
ids.append(im_end)
|
| 235 |
+
return ids or None # type: ignore[return-value]
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
@torch.no_grad()
|
| 239 |
+
def _generate(
|
| 240 |
+
model: AutoModelForCausalLM,
|
| 241 |
+
tokenizer: AutoTokenizer,
|
| 242 |
+
question: str,
|
| 243 |
+
max_new_tokens: int,
|
| 244 |
+
device: torch.device,
|
| 245 |
+
) -> str:
|
| 246 |
+
prompt = _build_prompt(tokenizer, question)
|
| 247 |
+
enc = tokenizer(
|
| 248 |
+
prompt,
|
| 249 |
+
return_tensors="pt",
|
| 250 |
+
truncation=True,
|
| 251 |
+
max_length=1024,
|
| 252 |
+
).to(device)
|
| 253 |
+
prompt_len = enc["input_ids"].shape[1]
|
| 254 |
+
|
| 255 |
+
out = model.generate(
|
| 256 |
+
input_ids=enc["input_ids"],
|
| 257 |
+
attention_mask=enc["attention_mask"],
|
| 258 |
+
max_new_tokens=max_new_tokens,
|
| 259 |
+
do_sample=False, # greedy β deterministic for reproducibility
|
| 260 |
+
temperature=1.0,
|
| 261 |
+
pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
|
| 262 |
+
eos_token_id=_stop_ids(tokenizer),
|
| 263 |
+
use_cache=True,
|
| 264 |
+
)
|
| 265 |
+
return tokenizer.decode(out[0][prompt_len:], skip_special_tokens=True)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
# ---------------------------------------------------------------------------
|
| 269 |
+
# Scoring
|
| 270 |
+
# ---------------------------------------------------------------------------
|
| 271 |
+
|
| 272 |
+
def _normalize(x: str) -> str:
|
| 273 |
+
if not x:
|
| 274 |
+
return ""
|
| 275 |
+
s = x.strip().replace(",", "").replace("$", "").strip()
|
| 276 |
+
try:
|
| 277 |
+
f = float(s)
|
| 278 |
+
return f"{int(f)}" if f == int(f) else f"{f}"
|
| 279 |
+
except ValueError:
|
| 280 |
+
return s
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
@dataclass
|
| 284 |
+
class Record:
|
| 285 |
+
question: str
|
| 286 |
+
gold: str
|
| 287 |
+
pred: str
|
| 288 |
+
correct: bool
|
| 289 |
+
solution_text: str
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def _score_model(
|
| 293 |
+
model: AutoModelForCausalLM,
|
| 294 |
+
tokenizer: AutoTokenizer,
|
| 295 |
+
problems: List[Problem],
|
| 296 |
+
max_new_tokens: int,
|
| 297 |
+
device: torch.device,
|
| 298 |
+
label: str,
|
| 299 |
+
) -> Tuple[int, List[Record]]:
|
| 300 |
+
records: List[Record] = []
|
| 301 |
+
correct = 0
|
| 302 |
+
for prob in tqdm(problems, desc=f"Scoring {label}", unit="q", dynamic_ncols=True):
|
| 303 |
+
try:
|
| 304 |
+
text = _generate(model, tokenizer, prob.question, max_new_tokens, device)
|
| 305 |
+
except Exception as exc:
|
| 306 |
+
text = f"[generation error: {exc}]"
|
| 307 |
+
pred = extract_final_answer_numeric_str(text) or ""
|
| 308 |
+
ok = bool(pred) and _normalize(pred) == _normalize(prob.gold_final)
|
| 309 |
+
if ok:
|
| 310 |
+
correct += 1
|
| 311 |
+
records.append(Record(
|
| 312 |
+
question=prob.question,
|
| 313 |
+
gold=prob.gold_final,
|
| 314 |
+
pred=pred,
|
| 315 |
+
correct=ok,
|
| 316 |
+
solution_text=text,
|
| 317 |
+
))
|
| 318 |
+
return correct, records
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
# ---------------------------------------------------------------------------
|
| 322 |
+
# Metrics curve
|
| 323 |
+
# ---------------------------------------------------------------------------
|
| 324 |
+
|
| 325 |
+
def _load_metrics_curve(path: Path) -> List[Dict]:
|
| 326 |
+
"""Read metrics.jsonl and return rows that contain GSM8K accuracy."""
|
| 327 |
+
rows = []
|
| 328 |
+
if not path.exists():
|
| 329 |
+
return rows
|
| 330 |
+
with path.open(encoding="utf-8") as f:
|
| 331 |
+
for line in f:
|
| 332 |
+
line = line.strip()
|
| 333 |
+
if not line:
|
| 334 |
+
continue
|
| 335 |
+
try:
|
| 336 |
+
obj = json.loads(line)
|
| 337 |
+
if "accuracy" in obj or "iteration" in obj:
|
| 338 |
+
rows.append(obj)
|
| 339 |
+
except json.JSONDecodeError:
|
| 340 |
+
pass
|
| 341 |
+
return rows
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def _print_curve(rows: List[Dict]) -> None:
|
| 345 |
+
if not rows:
|
| 346 |
+
return
|
| 347 |
+
print(f"\n{_SEP}")
|
| 348 |
+
print("TRAINING ACCURACY CURVE (from metrics.jsonl)")
|
| 349 |
+
print(_SEP)
|
| 350 |
+
print(f"{'Iter':>5} {'GSM8K%':>7} {'Reward':>7} {'Batch%':>7} {'LR':>10} {'Time(s)':>8}")
|
| 351 |
+
print(_SEP2)
|
| 352 |
+
for r in rows:
|
| 353 |
+
it = r.get("iteration", "")
|
| 354 |
+
acc = r.get("accuracy", None)
|
| 355 |
+
rwd = r.get("mean_reward", None)
|
| 356 |
+
bat = r.get("batch_accuracy", None)
|
| 357 |
+
lr = r.get("learning_rate", None)
|
| 358 |
+
ts = r.get("iter_time_s", None)
|
| 359 |
+
acc_s = f"{100*acc:.1f}%" if acc is not None else "β"
|
| 360 |
+
rwd_s = f"{rwd:.3f}" if rwd is not None else "β"
|
| 361 |
+
bat_s = f"{100*bat:.1f}%" if bat is not None else "β"
|
| 362 |
+
lr_s = f"{lr:.2e}" if lr is not None else "β"
|
| 363 |
+
ts_s = f"{ts:.1f}" if ts is not None else "β"
|
| 364 |
+
print(f"{it:>5} {acc_s:>7} {rwd_s:>7} {bat_s:>7} {lr_s:>10} {ts_s:>8}")
|
| 365 |
+
print()
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
# ---------------------------------------------------------------------------
|
| 369 |
+
# Output
|
| 370 |
+
# ---------------------------------------------------------------------------
|
| 371 |
+
|
| 372 |
+
def _print_summary(
|
| 373 |
+
base_correct: int,
|
| 374 |
+
tr_correct: int,
|
| 375 |
+
base_records: List[Record],
|
| 376 |
+
tr_records: List[Record],
|
| 377 |
+
baseline_name: str,
|
| 378 |
+
trained_name: str,
|
| 379 |
+
n_solutions: int = 3,
|
| 380 |
+
) -> None:
|
| 381 |
+
n = len(base_records)
|
| 382 |
+
wins = [(p, b, t) for p, b, t in zip(base_records, base_records, tr_records) if not b.correct and t.correct]
|
| 383 |
+
losses = [(p, b, t) for p, b, t in zip(base_records, base_records, tr_records) if b.correct and not t.correct]
|
| 384 |
+
both_wrong = sum(1 for b, t in zip(base_records, tr_records) if not b.correct and not t.correct)
|
| 385 |
+
both_right = sum(1 for b, t in zip(base_records, tr_records) if b.correct and t.correct)
|
| 386 |
+
|
| 387 |
+
delta = tr_correct - base_correct
|
| 388 |
+
sign = "+" if delta >= 0 else ""
|
| 389 |
+
|
| 390 |
+
print(f"\n{_SEP}")
|
| 391 |
+
print("BEFORE vs AFTER β GSM8K accuracy (greedy decoding, fixed seed)")
|
| 392 |
+
print(_SEP)
|
| 393 |
+
print(f" Baseline : {baseline_name}")
|
| 394 |
+
print(f" Trained : {trained_name}")
|
| 395 |
+
print(_SEP2)
|
| 396 |
+
print(f" Baseline accuracy : {base_correct}/{n} ({100*base_correct/n:.1f}%)")
|
| 397 |
+
print(f" Trained accuracy : {tr_correct}/{n} ({100*tr_correct/n:.1f}%)")
|
| 398 |
+
print(f" Delta : {sign}{delta} problems ({sign}{100*delta/n:.1f} pp)")
|
| 399 |
+
print(_SEP2)
|
| 400 |
+
print(f" Newly correct (wins) : {len(wins)}")
|
| 401 |
+
print(f" Newly wrong (losses) : {len(losses)}")
|
| 402 |
+
print(f" Both correct : {both_right}")
|
| 403 |
+
print(f" Both wrong : {both_wrong}")
|
| 404 |
+
print(_SEP)
|
| 405 |
+
|
| 406 |
+
if wins:
|
| 407 |
+
print(f"\n{'='*78}")
|
| 408 |
+
print(f"WINS β problems the RL model now solves that the baseline could not")
|
| 409 |
+
print(f"{'='*78}")
|
| 410 |
+
for i, (_, base_r, tr_r) in enumerate(wins[:n_solutions]):
|
| 411 |
+
print(f"\n[Win {i+1}/{min(n_solutions, len(wins))}]")
|
| 412 |
+
_print_problem(base_r, tr_r)
|
| 413 |
+
|
| 414 |
+
if losses:
|
| 415 |
+
print(f"\n{'='*78}")
|
| 416 |
+
print(f"REGRESSIONS β problems the baseline solved but the RL model now misses")
|
| 417 |
+
print(f"{'='*78}")
|
| 418 |
+
for i, (_, base_r, tr_r) in enumerate(losses[:min(2, len(losses))]):
|
| 419 |
+
print(f"\n[Regression {i+1}/{min(2, len(losses))}]")
|
| 420 |
+
_print_problem(base_r, tr_r, is_regression=True)
|
| 421 |
+
|
| 422 |
+
print(f"\n{_SEP}")
|
| 423 |
+
pct_gain = 100 * delta / max(n - base_correct, 1)
|
| 424 |
+
print(f"SUMMARY: RL training fixed {len(wins)} problems, regressed {len(losses)}.")
|
| 425 |
+
print(f" Net: {sign}{delta} pts. Relative gain on previously-wrong: {pct_gain:+.1f}%")
|
| 426 |
+
print(_SEP)
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
def _print_problem(base_r: Record, tr_r: Record, is_regression: bool = False) -> None:
|
| 430 |
+
q = base_r.question
|
| 431 |
+
# Truncate long questions
|
| 432 |
+
if len(q) > 250:
|
| 433 |
+
q = q[:247] + "..."
|
| 434 |
+
print(f" Q : {q}")
|
| 435 |
+
print(f" Gold : {base_r.gold}")
|
| 436 |
+
if not is_regression:
|
| 437 |
+
print(f" Before : {base_r.pred!r:30s} β")
|
| 438 |
+
print(f" After : {tr_r.pred!r:30s} β")
|
| 439 |
+
# Show trained solution (truncated)
|
| 440 |
+
sol = tr_r.solution_text.strip()
|
| 441 |
+
if sol:
|
| 442 |
+
lines = sol.splitlines()
|
| 443 |
+
show = "\n ".join(lines[:12])
|
| 444 |
+
if len(lines) > 12:
|
| 445 |
+
show += f"\n ... ({len(lines)-12} more lines)"
|
| 446 |
+
print(f"\n Solution (trained model):\n {show}")
|
| 447 |
+
else:
|
| 448 |
+
print(f" Before : {base_r.pred!r:30s} β")
|
| 449 |
+
print(f" After : {tr_r.pred!r:30s} β")
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
# ---------------------------------------------------------------------------
|
| 453 |
+
# CLI
|
| 454 |
+
# ---------------------------------------------------------------------------
|
| 455 |
+
|
| 456 |
+
def main() -> int:
|
| 457 |
+
parser = argparse.ArgumentParser(
|
| 458 |
+
description=__doc__,
|
| 459 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 460 |
+
)
|
| 461 |
+
parser.add_argument(
|
| 462 |
+
"--baseline-model", default="checkpoints/dual_task_v1",
|
| 463 |
+
help="Pre-RL checkpoint. HF model ID, full-weight dir, or PEFT adapter dir.",
|
| 464 |
+
)
|
| 465 |
+
parser.add_argument(
|
| 466 |
+
"--trained-model", required=True,
|
| 467 |
+
help="Post-RL checkpoint (GRPO best_policy/ dir, or iteration checkpoint).",
|
| 468 |
+
)
|
| 469 |
+
parser.add_argument(
|
| 470 |
+
"--base-model-for-adapter", default="Qwen/Qwen2.5-Math-1.5B-Instruct",
|
| 471 |
+
help="Base model used when loading a PEFT adapter checkpoint.",
|
| 472 |
+
)
|
| 473 |
+
parser.add_argument(
|
| 474 |
+
"--problems", type=Path, default=Path("data/sft/gsm8k_sft.jsonl"),
|
| 475 |
+
help="JSONL eval set. Defaults to GSM8K training split (first --max-samples rows).",
|
| 476 |
+
)
|
| 477 |
+
parser.add_argument("--max-samples", type=int, default=100)
|
| 478 |
+
parser.add_argument("--max-new-tokens", type=int, default=512)
|
| 479 |
+
parser.add_argument(
|
| 480 |
+
"--metrics-jsonl", type=Path, default=None,
|
| 481 |
+
help="Path to metrics.jsonl from a GRPO run β prints the accuracy curve.",
|
| 482 |
+
)
|
| 483 |
+
parser.add_argument(
|
| 484 |
+
"--n-solutions", type=int, default=3,
|
| 485 |
+
help="Number of win/loss examples to print in full.",
|
| 486 |
+
)
|
| 487 |
+
parser.add_argument(
|
| 488 |
+
"--records-out", type=Path, default=None,
|
| 489 |
+
help="Save full per-problem JSON records here (for judge grading scripts).",
|
| 490 |
+
)
|
| 491 |
+
parser.add_argument(
|
| 492 |
+
"--device", default="cuda" if torch.cuda.is_available() else "cpu",
|
| 493 |
+
)
|
| 494 |
+
parser.add_argument(
|
| 495 |
+
"--dtype", default="bfloat16",
|
| 496 |
+
choices=["float32", "float16", "bfloat16"],
|
| 497 |
+
)
|
| 498 |
+
args = parser.parse_args()
|
| 499 |
+
|
| 500 |
+
if not args.problems.is_file():
|
| 501 |
+
logger.error("Problems file not found: %s", args.problems)
|
| 502 |
+
return 2
|
| 503 |
+
|
| 504 |
+
dtype_map = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16}
|
| 505 |
+
dtype = dtype_map[args.dtype]
|
| 506 |
+
device = torch.device(args.device)
|
| 507 |
+
attn = select_attn_implementation()
|
| 508 |
+
logger.info("Device: %s | dtype: %s | attn: %s", device, args.dtype, attn)
|
| 509 |
+
|
| 510 |
+
# Print training curve if available
|
| 511 |
+
if args.metrics_jsonl:
|
| 512 |
+
curve = _load_metrics_curve(args.metrics_jsonl)
|
| 513 |
+
_print_curve(curve)
|
| 514 |
+
|
| 515 |
+
problems = _load_problems(args.problems, args.max_samples)
|
| 516 |
+
if not problems:
|
| 517 |
+
logger.error("No problems loaded from %s", args.problems)
|
| 518 |
+
return 2
|
| 519 |
+
logger.info("Evaluating on %d problems from %s", len(problems), args.problems)
|
| 520 |
+
|
| 521 |
+
# ββ Baseline ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 522 |
+
logger.info("%s\nScoring BASELINE: %s\n%s", _SEP, args.baseline_model, _SEP)
|
| 523 |
+
t0 = time.perf_counter()
|
| 524 |
+
base_model, base_tok = _load_model(
|
| 525 |
+
args.baseline_model, args.base_model_for_adapter, device, dtype, attn
|
| 526 |
+
)
|
| 527 |
+
base_correct, base_records = _score_model(
|
| 528 |
+
base_model, base_tok, problems, args.max_new_tokens, device, "baseline"
|
| 529 |
+
)
|
| 530 |
+
del base_model
|
| 531 |
+
if torch.cuda.is_available():
|
| 532 |
+
torch.cuda.empty_cache()
|
| 533 |
+
logger.info("Baseline done in %.1fs β accuracy: %d/%d (%.1f%%)",
|
| 534 |
+
time.perf_counter() - t0,
|
| 535 |
+
base_correct, len(problems),
|
| 536 |
+
100 * base_correct / len(problems))
|
| 537 |
+
|
| 538 |
+
# ββ Trained βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 539 |
+
logger.info("%s\nScoring TRAINED: %s\n%s", _SEP, args.trained_model, _SEP)
|
| 540 |
+
t0 = time.perf_counter()
|
| 541 |
+
tr_model, tr_tok = _load_model(
|
| 542 |
+
args.trained_model, args.base_model_for_adapter, device, dtype, attn
|
| 543 |
+
)
|
| 544 |
+
tr_correct, tr_records = _score_model(
|
| 545 |
+
tr_model, tr_tok, problems, args.max_new_tokens, device, "trained"
|
| 546 |
+
)
|
| 547 |
+
del tr_model
|
| 548 |
+
if torch.cuda.is_available():
|
| 549 |
+
torch.cuda.empty_cache()
|
| 550 |
+
logger.info("Trained done in %.1fs β accuracy: %d/%d (%.1f%%)",
|
| 551 |
+
time.perf_counter() - t0,
|
| 552 |
+
tr_correct, len(problems),
|
| 553 |
+
100 * tr_correct / len(problems))
|
| 554 |
+
|
| 555 |
+
# ββ Summary βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 556 |
+
_print_summary(
|
| 557 |
+
base_correct, tr_correct,
|
| 558 |
+
base_records, tr_records,
|
| 559 |
+
baseline_name=args.baseline_model,
|
| 560 |
+
trained_name=args.trained_model,
|
| 561 |
+
n_solutions=args.n_solutions,
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
# ββ Save records ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 565 |
+
if args.records_out:
|
| 566 |
+
args.records_out.parent.mkdir(parents=True, exist_ok=True)
|
| 567 |
+
payload = {
|
| 568 |
+
"baseline_model": args.baseline_model,
|
| 569 |
+
"trained_model": args.trained_model,
|
| 570 |
+
"n_problems": len(problems),
|
| 571 |
+
"baseline": {
|
| 572 |
+
"correct": base_correct,
|
| 573 |
+
"accuracy": base_correct / len(problems),
|
| 574 |
+
"records": [vars(r) for r in base_records],
|
| 575 |
+
},
|
| 576 |
+
"trained": {
|
| 577 |
+
"correct": tr_correct,
|
| 578 |
+
"accuracy": tr_correct / len(problems),
|
| 579 |
+
"records": [vars(r) for r in tr_records],
|
| 580 |
+
},
|
| 581 |
+
}
|
| 582 |
+
args.records_out.write_text(
|
| 583 |
+
json.dumps(payload, indent=2, ensure_ascii=False), encoding="utf-8"
|
| 584 |
+
)
|
| 585 |
+
logger.info("Per-problem records saved to %s", args.records_out)
|
| 586 |
+
|
| 587 |
+
return 0
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
if __name__ == "__main__":
|
| 591 |
+
sys.exit(main())
|
scripts/dual_task_sft_pipeline.py
ADDED
|
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Dual-task SFT pipeline: train model on both question generation and solution tasks.
|
| 3 |
+
|
| 4 |
+
This pipeline trains a single model that can:
|
| 5 |
+
1. Generate math questions when prompted with "### Task: Generate Question"
|
| 6 |
+
2. Solve math problems when prompted with "### Task: Solve Problem"
|
| 7 |
+
|
| 8 |
+
Examples
|
| 9 |
+
--------
|
| 10 |
+
# Train dual-task model
|
| 11 |
+
python scripts/dual_task_sft_pipeline.py train \\
|
| 12 |
+
--data data/sft/dual_task_train.jsonl \\
|
| 13 |
+
--output-dir checkpoints/dual_task_v1 \\
|
| 14 |
+
--epochs 2
|
| 15 |
+
|
| 16 |
+
# Infer - Question Generation
|
| 17 |
+
python scripts/dual_task_sft_pipeline.py infer \\
|
| 18 |
+
--adapter checkpoints/dual_task_v1 \\
|
| 19 |
+
--task generate \\
|
| 20 |
+
--prompt "Create a word problem about fractions and money requiring 3 steps."
|
| 21 |
+
|
| 22 |
+
# Infer - Solution Generation
|
| 23 |
+
python scripts/dual_task_sft_pipeline.py infer \\
|
| 24 |
+
--adapter checkpoints/dual_task_v1 \\
|
| 25 |
+
--task solve \\
|
| 26 |
+
--problem "Janet has 16 eggs. She eats 3. How many are left?"
|
| 27 |
+
|
| 28 |
+
Dependencies: torch, transformers, peft, datasets, accelerate, bitsandbytes, trl
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
from __future__ import annotations
|
| 32 |
+
|
| 33 |
+
import os
|
| 34 |
+
|
| 35 |
+
if "HF_HUB_DISABLE_XET" not in os.environ:
|
| 36 |
+
os.environ["HF_HUB_DISABLE_XET"] = "1"
|
| 37 |
+
|
| 38 |
+
import argparse
|
| 39 |
+
import json
|
| 40 |
+
import math
|
| 41 |
+
import sys
|
| 42 |
+
from pathlib import Path
|
| 43 |
+
|
| 44 |
+
ROOT = Path(__file__).resolve().parents[1]
|
| 45 |
+
sys.path.insert(0, str(ROOT))
|
| 46 |
+
|
| 47 |
+
from src.config.prompts import (
|
| 48 |
+
SOLVE_TASK_PREFIX,
|
| 49 |
+
GENERATE_TASK_PREFIX,
|
| 50 |
+
SOLVER_SYSTEM_PROMPT,
|
| 51 |
+
GENERATOR_SYSTEM_PROMPT,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _warmup_steps_from_ratio(
|
| 56 |
+
num_examples: int,
|
| 57 |
+
per_device_train_batch_size: int,
|
| 58 |
+
gradient_accumulation_steps: int,
|
| 59 |
+
num_train_epochs: float,
|
| 60 |
+
warmup_ratio: float,
|
| 61 |
+
) -> int:
|
| 62 |
+
"""Calculate warmup steps from ratio."""
|
| 63 |
+
if warmup_ratio <= 0:
|
| 64 |
+
return 0
|
| 65 |
+
num_batches = max(
|
| 66 |
+
1,
|
| 67 |
+
(num_examples + per_device_train_batch_size - 1) // per_device_train_batch_size,
|
| 68 |
+
)
|
| 69 |
+
num_update_steps_per_epoch = max(1, num_batches // gradient_accumulation_steps)
|
| 70 |
+
total_optimizer_steps = max(1, math.ceil(num_train_epochs * num_update_steps_per_epoch))
|
| 71 |
+
return min(total_optimizer_steps, int(total_optimizer_steps * warmup_ratio))
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def cmd_train(args: argparse.Namespace) -> None:
|
| 75 |
+
try:
|
| 76 |
+
import torch
|
| 77 |
+
from datasets import load_dataset
|
| 78 |
+
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
| 79 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 80 |
+
from trl import SFTConfig, SFTTrainer
|
| 81 |
+
except ImportError as e:
|
| 82 |
+
raise SystemExit(
|
| 83 |
+
"Missing dependency for training. Install:\n"
|
| 84 |
+
" pip install torch transformers peft datasets accelerate bitsandbytes trl\n"
|
| 85 |
+
f"Original error: {e}"
|
| 86 |
+
) from e
|
| 87 |
+
|
| 88 |
+
data_path = Path(args.data)
|
| 89 |
+
if not data_path.is_file():
|
| 90 |
+
raise SystemExit(f"Data file not found: {data_path}")
|
| 91 |
+
|
| 92 |
+
out_dir = Path(args.output_dir)
|
| 93 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 94 |
+
|
| 95 |
+
compute_dtype = getattr(torch, args.bnb_compute_dtype)
|
| 96 |
+
bnb_config = BitsAndBytesConfig(
|
| 97 |
+
load_in_4bit=True,
|
| 98 |
+
bnb_4bit_compute_dtype=compute_dtype,
|
| 99 |
+
bnb_4bit_quant_type="nf4",
|
| 100 |
+
bnb_4bit_use_double_quant=True,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
|
| 104 |
+
if tokenizer.pad_token is None:
|
| 105 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 106 |
+
tokenizer.padding_side = "right"
|
| 107 |
+
|
| 108 |
+
print(f"Loading model {args.model} β¦")
|
| 109 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 110 |
+
args.model,
|
| 111 |
+
quantization_config=bnb_config,
|
| 112 |
+
device_map="auto",
|
| 113 |
+
trust_remote_code=True,
|
| 114 |
+
dtype=compute_dtype,
|
| 115 |
+
)
|
| 116 |
+
model = prepare_model_for_kbit_training(model)
|
| 117 |
+
|
| 118 |
+
peft = LoraConfig(
|
| 119 |
+
r=args.lora_rank,
|
| 120 |
+
lora_alpha=args.lora_alpha,
|
| 121 |
+
lora_dropout=args.lora_dropout,
|
| 122 |
+
bias="none",
|
| 123 |
+
task_type="CAUSAL_LM",
|
| 124 |
+
target_modules=list(args.target_modules.split(",")),
|
| 125 |
+
)
|
| 126 |
+
model = get_peft_model(model, peft)
|
| 127 |
+
model.config.use_cache = False
|
| 128 |
+
model.print_trainable_parameters()
|
| 129 |
+
|
| 130 |
+
print(f"Loading dual-task dataset from {data_path} β¦")
|
| 131 |
+
ds = load_dataset("json", data_files=str(data_path), split="train")
|
| 132 |
+
if args.max_samples and args.max_samples > 0:
|
| 133 |
+
ds = ds.select(range(min(args.max_samples, len(ds))))
|
| 134 |
+
|
| 135 |
+
task_counts = {"solve": 0, "generate": 0, "unknown": 0}
|
| 136 |
+
for example in ds:
|
| 137 |
+
task_type = example.get("task_type", "unknown")
|
| 138 |
+
task_counts[task_type] = task_counts.get(task_type, 0) + 1
|
| 139 |
+
|
| 140 |
+
print(f"Dataset composition:")
|
| 141 |
+
print(f" Total examples: {len(ds)}")
|
| 142 |
+
print(f" Solve tasks: {task_counts['solve']} ({task_counts['solve']/len(ds):.1%})")
|
| 143 |
+
print(f" Generate tasks: {task_counts['generate']} ({task_counts['generate']/len(ds):.1%})")
|
| 144 |
+
if task_counts['unknown'] > 0:
|
| 145 |
+
print(f" Unknown tasks: {task_counts['unknown']}")
|
| 146 |
+
|
| 147 |
+
def formatting_func(example):
|
| 148 |
+
return tokenizer.apply_chat_template(
|
| 149 |
+
example["messages"],
|
| 150 |
+
tokenize=False,
|
| 151 |
+
add_generation_prompt=False,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
if args.warmup_steps is not None:
|
| 155 |
+
warmup_steps = max(0, args.warmup_steps)
|
| 156 |
+
else:
|
| 157 |
+
warmup_steps = _warmup_steps_from_ratio(
|
| 158 |
+
len(ds),
|
| 159 |
+
args.batch_size,
|
| 160 |
+
args.grad_accum,
|
| 161 |
+
args.epochs,
|
| 162 |
+
args.warmup_ratio,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
sft_args = SFTConfig(
|
| 166 |
+
output_dir=str(out_dir),
|
| 167 |
+
num_train_epochs=args.epochs,
|
| 168 |
+
per_device_train_batch_size=args.batch_size,
|
| 169 |
+
gradient_accumulation_steps=args.grad_accum,
|
| 170 |
+
learning_rate=args.learning_rate,
|
| 171 |
+
logging_steps=args.logging_steps,
|
| 172 |
+
save_steps=args.save_steps,
|
| 173 |
+
save_total_limit=3,
|
| 174 |
+
bf16=args.bf16 and torch.cuda.is_available(),
|
| 175 |
+
fp16=args.fp16 and torch.cuda.is_available() and not args.bf16,
|
| 176 |
+
max_length=args.max_seq_length,
|
| 177 |
+
warmup_steps=warmup_steps,
|
| 178 |
+
lr_scheduler_type="cosine",
|
| 179 |
+
report_to="none",
|
| 180 |
+
gradient_checkpointing=True,
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
print("\nStarting dual-task training...")
|
| 184 |
+
trainer = SFTTrainer(
|
| 185 |
+
model=model,
|
| 186 |
+
args=sft_args,
|
| 187 |
+
train_dataset=ds,
|
| 188 |
+
processing_class=tokenizer,
|
| 189 |
+
formatting_func=formatting_func,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
trainer.train()
|
| 193 |
+
trainer.save_model(str(out_dir))
|
| 194 |
+
tokenizer.save_pretrained(str(out_dir))
|
| 195 |
+
|
| 196 |
+
with (out_dir / "pipeline_meta.json").open("w", encoding="utf-8") as f:
|
| 197 |
+
json.dump(
|
| 198 |
+
{
|
| 199 |
+
"pipeline_type": "dual_task",
|
| 200 |
+
"base_model": args.model,
|
| 201 |
+
"data": str(data_path),
|
| 202 |
+
"lora_rank": args.lora_rank,
|
| 203 |
+
"epochs": args.epochs,
|
| 204 |
+
"task_distribution": task_counts,
|
| 205 |
+
},
|
| 206 |
+
f,
|
| 207 |
+
indent=2,
|
| 208 |
+
)
|
| 209 |
+
print(f"\nSaved dual-task adapter and tokenizer to {out_dir}")
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def cmd_infer(args: argparse.Namespace) -> None:
|
| 213 |
+
import torch
|
| 214 |
+
from peft import PeftModel
|
| 215 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 216 |
+
|
| 217 |
+
adapter = Path(args.adapter)
|
| 218 |
+
meta_path = adapter / "pipeline_meta.json"
|
| 219 |
+
base_model = args.base_model
|
| 220 |
+
|
| 221 |
+
if meta_path.is_file():
|
| 222 |
+
meta = json.loads(meta_path.read_text(encoding="utf-8"))
|
| 223 |
+
base_model = meta.get("base_model", base_model)
|
| 224 |
+
pipeline_type = meta.get("pipeline_type", "unknown")
|
| 225 |
+
if pipeline_type != "dual_task":
|
| 226 |
+
print(f"Warning: Adapter trained with pipeline_type='{pipeline_type}', expected 'dual_task'")
|
| 227 |
+
|
| 228 |
+
compute_dtype = getattr(torch, args.bnb_compute_dtype)
|
| 229 |
+
bnb_config = BitsAndBytesConfig(
|
| 230 |
+
load_in_4bit=True,
|
| 231 |
+
bnb_4bit_compute_dtype=compute_dtype,
|
| 232 |
+
bnb_4bit_quant_type="nf4",
|
| 233 |
+
bnb_4bit_use_double_quant=True,
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
tokenizer = AutoTokenizer.from_pretrained(adapter, trust_remote_code=True)
|
| 237 |
+
if tokenizer.pad_token is None:
|
| 238 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 239 |
+
|
| 240 |
+
print(f"Loading base {base_model} + adapter {adapter} β¦")
|
| 241 |
+
base = AutoModelForCausalLM.from_pretrained(
|
| 242 |
+
base_model,
|
| 243 |
+
quantization_config=bnb_config,
|
| 244 |
+
device_map="auto",
|
| 245 |
+
trust_remote_code=True,
|
| 246 |
+
)
|
| 247 |
+
model = PeftModel.from_pretrained(base, str(adapter))
|
| 248 |
+
model.eval()
|
| 249 |
+
|
| 250 |
+
if args.task == "solve":
|
| 251 |
+
system_prompt = SOLVER_SYSTEM_PROMPT
|
| 252 |
+
user_content = (
|
| 253 |
+
f"{SOLVE_TASK_PREFIX}"
|
| 254 |
+
"Solve the following problem. Show your reasoning as numbered steps, "
|
| 255 |
+
"then give the final numeric answer on the last line.\n\n"
|
| 256 |
+
f"Problem:\n{args.problem.strip()}"
|
| 257 |
+
)
|
| 258 |
+
elif args.task == "generate":
|
| 259 |
+
system_prompt = GENERATOR_SYSTEM_PROMPT
|
| 260 |
+
user_content = f"{GENERATE_TASK_PREFIX}{args.prompt.strip()}"
|
| 261 |
+
else:
|
| 262 |
+
raise ValueError(f"Unknown task: {args.task}. Must be 'solve' or 'generate'")
|
| 263 |
+
|
| 264 |
+
messages = [
|
| 265 |
+
{"role": "system", "content": system_prompt},
|
| 266 |
+
{"role": "user", "content": user_content},
|
| 267 |
+
]
|
| 268 |
+
|
| 269 |
+
prompt = tokenizer.apply_chat_template(
|
| 270 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 271 |
+
)
|
| 272 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 273 |
+
|
| 274 |
+
print(f"\nTask: {args.task}")
|
| 275 |
+
print(f"Prompt length: {inputs['input_ids'].shape[1]} tokens")
|
| 276 |
+
print("\nGenerating...")
|
| 277 |
+
|
| 278 |
+
with torch.no_grad():
|
| 279 |
+
out = model.generate(
|
| 280 |
+
**inputs,
|
| 281 |
+
max_new_tokens=args.max_new_tokens,
|
| 282 |
+
temperature=args.temperature,
|
| 283 |
+
top_p=args.top_p,
|
| 284 |
+
do_sample=not args.greedy,
|
| 285 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
gen_ids = out[0, inputs["input_ids"].shape[1] :]
|
| 289 |
+
text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
|
| 290 |
+
|
| 291 |
+
print("\n" + "=" * 60)
|
| 292 |
+
print("Generated Output")
|
| 293 |
+
print("=" * 60)
|
| 294 |
+
print(text)
|
| 295 |
+
print("=" * 60)
|
| 296 |
+
|
| 297 |
+
if args.task == "solve":
|
| 298 |
+
print("\n--- Format Validation ---")
|
| 299 |
+
from src.sft.solution_format import validate_sympy_solution_format
|
| 300 |
+
r = validate_sympy_solution_format(text)
|
| 301 |
+
print(json.dumps(r.__dict__, indent=2))
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def build_parser() -> argparse.ArgumentParser:
|
| 305 |
+
p = argparse.ArgumentParser(description="Dual-task SFT pipeline (train / infer)")
|
| 306 |
+
sub = p.add_subparsers(dest="command", required=True)
|
| 307 |
+
|
| 308 |
+
tr = sub.add_parser("train", help="Train dual-task model on mixed dataset")
|
| 309 |
+
tr.add_argument("--data", type=str, required=True, help="Dual-task training JSONL")
|
| 310 |
+
tr.add_argument("--output-dir", type=str, required=True, help="Output directory for adapter")
|
| 311 |
+
tr.add_argument("--model", type=str, default="Qwen/Qwen2.5-Math-1.5B-Instruct", help="Base model")
|
| 312 |
+
tr.add_argument("--epochs", type=float, default=2.0, help="Training epochs (default: 2.0 for dual-task)")
|
| 313 |
+
tr.add_argument("--batch-size", type=int, default=1)
|
| 314 |
+
tr.add_argument("--grad-accum", type=int, default=8)
|
| 315 |
+
tr.add_argument("--learning-rate", type=float, default=2e-4)
|
| 316 |
+
tr.add_argument("--max-samples", type=int, default=0, help="0 = use full dataset")
|
| 317 |
+
tr.add_argument("--lora-rank", type=int, default=16)
|
| 318 |
+
tr.add_argument("--lora-alpha", type=int, default=32)
|
| 319 |
+
tr.add_argument("--lora-dropout", type=float, default=0.05)
|
| 320 |
+
tr.add_argument(
|
| 321 |
+
"--target-modules",
|
| 322 |
+
type=str,
|
| 323 |
+
default="q_proj,v_proj,o_proj,gate_proj",
|
| 324 |
+
)
|
| 325 |
+
tr.add_argument("--max-seq-length", type=int, default=2048)
|
| 326 |
+
tr.add_argument("--save-steps", type=int, default=200)
|
| 327 |
+
tr.add_argument("--logging-steps", type=int, default=10)
|
| 328 |
+
tr.add_argument("--warmup-ratio", type=float, default=0.03)
|
| 329 |
+
tr.add_argument("--warmup-steps", type=int, default=None)
|
| 330 |
+
tr.add_argument("--bf16", action="store_true", default=True)
|
| 331 |
+
tr.add_argument("--no-bf16", dest="bf16", action="store_false")
|
| 332 |
+
tr.add_argument("--fp16", action="store_true")
|
| 333 |
+
tr.add_argument("--bnb-compute-dtype", type=str, default="bfloat16")
|
| 334 |
+
tr.set_defaults(func=cmd_train)
|
| 335 |
+
|
| 336 |
+
inf = sub.add_parser("infer", help="Generate with dual-task model")
|
| 337 |
+
inf.add_argument("--adapter", type=str, required=True, help="Adapter directory")
|
| 338 |
+
inf.add_argument(
|
| 339 |
+
"--base-model",
|
| 340 |
+
type=str,
|
| 341 |
+
default="Qwen/Qwen2.5-Math-1.5B-Instruct",
|
| 342 |
+
help="Base model (auto-detected from pipeline_meta.json if present)",
|
| 343 |
+
)
|
| 344 |
+
inf.add_argument(
|
| 345 |
+
"--task",
|
| 346 |
+
type=str,
|
| 347 |
+
required=True,
|
| 348 |
+
choices=["solve", "generate"],
|
| 349 |
+
help="Task type: 'solve' for problem solving, 'generate' for question generation",
|
| 350 |
+
)
|
| 351 |
+
inf.add_argument(
|
| 352 |
+
"--problem",
|
| 353 |
+
type=str,
|
| 354 |
+
default="",
|
| 355 |
+
help="Math problem to solve (required if --task solve)",
|
| 356 |
+
)
|
| 357 |
+
inf.add_argument(
|
| 358 |
+
"--prompt",
|
| 359 |
+
type=str,
|
| 360 |
+
default="",
|
| 361 |
+
help="Question generation prompt (required if --task generate)",
|
| 362 |
+
)
|
| 363 |
+
inf.add_argument("--max-new-tokens", type=int, default=1024)
|
| 364 |
+
inf.add_argument("--temperature", type=float, default=0.7)
|
| 365 |
+
inf.add_argument("--top-p", type=float, default=0.95)
|
| 366 |
+
inf.add_argument("--greedy", action="store_true", help="Use greedy decoding")
|
| 367 |
+
inf.add_argument("--bnb-compute-dtype", type=str, default="bfloat16")
|
| 368 |
+
inf.set_defaults(func=cmd_infer)
|
| 369 |
+
|
| 370 |
+
return p
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
def main() -> None:
|
| 374 |
+
parser = build_parser()
|
| 375 |
+
args = parser.parse_args()
|
| 376 |
+
|
| 377 |
+
if args.command == "infer":
|
| 378 |
+
if args.task == "solve" and not args.problem:
|
| 379 |
+
raise SystemExit("Error: --problem is required when --task solve")
|
| 380 |
+
if args.task == "generate" and not args.prompt:
|
| 381 |
+
raise SystemExit("Error: --prompt is required when --task generate")
|
| 382 |
+
|
| 383 |
+
if str(ROOT) not in sys.path:
|
| 384 |
+
sys.path.insert(0, str(ROOT))
|
| 385 |
+
|
| 386 |
+
args.func(args)
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
if __name__ == "__main__":
|
| 390 |
+
main()
|
scripts/eval_sft_inference.py
ADDED
|
@@ -0,0 +1,565 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Run batch inference for a trained QLoRA adapter and report quality metrics.
|
| 4 |
+
|
| 5 |
+
This helps decide whether another SFT epoch is needed before RL.
|
| 6 |
+
|
| 7 |
+
Examples
|
| 8 |
+
--------
|
| 9 |
+
# Evaluate on GSM8K test split (first 100 samples)
|
| 10 |
+
python scripts/eval_sft_inference.py \
|
| 11 |
+
--adapter checkpoints/gsm8k_sft \
|
| 12 |
+
--max-samples 100
|
| 13 |
+
|
| 14 |
+
# Evaluate on local JSONL with {question, answer} rows
|
| 15 |
+
python scripts/eval_sft_inference.py \
|
| 16 |
+
--adapter checkpoints/gsm8k_sft \
|
| 17 |
+
--source jsonl \
|
| 18 |
+
--input data/raw/gsm8k_test.jsonl \
|
| 19 |
+
--max-samples 50 \
|
| 20 |
+
--output-json reports/sft_eval.json
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
from __future__ import annotations
|
| 24 |
+
|
| 25 |
+
import argparse
|
| 26 |
+
import json
|
| 27 |
+
import os
|
| 28 |
+
import re
|
| 29 |
+
import sys
|
| 30 |
+
from dataclasses import asdict, dataclass
|
| 31 |
+
from pathlib import Path
|
| 32 |
+
from typing import Any, Optional
|
| 33 |
+
|
| 34 |
+
# Prefer classic HTTP Hub downloads by default.
|
| 35 |
+
if "HF_HUB_DISABLE_XET" not in os.environ:
|
| 36 |
+
os.environ["HF_HUB_DISABLE_XET"] = "1"
|
| 37 |
+
|
| 38 |
+
# Ensure project-root imports work when invoked as `python scripts/...`.
|
| 39 |
+
ROOT = Path(__file__).resolve().parents[1]
|
| 40 |
+
if str(ROOT) not in sys.path:
|
| 41 |
+
sys.path.insert(0, str(ROOT))
|
| 42 |
+
|
| 43 |
+
import torch
|
| 44 |
+
from datasets import load_dataset
|
| 45 |
+
from peft import PeftModel
|
| 46 |
+
from sympy import simplify
|
| 47 |
+
from sympy.parsing.sympy_parser import parse_expr
|
| 48 |
+
from tqdm.auto import tqdm
|
| 49 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 50 |
+
|
| 51 |
+
from scripts.convert_gsm8k_to_sft import parse_gsm8k_answer
|
| 52 |
+
from src.config.prompts import create_solver_messages
|
| 53 |
+
from src.sft.solution_format import extract_final_answer_numeric_str, validate_sympy_solution_format
|
| 54 |
+
from src.sft.sympy_normalize import normalize_for_parse_expr
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@dataclass
|
| 58 |
+
class EvalRow:
|
| 59 |
+
index: int
|
| 60 |
+
question: str
|
| 61 |
+
gold_final: str
|
| 62 |
+
pred_final: str
|
| 63 |
+
exact_match: Optional[bool]
|
| 64 |
+
format_ok: bool
|
| 65 |
+
step_count: int
|
| 66 |
+
scratchpad_leak: bool
|
| 67 |
+
output_text: str
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _norm_expr(s: str) -> str:
|
| 71 |
+
s = s.strip()
|
| 72 |
+
s = s.replace("^", "**")
|
| 73 |
+
s = re.sub(r"[,$β¬Β£\s]+", "", s)
|
| 74 |
+
return s
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def _equiv_expr(a: str, b: str) -> Optional[bool]:
|
| 78 |
+
"""Check if two answer strings are mathematically equivalent.
|
| 79 |
+
|
| 80 |
+
Uses the same normalization as CurriculumMathEnvironment._answers_equivalent
|
| 81 |
+
so eval and training agree on what counts as "correct".
|
| 82 |
+
"""
|
| 83 |
+
if not a or not b:
|
| 84 |
+
return None
|
| 85 |
+
a_n = normalize_for_parse_expr(_norm_expr(a))
|
| 86 |
+
b_n = normalize_for_parse_expr(_norm_expr(b))
|
| 87 |
+
try:
|
| 88 |
+
return bool(simplify(parse_expr(a_n) - parse_expr(b_n)) == 0)
|
| 89 |
+
except Exception:
|
| 90 |
+
return a_n == b_n
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def _iter_examples(args: argparse.Namespace) -> list[dict[str, str]]:
|
| 94 |
+
rows: list[dict[str, str]] = []
|
| 95 |
+
if args.source == "hf":
|
| 96 |
+
ds = load_dataset(args.dataset, args.config, split=args.split)
|
| 97 |
+
if args.max_samples > 0:
|
| 98 |
+
ds = ds.select(range(min(args.max_samples, len(ds))))
|
| 99 |
+
for row in ds:
|
| 100 |
+
_, final = parse_gsm8k_answer(row["answer"])
|
| 101 |
+
rows.append({"question": row["question"].strip(), "gold_final": final})
|
| 102 |
+
return rows
|
| 103 |
+
|
| 104 |
+
in_path = Path(args.input)
|
| 105 |
+
if not in_path.is_file():
|
| 106 |
+
raise SystemExit(f"Input JSONL not found: {in_path}")
|
| 107 |
+
with in_path.open(encoding="utf-8") as f:
|
| 108 |
+
for line in f:
|
| 109 |
+
if args.max_samples > 0 and len(rows) >= args.max_samples:
|
| 110 |
+
break
|
| 111 |
+
line = line.strip()
|
| 112 |
+
if not line:
|
| 113 |
+
continue
|
| 114 |
+
o = json.loads(line)
|
| 115 |
+
if "question" in o and "answer" in o:
|
| 116 |
+
_, final = parse_gsm8k_answer(o["answer"])
|
| 117 |
+
rows.append({"question": o["question"].strip(), "gold_final": final})
|
| 118 |
+
continue
|
| 119 |
+
if "messages" in o:
|
| 120 |
+
user = next((m["content"] for m in o["messages"] if m.get("role") == "user"), "").strip()
|
| 121 |
+
asst = next((m["content"] for m in o["messages"] if m.get("role") == "assistant"), "")
|
| 122 |
+
gold = extract_final_answer_numeric_str(asst) or ""
|
| 123 |
+
user = re.sub(r"^Solve the following problem\..*?Problem:\n", "", user, flags=re.S)
|
| 124 |
+
rows.append({"question": user.strip(), "gold_final": gold.strip()})
|
| 125 |
+
continue
|
| 126 |
+
raise SystemExit("JSONL rows must contain either {question, answer} or {messages}.")
|
| 127 |
+
return rows
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def _generate(
|
| 131 |
+
model: Any,
|
| 132 |
+
tokenizer: Any,
|
| 133 |
+
problem: str,
|
| 134 |
+
max_new_tokens: int,
|
| 135 |
+
temperature: float,
|
| 136 |
+
top_p: float,
|
| 137 |
+
greedy: bool,
|
| 138 |
+
) -> str:
|
| 139 |
+
# Use the canonical solver prompt (same system + user format as GRPO training)
|
| 140 |
+
# so eval measures the model under the exact distribution it was trained on.
|
| 141 |
+
messages = create_solver_messages(problem.strip())
|
| 142 |
+
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 143 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 144 |
+
# HuggingFace warns once-per-call when `temperature`/`top_p` are passed
|
| 145 |
+
# alongside `do_sample=False`. Skip those kwargs entirely in greedy mode
|
| 146 |
+
# so long eval loops don't spam the log.
|
| 147 |
+
gen_kwargs = {
|
| 148 |
+
"max_new_tokens": max_new_tokens,
|
| 149 |
+
"do_sample": not greedy,
|
| 150 |
+
"pad_token_id": tokenizer.pad_token_id,
|
| 151 |
+
}
|
| 152 |
+
if not greedy:
|
| 153 |
+
gen_kwargs["temperature"] = temperature
|
| 154 |
+
gen_kwargs["top_p"] = top_p
|
| 155 |
+
with torch.no_grad():
|
| 156 |
+
out = model.generate(**inputs, **gen_kwargs)
|
| 157 |
+
gen_ids = out[0, inputs["input_ids"].shape[1] :]
|
| 158 |
+
return tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def main() -> None:
|
| 162 |
+
p = argparse.ArgumentParser(description="Batch eval for SFT adapter inference.")
|
| 163 |
+
p.add_argument("--adapter", type=Path, required=True, help="Adapter directory from training step.")
|
| 164 |
+
p.add_argument("--base-model", type=str, default="Qwen/Qwen2.5-Math-1.5B-Instruct")
|
| 165 |
+
p.add_argument("--source", choices=("hf", "jsonl"), default="hf")
|
| 166 |
+
p.add_argument("--dataset", type=str, default="openai/gsm8k")
|
| 167 |
+
p.add_argument("--config", type=str, default="main")
|
| 168 |
+
p.add_argument("--split", type=str, default="test")
|
| 169 |
+
p.add_argument("--input", type=Path, help="JSONL path for --source jsonl")
|
| 170 |
+
p.add_argument("--max-samples", type=int, default=100)
|
| 171 |
+
p.add_argument("--max-new-tokens", type=int, default=512)
|
| 172 |
+
p.add_argument("--temperature", type=float, default=0.0)
|
| 173 |
+
p.add_argument("--top-p", type=float, default=1.0)
|
| 174 |
+
p.add_argument("--greedy", action="store_true", default=True)
|
| 175 |
+
p.add_argument("--no-greedy", dest="greedy", action="store_false")
|
| 176 |
+
p.add_argument("--bnb-compute-dtype", type=str, default="bfloat16")
|
| 177 |
+
p.add_argument("--show-samples", type=int, default=3)
|
| 178 |
+
p.add_argument("--output-json", type=Path, default=None)
|
| 179 |
+
args = p.parse_args()
|
| 180 |
+
|
| 181 |
+
if args.source == "jsonl" and not args.input:
|
| 182 |
+
raise SystemExit("--input is required when --source jsonl")
|
| 183 |
+
|
| 184 |
+
meta_path = args.adapter / "pipeline_meta.json"
|
| 185 |
+
base_model = args.base_model
|
| 186 |
+
if meta_path.is_file():
|
| 187 |
+
meta = json.loads(meta_path.read_text(encoding="utf-8"))
|
| 188 |
+
base_model = meta.get("base_model", base_model)
|
| 189 |
+
|
| 190 |
+
rows = _iter_examples(args)
|
| 191 |
+
if not rows:
|
| 192 |
+
raise SystemExit("No evaluation examples loaded.")
|
| 193 |
+
print(f"Loaded {len(rows)} evaluation examples.")
|
| 194 |
+
|
| 195 |
+
compute_dtype = getattr(torch, args.bnb_compute_dtype)
|
| 196 |
+
bnb_config = BitsAndBytesConfig(
|
| 197 |
+
load_in_4bit=True,
|
| 198 |
+
bnb_4bit_compute_dtype=compute_dtype,
|
| 199 |
+
bnb_4bit_quant_type="nf4",
|
| 200 |
+
bnb_4bit_use_double_quant=True,
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
print(f"Loading base {base_model} + adapter {args.adapter} β¦")
|
| 204 |
+
tokenizer = AutoTokenizer.from_pretrained(args.adapter, trust_remote_code=True)
|
| 205 |
+
if tokenizer.pad_token is None:
|
| 206 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 207 |
+
base = AutoModelForCausalLM.from_pretrained(
|
| 208 |
+
base_model,
|
| 209 |
+
quantization_config=bnb_config,
|
| 210 |
+
device_map="auto",
|
| 211 |
+
trust_remote_code=True,
|
| 212 |
+
)
|
| 213 |
+
model = PeftModel.from_pretrained(base, str(args.adapter))
|
| 214 |
+
model.eval()
|
| 215 |
+
|
| 216 |
+
results: list[EvalRow] = []
|
| 217 |
+
for i, row in enumerate(rows):
|
| 218 |
+
text = _generate(
|
| 219 |
+
model=model,
|
| 220 |
+
tokenizer=tokenizer,
|
| 221 |
+
problem=row["question"],
|
| 222 |
+
max_new_tokens=args.max_new_tokens,
|
| 223 |
+
temperature=args.temperature,
|
| 224 |
+
top_p=args.top_p,
|
| 225 |
+
greedy=args.greedy,
|
| 226 |
+
)
|
| 227 |
+
fmt = validate_sympy_solution_format(text)
|
| 228 |
+
pred_final = extract_final_answer_numeric_str(text) or ""
|
| 229 |
+
exact = _equiv_expr(pred_final, row["gold_final"])
|
| 230 |
+
results.append(
|
| 231 |
+
EvalRow(
|
| 232 |
+
index=i,
|
| 233 |
+
question=row["question"],
|
| 234 |
+
gold_final=row["gold_final"],
|
| 235 |
+
pred_final=pred_final,
|
| 236 |
+
exact_match=exact,
|
| 237 |
+
format_ok=fmt.ok,
|
| 238 |
+
step_count=fmt.step_count,
|
| 239 |
+
scratchpad_leak=("<<" in text and ">>" in text),
|
| 240 |
+
output_text=text,
|
| 241 |
+
)
|
| 242 |
+
)
|
| 243 |
+
if i < args.show_samples:
|
| 244 |
+
print(f"\n=== Sample {i} ===")
|
| 245 |
+
print("Q:", row["question"])
|
| 246 |
+
print("Gold:", row["gold_final"])
|
| 247 |
+
print("Pred:", pred_final)
|
| 248 |
+
print("Format OK:", fmt.ok, "| Steps:", fmt.step_count)
|
| 249 |
+
print(text)
|
| 250 |
+
|
| 251 |
+
n = len(results)
|
| 252 |
+
n_format_ok = sum(1 for r in results if r.format_ok)
|
| 253 |
+
n_scratch = sum(1 for r in results if r.scratchpad_leak)
|
| 254 |
+
em_scored = [r for r in results if r.exact_match is not None]
|
| 255 |
+
n_em = sum(1 for r in em_scored if r.exact_match)
|
| 256 |
+
|
| 257 |
+
print("\n=== Summary ===")
|
| 258 |
+
print(f"Samples: {n}")
|
| 259 |
+
print(f"Format OK: {n_format_ok}/{n} ({100.0 * n_format_ok / n:.2f}%)")
|
| 260 |
+
print(f"Scratchpad leakage (<< >>): {n_scratch}/{n} ({100.0 * n_scratch / n:.2f}%)")
|
| 261 |
+
if em_scored:
|
| 262 |
+
print(f"Exact match (final answer): {n_em}/{len(em_scored)} ({100.0 * n_em / len(em_scored):.2f}%)")
|
| 263 |
+
else:
|
| 264 |
+
print("Exact match (final answer): N/A (missing gold labels)")
|
| 265 |
+
|
| 266 |
+
if args.output_json is not None:
|
| 267 |
+
args.output_json.parent.mkdir(parents=True, exist_ok=True)
|
| 268 |
+
payload = {
|
| 269 |
+
"summary": {
|
| 270 |
+
"samples": n,
|
| 271 |
+
"format_ok": n_format_ok,
|
| 272 |
+
"format_ok_rate": n_format_ok / n,
|
| 273 |
+
"scratchpad_leakage": n_scratch,
|
| 274 |
+
"scratchpad_leakage_rate": n_scratch / n,
|
| 275 |
+
"exact_match_scored": len(em_scored),
|
| 276 |
+
"exact_match": n_em,
|
| 277 |
+
"exact_match_rate": (n_em / len(em_scored)) if em_scored else None,
|
| 278 |
+
},
|
| 279 |
+
"results": [asdict(r) for r in results],
|
| 280 |
+
}
|
| 281 |
+
args.output_json.write_text(json.dumps(payload, indent=2, ensure_ascii=False), encoding="utf-8")
|
| 282 |
+
print(f"Wrote detailed report to {args.output_json}")
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def _infer_dataset_name(data_path: str) -> str:
|
| 286 |
+
"""Derive a short human-readable dataset label from the file path."""
|
| 287 |
+
stem = Path(data_path).stem.lower() # e.g. "aqua_validation", "gsm8k_test"
|
| 288 |
+
if "aqua" in stem:
|
| 289 |
+
return "AQuA-RAT"
|
| 290 |
+
if "math" in stem:
|
| 291 |
+
return "MATH"
|
| 292 |
+
if "gsm" in stem:
|
| 293 |
+
return "GSM8K"
|
| 294 |
+
return Path(data_path).stem # fallback: raw filename stem
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def evaluate_gsm8k(
|
| 298 |
+
model: Any,
|
| 299 |
+
tokenizer: Any,
|
| 300 |
+
data_path: str = "data/sft/gsm8k_test.jsonl",
|
| 301 |
+
max_samples: int = 500,
|
| 302 |
+
max_new_tokens: int = 512,
|
| 303 |
+
temperature: float = 0.0,
|
| 304 |
+
top_p: float = 1.0,
|
| 305 |
+
reward_fn: Any = None,
|
| 306 |
+
pass_at_k: int = 0,
|
| 307 |
+
dataset_name: str = "",
|
| 308 |
+
pass_at_k_temperature: float = 0.8,
|
| 309 |
+
) -> dict:
|
| 310 |
+
"""
|
| 311 |
+
Evaluate *model* on a math JSONL file using the SAME scoring
|
| 312 |
+
function used during GRPO training.
|
| 313 |
+
|
| 314 |
+
Args:
|
| 315 |
+
model : AutoModelForCausalLM (already on correct device).
|
| 316 |
+
tokenizer : Matching AutoTokenizer.
|
| 317 |
+
data_path : Path to JSONL with {question, answer} rows.
|
| 318 |
+
max_samples : Evaluation cap.
|
| 319 |
+
max_new_tokens / temperature / top_p : generation hyper-params.
|
| 320 |
+
reward_fn : callable(question: str, solution: str, gold: str) -> dict
|
| 321 |
+
Must return at minimum {"combined_score": float} and
|
| 322 |
+
optionally {"gt_match": bool, "prm_mean_score": float,
|
| 323 |
+
"sympy_score": float, "format_score": float}.
|
| 324 |
+
When supplied the primary accuracy metric becomes the
|
| 325 |
+
mean combined_score β identical to the GRPO training
|
| 326 |
+
objective β so every component (correctness, PRM step
|
| 327 |
+
quality, SymPy verification, format) contributes and
|
| 328 |
+
improvements in any of them show up immediately.
|
| 329 |
+
When None the function falls back to final-answer
|
| 330 |
+
exact-match accuracy (coarse binary).
|
| 331 |
+
|
| 332 |
+
Returns dict keys:
|
| 333 |
+
accuracy β mean combined_score per solution (or exact-match if no reward_fn)
|
| 334 |
+
combined_score β same as accuracy (alias)
|
| 335 |
+
correct_rate β fraction of solutions with gt_match == True
|
| 336 |
+
prm_mean β mean PRM step-quality score per solution
|
| 337 |
+
sympy_mean β mean SymPy verification score
|
| 338 |
+
format_mean β mean format compliance score
|
| 339 |
+
n_scored β solutions successfully scored by reward_fn
|
| 340 |
+
total β total solutions evaluated
|
| 341 |
+
# fallback (no reward_fn):
|
| 342 |
+
exact_match_rate β fraction of final answers matching gold
|
| 343 |
+
"""
|
| 344 |
+
import logging as _logging
|
| 345 |
+
_logger = _logging.getLogger(__name__)
|
| 346 |
+
|
| 347 |
+
greedy = temperature < 1e-6
|
| 348 |
+
rows: list[dict] = []
|
| 349 |
+
|
| 350 |
+
p = Path(data_path)
|
| 351 |
+
if p.is_file():
|
| 352 |
+
with p.open(encoding="utf-8") as fh:
|
| 353 |
+
for line in fh:
|
| 354 |
+
if max_samples > 0 and len(rows) >= max_samples:
|
| 355 |
+
break
|
| 356 |
+
line = line.strip()
|
| 357 |
+
if not line:
|
| 358 |
+
continue
|
| 359 |
+
obj = json.loads(line)
|
| 360 |
+
if "question" in obj and "gold_final" in obj and obj["gold_final"]:
|
| 361 |
+
# Pre-extracted format (our gsm8k_test.jsonl)
|
| 362 |
+
rows.append({"question": obj["question"].strip(), "gold_final": obj["gold_final"].strip()})
|
| 363 |
+
elif "question" in obj and "answer" in obj:
|
| 364 |
+
_, final = parse_gsm8k_answer(obj["answer"])
|
| 365 |
+
if final:
|
| 366 |
+
rows.append({"question": obj["question"].strip(), "gold_final": final})
|
| 367 |
+
elif "messages" in obj:
|
| 368 |
+
task_type = obj.get("task_type", "solve")
|
| 369 |
+
if task_type != "solve":
|
| 370 |
+
continue # skip question-generation entries
|
| 371 |
+
user = next(
|
| 372 |
+
(m["content"] for m in obj["messages"] if m.get("role") == "user"), ""
|
| 373 |
+
).strip()
|
| 374 |
+
asst = next(
|
| 375 |
+
(m["content"] for m in obj["messages"] if m.get("role") == "assistant"), ""
|
| 376 |
+
)
|
| 377 |
+
gold = extract_final_answer_numeric_str(asst) or ""
|
| 378 |
+
if not gold:
|
| 379 |
+
continue # skip entries with no parseable gold answer
|
| 380 |
+
user = re.sub(r"^Solve the following problem\..*?Problem:\n", "", user, flags=re.S)
|
| 381 |
+
rows.append({"question": user.strip(), "gold_final": gold.strip()})
|
| 382 |
+
else:
|
| 383 |
+
_logger.warning(
|
| 384 |
+
f"evaluate_gsm8k: {data_path} not found; loading openai/gsm8k from Hub."
|
| 385 |
+
)
|
| 386 |
+
try:
|
| 387 |
+
ds = load_dataset("openai/gsm8k", "main", split="test")
|
| 388 |
+
if max_samples > 0:
|
| 389 |
+
ds = ds.select(range(min(max_samples, len(ds))))
|
| 390 |
+
for row in ds:
|
| 391 |
+
_, final = parse_gsm8k_answer(row["answer"])
|
| 392 |
+
rows.append({"question": row["question"].strip(), "gold_final": final})
|
| 393 |
+
except Exception as exc:
|
| 394 |
+
_logger.error(f"Could not load GSM8K: {exc}")
|
| 395 |
+
return {"accuracy": 0.0, "correct": 0, "total": 0, "exact_match_rate": 0.0}
|
| 396 |
+
|
| 397 |
+
if not rows:
|
| 398 |
+
return {"accuracy": 0.0, "correct": 0, "total": 0, "exact_match_rate": 0.0}
|
| 399 |
+
|
| 400 |
+
correct = 0
|
| 401 |
+
total = len(rows)
|
| 402 |
+
_n_errors = 0
|
| 403 |
+
_MAX_ERROR_WARNINGS = 3
|
| 404 |
+
|
| 405 |
+
# Per-solution reward accumulators (populated when reward_fn is supplied).
|
| 406 |
+
_combined: list[float] = []
|
| 407 |
+
_gt_match: list[float] = []
|
| 408 |
+
_prm_comp: list[float] = []
|
| 409 |
+
_prm_final: list[float] = []
|
| 410 |
+
_step_acc: list[float] = [] # fraction of steps rated correct by PRM (>0.5)
|
| 411 |
+
_lccp: list[float] = [] # longest correct consecutive prefix ratio
|
| 412 |
+
_sympy_comp:list[float] = []
|
| 413 |
+
_fmt_comp: list[float] = []
|
| 414 |
+
|
| 415 |
+
# Pass@K accumulators: for each problem, did ANY of K samples get it right?
|
| 416 |
+
_pak_any_correct: list[int] = [] # 1 if any of K samples correct, else 0
|
| 417 |
+
|
| 418 |
+
_eval_label = dataset_name or _infer_dataset_name(data_path)
|
| 419 |
+
pbar = tqdm(
|
| 420 |
+
rows, total=total, desc=f"{_eval_label} eval",
|
| 421 |
+
unit="q", dynamic_ncols=True, leave=True,
|
| 422 |
+
)
|
| 423 |
+
for i, row in enumerate(pbar):
|
| 424 |
+
pred_text = ""
|
| 425 |
+
try:
|
| 426 |
+
pred_text = _generate(
|
| 427 |
+
model=model, tokenizer=tokenizer,
|
| 428 |
+
problem=row["question"],
|
| 429 |
+
max_new_tokens=max_new_tokens,
|
| 430 |
+
temperature=temperature, top_p=top_p, greedy=greedy,
|
| 431 |
+
)
|
| 432 |
+
pred_final = extract_final_answer_numeric_str(pred_text) or ""
|
| 433 |
+
if _equiv_expr(pred_final, row["gold_final"]):
|
| 434 |
+
correct += 1
|
| 435 |
+
except Exception as exc:
|
| 436 |
+
_n_errors += 1
|
| 437 |
+
if _n_errors <= _MAX_ERROR_WARNINGS:
|
| 438 |
+
_logger.warning(
|
| 439 |
+
"evaluate_gsm8k: sample %d raised %s: %s. "
|
| 440 |
+
"If all fail check that tokenizer has a chat_template.",
|
| 441 |
+
i, type(exc).__name__, exc,
|
| 442 |
+
)
|
| 443 |
+
elif _n_errors == _MAX_ERROR_WARNINGS + 1:
|
| 444 |
+
_logger.warning(
|
| 445 |
+
"evaluate_gsm8k: suppressing further errors (%d so far).",
|
| 446 |
+
_n_errors,
|
| 447 |
+
)
|
| 448 |
+
_logger.debug("Sample %d error: %s", i, exc, exc_info=True)
|
| 449 |
+
|
| 450 |
+
# ββ Pass@K: sample K solutions at T=0.8 and check if any is correct β
|
| 451 |
+
# This is the fair comparison to batch_acc during training (also K samples
|
| 452 |
+
# at T=0.8). Greedy (pass@1) is pessimistic; pass@k shows the upper bound
|
| 453 |
+
# the model can achieve with sampling, matching the training regime.
|
| 454 |
+
if pass_at_k > 1 and row.get("gold_final"):
|
| 455 |
+
_any = 0
|
| 456 |
+
for _ in range(pass_at_k):
|
| 457 |
+
try:
|
| 458 |
+
s = _generate(
|
| 459 |
+
model=model, tokenizer=tokenizer,
|
| 460 |
+
problem=row["question"],
|
| 461 |
+
max_new_tokens=max_new_tokens,
|
| 462 |
+
temperature=pass_at_k_temperature,
|
| 463 |
+
top_p=top_p, greedy=False,
|
| 464 |
+
)
|
| 465 |
+
pf = extract_final_answer_numeric_str(s) or ""
|
| 466 |
+
if _equiv_expr(pf, row["gold_final"]):
|
| 467 |
+
_any = 1
|
| 468 |
+
break
|
| 469 |
+
except Exception:
|
| 470 |
+
pass
|
| 471 |
+
_pak_any_correct.append(_any)
|
| 472 |
+
|
| 473 |
+
# ββ Apply the SAME reward function used during GRPO training ββββββββββ
|
| 474 |
+
if reward_fn is not None and pred_text:
|
| 475 |
+
try:
|
| 476 |
+
r = reward_fn(row["question"], pred_text, row["gold_final"])
|
| 477 |
+
_combined.append(float(r.get("combined_score", 0.0)))
|
| 478 |
+
_gt_match.append(1.0 if r.get("gt_match", False) else 0.0)
|
| 479 |
+
_prm_comp.append(float(r.get("prm_mean_score", 0.0)))
|
| 480 |
+
_prm_final.append(float(r.get("prm_final_score", 0.0)))
|
| 481 |
+
_step_acc.append(float(r.get("step_accuracy", 0.0)))
|
| 482 |
+
_lccp.append(float(r.get("lccp", 0.0)))
|
| 483 |
+
_sympy_comp.append(float(r.get("sympy_score", 0.0)))
|
| 484 |
+
_fmt_comp.append(float(r.get("format_score", 0.0)))
|
| 485 |
+
except Exception as rfn_exc:
|
| 486 |
+
_logger.debug("reward_fn failed for sample %d: %s", i, rfn_exc)
|
| 487 |
+
|
| 488 |
+
done = i + 1
|
| 489 |
+
# Periodically flush the CUDA allocator's free-block pool so that
|
| 490 |
+
# fragmentation from large KV-cache + PRM tensors doesn't accumulate
|
| 491 |
+
# and cause per-sample allocation time to grow throughout the run.
|
| 492 |
+
if done % 20 == 0:
|
| 493 |
+
import gc; gc.collect()
|
| 494 |
+
if torch.cuda.is_available():
|
| 495 |
+
torch.cuda.empty_cache()
|
| 496 |
+
|
| 497 |
+
# Live bar: show training-objective score when available, else acc.
|
| 498 |
+
if _combined:
|
| 499 |
+
_pf: dict = dict(
|
| 500 |
+
score=f"{sum(_combined) / len(_combined):.3f}",
|
| 501 |
+
correct=f"{sum(_gt_match):.0f}/{len(_combined)}",
|
| 502 |
+
step_acc=f"{sum(_step_acc)/len(_step_acc):.1%}" if _step_acc else "β",
|
| 503 |
+
lccp=f"{sum(_lccp)/len(_lccp):.1%}" if _lccp else "β",
|
| 504 |
+
)
|
| 505 |
+
else:
|
| 506 |
+
_pf = dict(acc=f"{correct / done:.1%}", correct=f"{correct}/{done}")
|
| 507 |
+
pbar.set_postfix(**_pf, refresh=False)
|
| 508 |
+
|
| 509 |
+
# ββ Aggregate ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 510 |
+
n_scored = len(_combined)
|
| 511 |
+
_avg = lambda lst: round(sum(lst) / len(lst), 4) if lst else 0.0
|
| 512 |
+
|
| 513 |
+
# Pass@K: fraction of problems where any of K sampled solutions was correct.
|
| 514 |
+
pass_at_k_score = _avg(_pak_any_correct) if _pak_any_correct else None
|
| 515 |
+
|
| 516 |
+
if reward_fn is not None:
|
| 517 |
+
combined_score = _avg(_combined)
|
| 518 |
+
result: dict = {
|
| 519 |
+
# PRIMARY: mean training-objective score.
|
| 520 |
+
# Formula: 0.50Γcorrect + 0.40Γprocess(prm_final, prm_mean) + 0.10Γformat
|
| 521 |
+
"accuracy": combined_score,
|
| 522 |
+
"combined_score": combined_score,
|
| 523 |
+
# PROCESS metrics β improve before correct_rate does
|
| 524 |
+
"step_accuracy": _avg(_step_acc),
|
| 525 |
+
"lccp": _avg(_lccp), # chain integrity: how far into solution stays correct
|
| 526 |
+
# Answer correctness
|
| 527 |
+
"correct_rate": _avg(_gt_match),
|
| 528 |
+
# PRM components
|
| 529 |
+
"prm_mean": _avg(_prm_comp),
|
| 530 |
+
"prm_final": _avg(_prm_final),
|
| 531 |
+
# Format / SymPy (informational)
|
| 532 |
+
"sympy_mean": _avg(_sympy_comp),
|
| 533 |
+
"format_mean": _avg(_fmt_comp),
|
| 534 |
+
"n_scored": n_scored,
|
| 535 |
+
"total": total,
|
| 536 |
+
"final_answer_correct": correct,
|
| 537 |
+
"final_answer_accuracy": correct / total if total else 0.0,
|
| 538 |
+
}
|
| 539 |
+
else:
|
| 540 |
+
_logger.warning(
|
| 541 |
+
"evaluate_gsm8k: no reward_fn provided β using final-answer accuracy. "
|
| 542 |
+
"Pass reward_fn=math_env.compute_grounded_reward for full training-objective eval."
|
| 543 |
+
)
|
| 544 |
+
fa_acc = correct / total if total else 0.0
|
| 545 |
+
result = {
|
| 546 |
+
"accuracy": fa_acc,
|
| 547 |
+
"combined_score": fa_acc,
|
| 548 |
+
"correct_rate": fa_acc,
|
| 549 |
+
"prm_mean": 0.0,
|
| 550 |
+
"sympy_mean": 0.0,
|
| 551 |
+
"format_mean": 0.0,
|
| 552 |
+
"n_scored": 0,
|
| 553 |
+
"total": total,
|
| 554 |
+
"final_answer_correct": correct,
|
| 555 |
+
"final_answer_accuracy": fa_acc,
|
| 556 |
+
}
|
| 557 |
+
# Attach pass@k if it was computed
|
| 558 |
+
if pass_at_k_score is not None:
|
| 559 |
+
result["pass_at_k"] = pass_at_k_score
|
| 560 |
+
result["pass_at_k_k"] = pass_at_k
|
| 561 |
+
return result
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
if __name__ == "__main__":
|
| 565 |
+
main()
|
scripts/gsm8k_sft_pipeline.py
ADDED
|
@@ -0,0 +1,475 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
End-to-end GSM8K pipeline: prepare JSONL β QLoRA SFT β save adapter β inference.
|
| 4 |
+
|
| 5 |
+
The trained model follows ``Step N:`` / ``Final Answer:`` formatting with SymPy-friendly
|
| 6 |
+
expressions (see ``src.agent.math_agent.SOLVER_SYSTEM_PROMPT``).
|
| 7 |
+
|
| 8 |
+
Examples
|
| 9 |
+
--------
|
| 10 |
+
# 1) Only build training JSONL from Hugging Face GSM8K
|
| 11 |
+
python scripts/gsm8k_sft_pipeline.py prepare --output data/sft/gsm8k_sft.jsonl
|
| 12 |
+
|
| 13 |
+
# 2) Fine-tune (requires GPU recommended)
|
| 14 |
+
python scripts/gsm8k_sft_pipeline.py train \\
|
| 15 |
+
--data data/sft/gsm8k_sft.jsonl \\
|
| 16 |
+
--output-dir checkpoints/gsm8k_sft
|
| 17 |
+
|
| 18 |
+
# 3) Run inference with saved adapter
|
| 19 |
+
python scripts/gsm8k_sft_pipeline.py infer \\
|
| 20 |
+
--adapter checkpoints/gsm8k_sft \\
|
| 21 |
+
--problem \"Janet has 16 eggs. She eats 3. How many are left?\"
|
| 22 |
+
|
| 23 |
+
# Full chain
|
| 24 |
+
python scripts/gsm8k_sft_pipeline.py all --output-dir checkpoints/gsm8k_sft
|
| 25 |
+
|
| 26 |
+
Dependencies: torch, transformers, peft, datasets, accelerate, bitsandbytes, trl, sympy
|
| 27 |
+
|
| 28 |
+
Tip: if downloads fail with XET / "Background writer channel closed", export ``HF_HUB_DISABLE_XET=1``
|
| 29 |
+
before running (this script sets it by default unless already set).
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
from __future__ import annotations
|
| 33 |
+
|
| 34 |
+
import os
|
| 35 |
+
|
| 36 |
+
# hf-xet can error or segfault on interrupted/large shards; classic HTTP download is more robust.
|
| 37 |
+
if "HF_HUB_DISABLE_XET" not in os.environ:
|
| 38 |
+
os.environ["HF_HUB_DISABLE_XET"] = "1"
|
| 39 |
+
|
| 40 |
+
import argparse
|
| 41 |
+
import json
|
| 42 |
+
import math
|
| 43 |
+
import subprocess
|
| 44 |
+
import sys
|
| 45 |
+
from pathlib import Path
|
| 46 |
+
|
| 47 |
+
# Project root (β¦/Maths_LLM)
|
| 48 |
+
ROOT = Path(__file__).resolve().parents[1]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def cmd_prepare(args: argparse.Namespace) -> None:
|
| 52 |
+
cmd = [
|
| 53 |
+
sys.executable,
|
| 54 |
+
str(ROOT / "scripts" / "convert_gsm8k_to_sft.py"),
|
| 55 |
+
"--output",
|
| 56 |
+
str(Path(args.output)),
|
| 57 |
+
"--splits",
|
| 58 |
+
*args.splits,
|
| 59 |
+
]
|
| 60 |
+
if args.source == "jsonl":
|
| 61 |
+
cmd.extend(["--source", "jsonl", "--input", str(args.input)])
|
| 62 |
+
print("Running:", " ".join(cmd))
|
| 63 |
+
subprocess.check_call(cmd, cwd=str(ROOT))
|
| 64 |
+
if args.strip_scratchpads:
|
| 65 |
+
_rewrite_jsonl_strip_scratchpads(Path(args.output))
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _rewrite_jsonl_strip_scratchpads(jsonl_path: Path) -> None:
|
| 69 |
+
from src.sft.solution_format import strip_gsm8k_scratchpads
|
| 70 |
+
|
| 71 |
+
tmp = jsonl_path.with_suffix(".jsonl.tmp")
|
| 72 |
+
n = 0
|
| 73 |
+
with jsonl_path.open(encoding="utf-8") as fin, tmp.open("w", encoding="utf-8") as fout:
|
| 74 |
+
for line in fin:
|
| 75 |
+
o = json.loads(line)
|
| 76 |
+
for m in o.get("messages", []):
|
| 77 |
+
if m.get("role") == "assistant":
|
| 78 |
+
m["content"] = strip_gsm8k_scratchpads(m["content"])
|
| 79 |
+
if "text" in o:
|
| 80 |
+
sys_p = next(x["content"] for x in o["messages"] if x["role"] == "system")
|
| 81 |
+
usr = next(x["content"] for x in o["messages"] if x["role"] == "user")
|
| 82 |
+
asst = next(x["content"] for x in o["messages"] if x["role"] == "assistant")
|
| 83 |
+
o["text"] = (
|
| 84 |
+
f"<|system|>\n{sys_p}\n<|user|>\n{usr}\n<|assistant|>\n{asst}"
|
| 85 |
+
)
|
| 86 |
+
fout.write(json.dumps(o, ensure_ascii=False) + "\n")
|
| 87 |
+
n += 1
|
| 88 |
+
tmp.replace(jsonl_path)
|
| 89 |
+
print(f"Stripped <<>> scratchpads in {n} records β {jsonl_path}")
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def _warmup_steps_from_ratio(
|
| 93 |
+
num_examples: int,
|
| 94 |
+
per_device_train_batch_size: int,
|
| 95 |
+
gradient_accumulation_steps: int,
|
| 96 |
+
num_train_epochs: float,
|
| 97 |
+
warmup_ratio: float,
|
| 98 |
+
) -> int:
|
| 99 |
+
"""Approximate HF Trainer optimizer steps; used to map legacy warmup_ratio β warmup_steps."""
|
| 100 |
+
if warmup_ratio <= 0:
|
| 101 |
+
return 0
|
| 102 |
+
num_batches = max(
|
| 103 |
+
1,
|
| 104 |
+
(num_examples + per_device_train_batch_size - 1) // per_device_train_batch_size,
|
| 105 |
+
)
|
| 106 |
+
num_update_steps_per_epoch = max(1, num_batches // gradient_accumulation_steps)
|
| 107 |
+
total_optimizer_steps = max(1, math.ceil(num_train_epochs * num_update_steps_per_epoch))
|
| 108 |
+
return min(total_optimizer_steps, int(total_optimizer_steps * warmup_ratio))
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def cmd_train(args: argparse.Namespace) -> None:
|
| 112 |
+
try:
|
| 113 |
+
import torch
|
| 114 |
+
from datasets import load_dataset
|
| 115 |
+
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
| 116 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 117 |
+
from trl import SFTConfig, SFTTrainer
|
| 118 |
+
except ImportError as e:
|
| 119 |
+
raise SystemExit(
|
| 120 |
+
"Missing dependency for training. Install:\n"
|
| 121 |
+
" pip install torch transformers peft datasets accelerate bitsandbytes trl sympy\n"
|
| 122 |
+
f"Original error: {e}"
|
| 123 |
+
) from e
|
| 124 |
+
|
| 125 |
+
data_path = Path(args.data)
|
| 126 |
+
if not data_path.is_file():
|
| 127 |
+
raise SystemExit(f"Data file not found: {data_path}")
|
| 128 |
+
|
| 129 |
+
out_dir = Path(args.output_dir)
|
| 130 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 131 |
+
|
| 132 |
+
compute_dtype = getattr(torch, args.bnb_compute_dtype)
|
| 133 |
+
bnb_config = BitsAndBytesConfig(
|
| 134 |
+
load_in_4bit=True,
|
| 135 |
+
bnb_4bit_compute_dtype=compute_dtype,
|
| 136 |
+
bnb_4bit_quant_type="nf4",
|
| 137 |
+
bnb_4bit_use_double_quant=True,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
|
| 141 |
+
if tokenizer.pad_token is None:
|
| 142 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 143 |
+
tokenizer.padding_side = "right"
|
| 144 |
+
|
| 145 |
+
print(f"Loading model {args.model} β¦")
|
| 146 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 147 |
+
args.model,
|
| 148 |
+
quantization_config=bnb_config,
|
| 149 |
+
device_map="auto",
|
| 150 |
+
trust_remote_code=True,
|
| 151 |
+
dtype=compute_dtype,
|
| 152 |
+
)
|
| 153 |
+
model = prepare_model_for_kbit_training(model)
|
| 154 |
+
peft = LoraConfig(
|
| 155 |
+
r=args.lora_rank,
|
| 156 |
+
lora_alpha=args.lora_alpha,
|
| 157 |
+
lora_dropout=args.lora_dropout,
|
| 158 |
+
bias="none",
|
| 159 |
+
task_type="CAUSAL_LM",
|
| 160 |
+
target_modules=list(args.target_modules.split(",")),
|
| 161 |
+
)
|
| 162 |
+
model = get_peft_model(model, peft)
|
| 163 |
+
model.config.use_cache = False
|
| 164 |
+
model.print_trainable_parameters()
|
| 165 |
+
|
| 166 |
+
ds = load_dataset("json", data_files=str(data_path), split="train")
|
| 167 |
+
if args.max_samples and args.max_samples > 0:
|
| 168 |
+
ds = ds.select(range(min(args.max_samples, len(ds))))
|
| 169 |
+
|
| 170 |
+
def formatting_func(example):
|
| 171 |
+
return tokenizer.apply_chat_template(
|
| 172 |
+
example["messages"],
|
| 173 |
+
tokenize=False,
|
| 174 |
+
add_generation_prompt=False,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
if args.warmup_steps is not None:
|
| 178 |
+
warmup_steps = max(0, args.warmup_steps)
|
| 179 |
+
else:
|
| 180 |
+
warmup_steps = _warmup_steps_from_ratio(
|
| 181 |
+
len(ds),
|
| 182 |
+
args.batch_size,
|
| 183 |
+
args.grad_accum,
|
| 184 |
+
args.epochs,
|
| 185 |
+
args.warmup_ratio,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
sft_args = SFTConfig(
|
| 189 |
+
output_dir=str(out_dir),
|
| 190 |
+
num_train_epochs=args.epochs,
|
| 191 |
+
per_device_train_batch_size=args.batch_size,
|
| 192 |
+
gradient_accumulation_steps=args.grad_accum,
|
| 193 |
+
learning_rate=args.learning_rate,
|
| 194 |
+
logging_steps=args.logging_steps,
|
| 195 |
+
save_steps=args.save_steps,
|
| 196 |
+
save_total_limit=3,
|
| 197 |
+
bf16=args.bf16 and torch.cuda.is_available(),
|
| 198 |
+
fp16=args.fp16 and torch.cuda.is_available() and not args.bf16,
|
| 199 |
+
max_length=args.max_seq_length,
|
| 200 |
+
warmup_steps=warmup_steps,
|
| 201 |
+
lr_scheduler_type="cosine",
|
| 202 |
+
report_to="none",
|
| 203 |
+
gradient_checkpointing=True,
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
trainer = SFTTrainer(
|
| 207 |
+
model=model,
|
| 208 |
+
args=sft_args,
|
| 209 |
+
train_dataset=ds,
|
| 210 |
+
processing_class=tokenizer,
|
| 211 |
+
formatting_func=formatting_func,
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
trainer.train()
|
| 215 |
+
trainer.save_model(str(out_dir))
|
| 216 |
+
tokenizer.save_pretrained(str(out_dir))
|
| 217 |
+
|
| 218 |
+
with (out_dir / "pipeline_meta.json").open("w", encoding="utf-8") as f:
|
| 219 |
+
json.dump(
|
| 220 |
+
{
|
| 221 |
+
"base_model": args.model,
|
| 222 |
+
"data": str(data_path),
|
| 223 |
+
"lora_rank": args.lora_rank,
|
| 224 |
+
"epochs": args.epochs,
|
| 225 |
+
},
|
| 226 |
+
f,
|
| 227 |
+
indent=2,
|
| 228 |
+
)
|
| 229 |
+
print(f"Saved adapter and tokenizer to {out_dir}")
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def cmd_infer(args: argparse.Namespace) -> None:
|
| 233 |
+
import torch
|
| 234 |
+
from peft import PeftModel
|
| 235 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 236 |
+
|
| 237 |
+
from src.agent.math_agent import SOLVER_SYSTEM_PROMPT
|
| 238 |
+
|
| 239 |
+
adapter = Path(args.adapter)
|
| 240 |
+
meta_path = adapter / "pipeline_meta.json"
|
| 241 |
+
base_model = args.base_model
|
| 242 |
+
if meta_path.is_file():
|
| 243 |
+
meta = json.loads(meta_path.read_text(encoding="utf-8"))
|
| 244 |
+
base_model = meta.get("base_model", base_model)
|
| 245 |
+
|
| 246 |
+
compute_dtype = getattr(torch, args.bnb_compute_dtype)
|
| 247 |
+
bnb_config = BitsAndBytesConfig(
|
| 248 |
+
load_in_4bit=True,
|
| 249 |
+
bnb_4bit_compute_dtype=compute_dtype,
|
| 250 |
+
bnb_4bit_quant_type="nf4",
|
| 251 |
+
bnb_4bit_use_double_quant=True,
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
tokenizer = AutoTokenizer.from_pretrained(adapter, trust_remote_code=True)
|
| 255 |
+
if tokenizer.pad_token is None:
|
| 256 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 257 |
+
|
| 258 |
+
print(f"Loading base {base_model} + adapter {adapter} β¦")
|
| 259 |
+
base = AutoModelForCausalLM.from_pretrained(
|
| 260 |
+
base_model,
|
| 261 |
+
quantization_config=bnb_config,
|
| 262 |
+
device_map="auto",
|
| 263 |
+
trust_remote_code=True,
|
| 264 |
+
)
|
| 265 |
+
model = PeftModel.from_pretrained(base, str(adapter))
|
| 266 |
+
model.eval()
|
| 267 |
+
|
| 268 |
+
user_content = (
|
| 269 |
+
"Solve the following problem. Show your reasoning as numbered steps, "
|
| 270 |
+
"then give the final numeric answer on the last line.\n\n"
|
| 271 |
+
f"Problem:\n{args.problem.strip()}"
|
| 272 |
+
)
|
| 273 |
+
messages = [
|
| 274 |
+
{"role": "system", "content": SOLVER_SYSTEM_PROMPT},
|
| 275 |
+
{"role": "user", "content": user_content},
|
| 276 |
+
]
|
| 277 |
+
prompt = tokenizer.apply_chat_template(
|
| 278 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 279 |
+
)
|
| 280 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 281 |
+
|
| 282 |
+
with torch.no_grad():
|
| 283 |
+
out = model.generate(
|
| 284 |
+
**inputs,
|
| 285 |
+
max_new_tokens=args.max_new_tokens,
|
| 286 |
+
temperature=args.temperature,
|
| 287 |
+
top_p=args.top_p,
|
| 288 |
+
do_sample=not args.greedy,
|
| 289 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
gen_ids = out[0, inputs["input_ids"].shape[1] :]
|
| 293 |
+
text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
|
| 294 |
+
print("\n--- Generated ---\n")
|
| 295 |
+
print(text)
|
| 296 |
+
print("\n--- Format check ---")
|
| 297 |
+
from src.sft.solution_format import validate_sympy_solution_format
|
| 298 |
+
|
| 299 |
+
r = validate_sympy_solution_format(text)
|
| 300 |
+
print(json.dumps(r.__dict__, indent=2))
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def cmd_all(args: argparse.Namespace) -> None:
|
| 304 |
+
out_jsonl = Path(args.data) if args.data else ROOT / "data" / "sft" / "gsm8k_sft.jsonl"
|
| 305 |
+
ns = argparse.Namespace(
|
| 306 |
+
output=out_jsonl,
|
| 307 |
+
source=args.prepare_source,
|
| 308 |
+
input=args.input,
|
| 309 |
+
splits=args.splits,
|
| 310 |
+
strip_scratchpads=args.strip_scratchpads,
|
| 311 |
+
)
|
| 312 |
+
cmd_prepare(ns)
|
| 313 |
+
train_ns = argparse.Namespace(
|
| 314 |
+
data=str(out_jsonl),
|
| 315 |
+
output_dir=args.output_dir,
|
| 316 |
+
model=args.model,
|
| 317 |
+
epochs=args.epochs,
|
| 318 |
+
batch_size=args.batch_size,
|
| 319 |
+
grad_accum=args.grad_accum,
|
| 320 |
+
learning_rate=args.learning_rate,
|
| 321 |
+
max_samples=args.max_samples,
|
| 322 |
+
lora_rank=args.lora_rank,
|
| 323 |
+
lora_alpha=args.lora_alpha,
|
| 324 |
+
lora_dropout=args.lora_dropout,
|
| 325 |
+
target_modules=args.target_modules,
|
| 326 |
+
max_seq_length=args.max_seq_length,
|
| 327 |
+
save_steps=args.save_steps,
|
| 328 |
+
logging_steps=args.logging_steps,
|
| 329 |
+
warmup_ratio=args.warmup_ratio,
|
| 330 |
+
warmup_steps=args.warmup_steps,
|
| 331 |
+
bf16=args.bf16,
|
| 332 |
+
fp16=args.fp16,
|
| 333 |
+
bnb_compute_dtype=args.bnb_compute_dtype,
|
| 334 |
+
)
|
| 335 |
+
cmd_train(train_ns)
|
| 336 |
+
if args.problem:
|
| 337 |
+
infer_ns = argparse.Namespace(
|
| 338 |
+
adapter=Path(args.output_dir),
|
| 339 |
+
base_model=args.model,
|
| 340 |
+
problem=args.problem,
|
| 341 |
+
max_new_tokens=args.max_new_tokens,
|
| 342 |
+
temperature=args.temperature,
|
| 343 |
+
top_p=args.top_p,
|
| 344 |
+
greedy=args.greedy,
|
| 345 |
+
bnb_compute_dtype=args.bnb_compute_dtype,
|
| 346 |
+
)
|
| 347 |
+
cmd_infer(infer_ns)
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
def build_parser() -> argparse.ArgumentParser:
|
| 351 |
+
p = argparse.ArgumentParser(description="GSM8K SFT pipeline (prepare / train / infer / all)")
|
| 352 |
+
sub = p.add_subparsers(dest="command", required=True)
|
| 353 |
+
|
| 354 |
+
pr = sub.add_parser("prepare", help="Run convert_gsm8k_to_sft.py")
|
| 355 |
+
pr.add_argument("--output", type=str, default=str(ROOT / "data" / "sft" / "gsm8k_sft.jsonl"))
|
| 356 |
+
pr.add_argument("--source", choices=("hf", "jsonl"), default="hf")
|
| 357 |
+
pr.add_argument("--input", type=str, help="JSONL path for --source jsonl")
|
| 358 |
+
pr.add_argument("--splits", nargs="+", default=["train", "test"])
|
| 359 |
+
pr.add_argument(
|
| 360 |
+
"--strip-scratchpads",
|
| 361 |
+
action="store_true",
|
| 362 |
+
help="Remove GSM8K <<...>> traces from assistant text after conversion.",
|
| 363 |
+
)
|
| 364 |
+
pr.set_defaults(func=cmd_prepare)
|
| 365 |
+
|
| 366 |
+
tr = sub.add_parser("train", help="QLoRA SFT on JSONL with messages field")
|
| 367 |
+
tr.add_argument("--data", type=str, required=True, help="JSONL from prepare step")
|
| 368 |
+
tr.add_argument("--output-dir", type=str, required=True)
|
| 369 |
+
tr.add_argument("--model", type=str, default="Qwen/Qwen2.5-Math-1.5B-Instruct")
|
| 370 |
+
tr.add_argument("--epochs", type=float, default=1.0)
|
| 371 |
+
tr.add_argument("--batch-size", type=int, default=1)
|
| 372 |
+
tr.add_argument("--grad-accum", type=int, default=8)
|
| 373 |
+
tr.add_argument("--learning-rate", type=float, default=2e-4)
|
| 374 |
+
tr.add_argument("--max-samples", type=int, default=0, help="0 = use full dataset")
|
| 375 |
+
tr.add_argument("--lora-rank", type=int, default=16)
|
| 376 |
+
tr.add_argument("--lora-alpha", type=int, default=32)
|
| 377 |
+
tr.add_argument("--lora-dropout", type=float, default=0.05)
|
| 378 |
+
tr.add_argument(
|
| 379 |
+
"--target-modules",
|
| 380 |
+
type=str,
|
| 381 |
+
default="q_proj,v_proj,o_proj,gate_proj",
|
| 382 |
+
)
|
| 383 |
+
tr.add_argument("--max-seq-length", type=int, default=2048)
|
| 384 |
+
tr.add_argument("--save-steps", type=int, default=200)
|
| 385 |
+
tr.add_argument("--logging-steps", type=int, default=10)
|
| 386 |
+
tr.add_argument(
|
| 387 |
+
"--warmup-ratio",
|
| 388 |
+
type=float,
|
| 389 |
+
default=0.03,
|
| 390 |
+
help="Used only if --warmup-steps is not set; converted to warmup_steps.",
|
| 391 |
+
)
|
| 392 |
+
tr.add_argument(
|
| 393 |
+
"--warmup-steps",
|
| 394 |
+
type=int,
|
| 395 |
+
default=None,
|
| 396 |
+
help="LR warmup steps; if set, overrides --warmup-ratio.",
|
| 397 |
+
)
|
| 398 |
+
tr.add_argument("--bf16", action="store_true", default=True)
|
| 399 |
+
tr.add_argument("--no-bf16", dest="bf16", action="store_false")
|
| 400 |
+
tr.add_argument("--fp16", action="store_true")
|
| 401 |
+
tr.add_argument("--bnb-compute-dtype", type=str, default="bfloat16")
|
| 402 |
+
tr.set_defaults(func=cmd_train)
|
| 403 |
+
|
| 404 |
+
inf = sub.add_parser("infer", help="Generate with saved adapter")
|
| 405 |
+
inf.add_argument("--adapter", type=str, required=True, help="Directory from train step")
|
| 406 |
+
inf.add_argument(
|
| 407 |
+
"--base-model",
|
| 408 |
+
type=str,
|
| 409 |
+
default="Qwen/Qwen2.5-Math-1.5B-Instruct",
|
| 410 |
+
help="Must match base used in training if no pipeline_meta.json",
|
| 411 |
+
)
|
| 412 |
+
inf.add_argument("--problem", type=str, required=True)
|
| 413 |
+
inf.add_argument("--max-new-tokens", type=int, default=1024)
|
| 414 |
+
inf.add_argument("--temperature", type=float, default=0.7)
|
| 415 |
+
inf.add_argument("--top-p", type=float, default=0.95)
|
| 416 |
+
inf.add_argument("--greedy", action="store_true")
|
| 417 |
+
inf.add_argument("--bnb-compute-dtype", type=str, default="bfloat16")
|
| 418 |
+
inf.set_defaults(func=cmd_infer)
|
| 419 |
+
|
| 420 |
+
al = sub.add_parser("all", help="prepare + train [+ infer if --problem]")
|
| 421 |
+
al.add_argument("--data", type=str, default=None, help="Output JSONL path (default data/sft/gsm8k_sft.jsonl)")
|
| 422 |
+
al.add_argument("--prepare-source", choices=("hf", "jsonl"), default="hf")
|
| 423 |
+
al.add_argument("--input", type=str, help="For jsonl prepare")
|
| 424 |
+
al.add_argument("--splits", nargs="+", default=["train", "test"])
|
| 425 |
+
al.add_argument("--strip-scratchpads", action="store_true")
|
| 426 |
+
al.add_argument("--output-dir", type=str, required=True)
|
| 427 |
+
al.add_argument("--model", type=str, default="Qwen/Qwen2.5-Math-1.5B-Instruct")
|
| 428 |
+
al.add_argument("--epochs", type=float, default=1.0)
|
| 429 |
+
al.add_argument("--batch-size", type=int, default=1)
|
| 430 |
+
al.add_argument("--grad-accum", type=int, default=8)
|
| 431 |
+
al.add_argument("--learning-rate", type=float, default=2e-4)
|
| 432 |
+
al.add_argument("--max-samples", type=int, default=0)
|
| 433 |
+
al.add_argument("--lora-rank", type=int, default=16)
|
| 434 |
+
al.add_argument("--lora-alpha", type=int, default=32)
|
| 435 |
+
al.add_argument("--lora-dropout", type=float, default=0.05)
|
| 436 |
+
al.add_argument("--target-modules", type=str, default="q_proj,v_proj,o_proj,gate_proj")
|
| 437 |
+
al.add_argument("--max-seq-length", type=int, default=2048)
|
| 438 |
+
al.add_argument("--save-steps", type=int, default=200)
|
| 439 |
+
al.add_argument("--logging-steps", type=int, default=10)
|
| 440 |
+
al.add_argument(
|
| 441 |
+
"--warmup-ratio",
|
| 442 |
+
type=float,
|
| 443 |
+
default=0.03,
|
| 444 |
+
help="Used only if --warmup-steps is not set; converted to warmup_steps.",
|
| 445 |
+
)
|
| 446 |
+
al.add_argument(
|
| 447 |
+
"--warmup-steps",
|
| 448 |
+
type=int,
|
| 449 |
+
default=None,
|
| 450 |
+
help="LR warmup steps; if set, overrides --warmup-ratio.",
|
| 451 |
+
)
|
| 452 |
+
al.add_argument("--bf16", action="store_true", default=True)
|
| 453 |
+
al.add_argument("--no-bf16", dest="bf16", action="store_false")
|
| 454 |
+
al.add_argument("--fp16", action="store_true")
|
| 455 |
+
al.add_argument("--bnb-compute-dtype", type=str, default="bfloat16")
|
| 456 |
+
al.add_argument("--problem", type=str, default="", help="If set, run infer after train")
|
| 457 |
+
al.add_argument("--max-new-tokens", type=int, default=1024)
|
| 458 |
+
al.add_argument("--temperature", type=float, default=0.7)
|
| 459 |
+
al.add_argument("--top-p", type=float, default=0.95)
|
| 460 |
+
al.add_argument("--greedy", action="store_true")
|
| 461 |
+
al.set_defaults(func=cmd_all)
|
| 462 |
+
|
| 463 |
+
return p
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
def main() -> None:
|
| 467 |
+
parser = build_parser()
|
| 468 |
+
args = parser.parse_args()
|
| 469 |
+
if str(ROOT) not in sys.path:
|
| 470 |
+
sys.path.insert(0, str(ROOT))
|
| 471 |
+
args.func(args)
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
if __name__ == "__main__":
|
| 475 |
+
main()
|
scripts/launch_grpo.sh
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
set -euo pipefail
|
| 2 |
+
|
| 3 |
+
# ββ Flash-Attention 2 install (if missing) ββββββββββββββββββββββββββββββββββββ
|
| 4 |
+
# flash-attn requires (torch version, CUDA version, Python version) alignment.
|
| 5 |
+
# MAX_JOBS caps parallel compilation; prebuilt wheel installs in <30 s.
|
| 6 |
+
# In the prior run (grpo_20260425_151304), flash-attn was absent β SDPA fallback
|
| 7 |
+
# β iter times of 262-330 s once question-gen started (vs ~150 s with Flash).
|
| 8 |
+
if ! python -c "import flash_attn; assert int(flash_attn.__version__.split('.')[0]) >= 2" 2>/dev/null; then
|
| 9 |
+
echo "[launch] flash-attn not found or < v2 β installing now β¦"
|
| 10 |
+
MAX_JOBS=4 pip install flash-attn --no-build-isolation -q
|
| 11 |
+
echo "[launch] flash-attn installed."
|
| 12 |
+
else
|
| 13 |
+
FLASH_VER=$(python -c "import flash_attn; print(flash_attn.__version__)" 2>/dev/null)
|
| 14 |
+
echo "[launch] flash-attn ${FLASH_VER} already installed β skipping install."
|
| 15 |
+
fi
|
| 16 |
+
|
| 17 |
+
# ββ GPU / allocator βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 18 |
+
export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-0}
|
| 19 |
+
# expandable_segments: recovers 2-4 GB fragmented VRAM during long Flash+HF runs
|
| 20 |
+
export PYTORCH_CUDA_ALLOC_CONF=${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}
|
| 21 |
+
|
| 22 |
+
# ββ CPU / threading βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 23 |
+
export OMP_NUM_THREADS=${OMP_NUM_THREADS:-8}
|
| 24 |
+
export MKL_NUM_THREADS=${MKL_NUM_THREADS:-8}
|
| 25 |
+
export TOKENIZERS_PARALLELISM=${TOKENIZERS_PARALLELISM:-false}
|
| 26 |
+
|
| 27 |
+
# ββ Triton / Flash-Attn compilation cache βββββββββββββββββββββββββββββββββββββ
|
| 28 |
+
# Persists JIT kernels across runs β avoids ~30 s recompile each launch.
|
| 29 |
+
export TRITON_CACHE_DIR=${TRITON_CACHE_DIR:-/tmp/triton_cache}
|
| 30 |
+
export FLASH_ATTENTION_SKIP_CUDA_BUILD=${FLASH_ATTENTION_SKIP_CUDA_BUILD:-FALSE}
|
| 31 |
+
|
| 32 |
+
# ββ HuggingFace hub robustness ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 33 |
+
export HF_HUB_DISABLE_XET=${HF_HUB_DISABLE_XET:-1}
|
| 34 |
+
export HF_HUB_ENABLE_HF_TRANSFER=${HF_HUB_ENABLE_HF_TRANSFER:-0}
|
| 35 |
+
export TRANSFORMERS_VERBOSITY=${TRANSFORMERS_VERBOSITY:-warning}
|
| 36 |
+
|
| 37 |
+
# ββ Python path βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 38 |
+
export PYTHONPATH="${PYTHONPATH:-}:$(pwd)"
|
| 39 |
+
|
| 40 |
+
# ββ Pre-flight: GPU info βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 41 |
+
if command -v nvidia-smi >/dev/null 2>&1; then
|
| 42 |
+
echo "βββ nvidia-smi βββββββββββββββββββββββββββββββββββββββββββββββββββ"
|
| 43 |
+
nvidia-smi --query-gpu=name,memory.total,memory.free,driver_version \
|
| 44 |
+
--format=csv,noheader || true
|
| 45 |
+
echo "ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ"
|
| 46 |
+
fi
|
| 47 |
+
|
| 48 |
+
# ββ Confirm attention backend βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 49 |
+
python - <<'PYEOF'
|
| 50 |
+
import sys; sys.path.insert(0, '.')
|
| 51 |
+
from src.utils.attn_backend import select_attn_implementation
|
| 52 |
+
impl = select_attn_implementation()
|
| 53 |
+
tag = {
|
| 54 |
+
"flash_attention_2": "FAST β Flash-Attn 2 active (O(T) memory, ~1.5-2Γ faster)",
|
| 55 |
+
"sdpa": "OK β SDPA active (install flash-attn for ~2Γ speedup)",
|
| 56 |
+
"eager": "SLOW β Eager fallback (install flash-attn for best speed)",
|
| 57 |
+
}.get(impl, impl)
|
| 58 |
+
print(f"[launch] attn_backend = {tag}")
|
| 59 |
+
PYEOF
|
| 60 |
+
|
| 61 |
+
# ββ Log tee βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 62 |
+
RUN_NAME="grpo_$(date +%Y%m%d_%H%M%S)"
|
| 63 |
+
LOG_DIR="logs/grpo"
|
| 64 |
+
mkdir -p "$LOG_DIR"
|
| 65 |
+
LOG_FILE="$LOG_DIR/${RUN_NAME}.log"
|
| 66 |
+
|
| 67 |
+
echo "[launch] run_name = $RUN_NAME"
|
| 68 |
+
echo "[launch] base_model = checkpoints/dual_task_v1"
|
| 69 |
+
echo "[launch] train_data = data/sft/gsm8k_sft.jsonl + data/math/math_numeric.jsonl"
|
| 70 |
+
echo "[launch] eval_data = data/sft/gsm8k_test.jsonl"
|
| 71 |
+
echo "[launch] log_file = $LOG_FILE"
|
| 72 |
+
echo "[launch] architecture = Two-phase self-play (K_q=2, K=10, N=20)"
|
| 73 |
+
echo "[launch] fixes_applied = min-warmupβ12, selfplay-gt-threshβ0.65, kl-coefβ0.06,"
|
| 74 |
+
echo "[launch] math-ramp-startβ18, group-sizeβ10, num-itersβ60"
|
| 75 |
+
echo "[launch] wall-time β 3.3 h (Flash active) / 4.5 h (SDPA fallback)"
|
| 76 |
+
|
| 77 |
+
# ββ Train βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 78 |
+
python -u scripts/run_grpo_training.py \
|
| 79 |
+
--base-model checkpoints/dual_task_v1 \
|
| 80 |
+
--output-dir checkpoints/grpo \
|
| 81 |
+
--gsm8k-data data/sft/gsm8k_sft.jsonl \
|
| 82 |
+
--eval-data-path data/sft/gsm8k_test.jsonl \
|
| 83 |
+
\
|
| 84 |
+
--num-iterations 60 \
|
| 85 |
+
--group-size 10 \
|
| 86 |
+
--q-group-size 2 \
|
| 87 |
+
--questions-per-iter 20 \
|
| 88 |
+
\
|
| 89 |
+
--learning-rate 5e-6 \
|
| 90 |
+
--max-new-tokens 1000 \
|
| 91 |
+
--temperature 0.8 \
|
| 92 |
+
--max-grad-norm 0.5 \
|
| 93 |
+
--clip-eps 0.2 \
|
| 94 |
+
--kl-coef 0.06 \
|
| 95 |
+
--warmup-iters 8 \
|
| 96 |
+
--min-lr-ratio 0.1 \
|
| 97 |
+
\
|
| 98 |
+
--difficulty-alpha 3.5 \
|
| 99 |
+
--self-play-ratio 0.70 \
|
| 100 |
+
\
|
| 101 |
+
--math-mix-ratio 0.30 \
|
| 102 |
+
--math-mix-ratio-late 0.50 \
|
| 103 |
+
--math-ramp-start 18 \
|
| 104 |
+
--math-max-difficulty 3 \
|
| 105 |
+
\
|
| 106 |
+
--overlong-filter \
|
| 107 |
+
--min-warmup 12 \
|
| 108 |
+
--selfplay-gt-thresh 0.65 \
|
| 109 |
+
--selfplay-grounded-thresh 0.65 \
|
| 110 |
+
--selfplay-step-thresh 0.68 \
|
| 111 |
+
--selfplay-ramp-iters 28 \
|
| 112 |
+
--grounded-floor 0.55 \
|
| 113 |
+
\
|
| 114 |
+
--extractor-model Qwen/Qwen2.5-0.5B-Instruct \
|
| 115 |
+
--extraction-cache data/extraction_cache.json \
|
| 116 |
+
\
|
| 117 |
+
--eval-every 5 \
|
| 118 |
+
--eval-max-samples 150 \
|
| 119 |
+
--eval-max-new-tokens 1000 \
|
| 120 |
+
--eval-pass-at-k 0 \
|
| 121 |
+
--save-every 5 \
|
| 122 |
+
--keep-last 4 \
|
| 123 |
+
\
|
| 124 |
+
--use-prm \
|
| 125 |
+
--prm-model Qwen/Qwen2.5-Math-PRM-7B \
|
| 126 |
+
--run-name "$RUN_NAME" \
|
| 127 |
+
"$@" 2>&1 | tee "$LOG_FILE"
|
scripts/plot_grpo_run.py
ADDED
|
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Generate demo-quality plots from a completed (or in-progress) GRPO run.
|
| 4 |
+
|
| 5 |
+
Usage
|
| 6 |
+
-----
|
| 7 |
+
# from the run output directory
|
| 8 |
+
python scripts/plot_grpo_run.py checkpoints/grpo/<run_name>/metrics.jsonl
|
| 9 |
+
|
| 10 |
+
# auto-discover the latest run
|
| 11 |
+
python scripts/plot_grpo_run.py --latest
|
| 12 |
+
|
| 13 |
+
# custom output directory
|
| 14 |
+
python scripts/plot_grpo_run.py metrics.jsonl --out-dir plots/my_run
|
| 15 |
+
|
| 16 |
+
Output
|
| 17 |
+
------
|
| 18 |
+
Six PNG files saved next to the JSONL (or --out-dir if given):
|
| 19 |
+
|
| 20 |
+
01_training_objective.png β combined_score vs iteration (PRIMARY demo plot)
|
| 21 |
+
02_reward_components.png β 4-panel breakdown: correct / PRM / SymPy / format
|
| 22 |
+
03_training_dynamics.png β GRPO loss + batch reward + batch accuracy
|
| 23 |
+
04_reward_vs_eval.png β training reward vs eval score on same axis
|
| 24 |
+
05_component_area.png β stacked-area chart of the 4 weighted components
|
| 25 |
+
06_summary_card.png β single-panel card: all key metrics in one view
|
| 26 |
+
|
| 27 |
+
All figures use a clean dark-on-white academic style. They are saved at
|
| 28 |
+
300 dpi so they look sharp in slides and posters.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
from __future__ import annotations
|
| 32 |
+
|
| 33 |
+
import argparse
|
| 34 |
+
import json
|
| 35 |
+
import sys
|
| 36 |
+
from pathlib import Path
|
| 37 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 38 |
+
|
| 39 |
+
import matplotlib
|
| 40 |
+
matplotlib.use("Agg") # headless β no display needed on training servers
|
| 41 |
+
import matplotlib.pyplot as plt
|
| 42 |
+
import matplotlib.ticker as mtick
|
| 43 |
+
import numpy as np
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# ββ Style ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 47 |
+
|
| 48 |
+
PALETTE = {
|
| 49 |
+
"combined": "#2563EB", # blue β training objective
|
| 50 |
+
"correct": "#16A34A", # green β correctness
|
| 51 |
+
"prm": "#DC2626", # red β PRM step quality
|
| 52 |
+
"sympy": "#D97706", # amber β SymPy verification
|
| 53 |
+
"fmt": "#7C3AED", # violet β format
|
| 54 |
+
"reward": "#0891B2", # cyan β mean batch reward
|
| 55 |
+
"loss": "#64748B", # slate β loss
|
| 56 |
+
"batch_acc": "#059669", # emerald β batch accuracy
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
plt.rcParams.update({
|
| 60 |
+
"figure.dpi": 150,
|
| 61 |
+
"savefig.dpi": 300,
|
| 62 |
+
"font.family": "DejaVu Sans",
|
| 63 |
+
"axes.spines.top": False,
|
| 64 |
+
"axes.spines.right": False,
|
| 65 |
+
"axes.grid": True,
|
| 66 |
+
"grid.alpha": 0.3,
|
| 67 |
+
"grid.linestyle": "--",
|
| 68 |
+
"axes.labelsize": 11,
|
| 69 |
+
"axes.titlesize": 13,
|
| 70 |
+
"legend.fontsize": 9,
|
| 71 |
+
"xtick.labelsize": 9,
|
| 72 |
+
"ytick.labelsize": 9,
|
| 73 |
+
})
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# ββ Data loading βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 77 |
+
|
| 78 |
+
def _load(path: Path) -> List[Dict[str, Any]]:
|
| 79 |
+
rows = []
|
| 80 |
+
with path.open(encoding="utf-8") as fh:
|
| 81 |
+
for line in fh:
|
| 82 |
+
line = line.strip()
|
| 83 |
+
if line:
|
| 84 |
+
rows.append(json.loads(line))
|
| 85 |
+
return rows
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _field(rows: List[Dict], key: str) -> Tuple[List[int], List[float]]:
|
| 89 |
+
"""Return (iterations, values) for rows that have a non-empty key."""
|
| 90 |
+
iters, vals = [], []
|
| 91 |
+
for r in rows:
|
| 92 |
+
v = r.get(key)
|
| 93 |
+
if v is not None and v != "" and not (isinstance(v, float) and np.isnan(v)):
|
| 94 |
+
try:
|
| 95 |
+
iters.append(int(r["iteration"]))
|
| 96 |
+
vals.append(float(v))
|
| 97 |
+
except (TypeError, ValueError):
|
| 98 |
+
pass
|
| 99 |
+
return iters, vals
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# ββ Individual plots βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 103 |
+
|
| 104 |
+
def plot_training_objective(rows: List[Dict], out: Path) -> None:
|
| 105 |
+
"""Plot 01: combined_score β the single most important demo plot."""
|
| 106 |
+
xi, xv = _field(rows, "combined_score")
|
| 107 |
+
if not xi:
|
| 108 |
+
return
|
| 109 |
+
|
| 110 |
+
fig, ax = plt.subplots(figsize=(9, 5))
|
| 111 |
+
ax.plot(xi, xv, color=PALETTE["combined"], linewidth=2.5,
|
| 112 |
+
marker="o", markersize=5, label="Training-objective score")
|
| 113 |
+
ax.fill_between(xi, xv, alpha=0.12, color=PALETTE["combined"])
|
| 114 |
+
|
| 115 |
+
# annotate first and last eval points
|
| 116 |
+
ax.annotate(f"{xv[0]:.3f}", (xi[0], xv[0]), textcoords="offset points",
|
| 117 |
+
xytext=(8, 6), fontsize=8, color=PALETTE["combined"])
|
| 118 |
+
ax.annotate(f"{xv[-1]:.3f}", (xi[-1], xv[-1]), textcoords="offset points",
|
| 119 |
+
xytext=(8, 6), fontsize=8, color=PALETTE["combined"])
|
| 120 |
+
|
| 121 |
+
ax.set_xlabel("Iteration")
|
| 122 |
+
ax.set_ylabel("Score (0 β 1)")
|
| 123 |
+
ax.set_title(
|
| 124 |
+
"GRPO Training β Combined Reward Score\n"
|
| 125 |
+
"0.60 Γ correct + 0.15 Γ PRM + 0.15 Γ SymPy + 0.10 Γ format",
|
| 126 |
+
fontsize=12,
|
| 127 |
+
)
|
| 128 |
+
ax.set_ylim(0, 1.05)
|
| 129 |
+
ax.yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1, decimals=0))
|
| 130 |
+
ax.legend(loc="lower right")
|
| 131 |
+
fig.tight_layout()
|
| 132 |
+
fig.savefig(out)
|
| 133 |
+
plt.close(fig)
|
| 134 |
+
print(f" saved {out.name}")
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def plot_reward_components(rows: List[Dict], out: Path) -> None:
|
| 138 |
+
"""Plot 02: four-panel breakdown of each reward component."""
|
| 139 |
+
specs = [
|
| 140 |
+
("correct_rate", "correct", "Correctness (gt_match)", "60 %"),
|
| 141 |
+
("prm_mean", "prm", "PRM Step Quality", "15 %"),
|
| 142 |
+
("sympy_mean", "sympy", "SymPy Verification", "15 %"),
|
| 143 |
+
("format_mean", "fmt", "Format Compliance", "10 %"),
|
| 144 |
+
]
|
| 145 |
+
|
| 146 |
+
fig, axes = plt.subplots(2, 2, figsize=(12, 7), sharex=False)
|
| 147 |
+
axes = axes.flatten()
|
| 148 |
+
|
| 149 |
+
for ax, (key, pal, title, weight) in zip(axes, specs):
|
| 150 |
+
xi, xv = _field(rows, key)
|
| 151 |
+
if not xi:
|
| 152 |
+
ax.set_visible(False)
|
| 153 |
+
continue
|
| 154 |
+
ax.plot(xi, xv, color=PALETTE[pal], linewidth=2,
|
| 155 |
+
marker="o", markersize=4)
|
| 156 |
+
ax.fill_between(xi, xv, alpha=0.12, color=PALETTE[pal])
|
| 157 |
+
ax.set_title(f"{title} (weight {weight})", fontsize=11)
|
| 158 |
+
ax.set_xlabel("Iteration")
|
| 159 |
+
ax.set_ylabel("Score")
|
| 160 |
+
ax.set_ylim(0, 1.05)
|
| 161 |
+
ax.yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1, decimals=0))
|
| 162 |
+
|
| 163 |
+
if xv:
|
| 164 |
+
delta = xv[-1] - xv[0]
|
| 165 |
+
sign = "+" if delta >= 0 else ""
|
| 166 |
+
ax.set_title(
|
| 167 |
+
f"{title} (weight {weight}) Ξ={sign}{delta:+.1%}",
|
| 168 |
+
fontsize=10,
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
fig.suptitle("Reward Component Breakdown over Training", fontsize=13, y=1.01)
|
| 172 |
+
fig.tight_layout()
|
| 173 |
+
fig.savefig(out, bbox_inches="tight")
|
| 174 |
+
plt.close(fig)
|
| 175 |
+
print(f" saved {out.name}")
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def plot_training_dynamics(rows: List[Dict], out: Path) -> None:
|
| 179 |
+
"""Plot 03: loss, mean_reward, batch_accuracy over all iterations."""
|
| 180 |
+
li, lv = _field(rows, "loss")
|
| 181 |
+
ri, rv = _field(rows, "mean_reward")
|
| 182 |
+
bi, bv = _field(rows, "batch_accuracy")
|
| 183 |
+
|
| 184 |
+
fig, axes = plt.subplots(3, 1, figsize=(10, 8), sharex=True)
|
| 185 |
+
|
| 186 |
+
if lv:
|
| 187 |
+
axes[0].plot(li, lv, color=PALETTE["loss"], linewidth=1.8)
|
| 188 |
+
axes[0].fill_between(li, lv, alpha=0.1, color=PALETTE["loss"])
|
| 189 |
+
axes[0].set_ylabel("GRPO Loss")
|
| 190 |
+
axes[0].set_title("Training Loss", fontsize=11)
|
| 191 |
+
axes[0].axhline(0, color="black", linewidth=0.8, linestyle="--", alpha=0.4)
|
| 192 |
+
|
| 193 |
+
if rv:
|
| 194 |
+
axes[1].plot(ri, rv, color=PALETTE["reward"], linewidth=1.8)
|
| 195 |
+
axes[1].fill_between(ri, rv, alpha=0.1, color=PALETTE["reward"])
|
| 196 |
+
axes[1].set_ylabel("Reward")
|
| 197 |
+
axes[1].set_ylim(0, 1.05)
|
| 198 |
+
axes[1].set_title("Mean Batch Reward", fontsize=11)
|
| 199 |
+
axes[1].yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1, decimals=0))
|
| 200 |
+
|
| 201 |
+
if bv:
|
| 202 |
+
axes[2].plot(bi, bv, color=PALETTE["batch_acc"], linewidth=1.8)
|
| 203 |
+
axes[2].fill_between(bi, bv, alpha=0.1, color=PALETTE["batch_acc"])
|
| 204 |
+
axes[2].set_ylabel("Accuracy")
|
| 205 |
+
axes[2].set_ylim(0, 1.05)
|
| 206 |
+
axes[2].set_title("Batch Accuracy (training rollouts)", fontsize=11)
|
| 207 |
+
axes[2].yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1, decimals=0))
|
| 208 |
+
|
| 209 |
+
for ax in axes:
|
| 210 |
+
ax.set_xlabel("Iteration")
|
| 211 |
+
|
| 212 |
+
fig.suptitle("GRPO Training Dynamics", fontsize=13)
|
| 213 |
+
fig.tight_layout()
|
| 214 |
+
fig.savefig(out)
|
| 215 |
+
plt.close(fig)
|
| 216 |
+
print(f" saved {out.name}")
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def plot_reward_vs_eval(rows: List[Dict], out: Path) -> None:
|
| 220 |
+
"""Plot 04: mean_reward (all iters) + combined_score (eval iters) overlaid."""
|
| 221 |
+
ri, rv = _field(rows, "mean_reward")
|
| 222 |
+
ei, ev = _field(rows, "combined_score")
|
| 223 |
+
|
| 224 |
+
fig, ax = plt.subplots(figsize=(10, 5))
|
| 225 |
+
|
| 226 |
+
if rv:
|
| 227 |
+
ax.plot(ri, rv, color=PALETTE["reward"], linewidth=1.4, alpha=0.7,
|
| 228 |
+
label="Batch reward (training)")
|
| 229 |
+
ax.fill_between(ri, rv, alpha=0.06, color=PALETTE["reward"])
|
| 230 |
+
|
| 231 |
+
if ev:
|
| 232 |
+
ax.plot(ei, ev, color=PALETTE["combined"], linewidth=2.5,
|
| 233 |
+
marker="D", markersize=6, label="Eval score (held-out GSM8K)")
|
| 234 |
+
for x, y in zip(ei, ev):
|
| 235 |
+
ax.annotate(f"{y:.3f}", (x, y), textcoords="offset points",
|
| 236 |
+
xytext=(0, 8), ha="center", fontsize=7,
|
| 237 |
+
color=PALETTE["combined"])
|
| 238 |
+
|
| 239 |
+
ax.set_xlabel("Iteration")
|
| 240 |
+
ax.set_ylabel("Score (0 β 1)")
|
| 241 |
+
ax.set_ylim(0, 1.05)
|
| 242 |
+
ax.yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1, decimals=0))
|
| 243 |
+
ax.set_title("Training Reward vs Held-Out Eval Score", fontsize=12)
|
| 244 |
+
ax.legend()
|
| 245 |
+
fig.tight_layout()
|
| 246 |
+
fig.savefig(out)
|
| 247 |
+
plt.close(fig)
|
| 248 |
+
print(f" saved {out.name}")
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def plot_component_area(rows: List[Dict], out: Path) -> None:
|
| 252 |
+
"""Plot 05: stacked-area of the four WEIGHTED components summing to combined_score."""
|
| 253 |
+
ei, ev_combined = _field(rows, "combined_score")
|
| 254 |
+
if not ei:
|
| 255 |
+
return
|
| 256 |
+
|
| 257 |
+
# Build per-component weighted series aligned to eval iterations
|
| 258 |
+
iter_set = set(ei)
|
| 259 |
+
aligned: Dict[str, List[float]] = {k: [] for k in ("correct", "prm", "sympy", "fmt")}
|
| 260 |
+
weights = {"correct": 0.60, "prm": 0.15, "sympy": 0.15, "fmt": 0.10}
|
| 261 |
+
keys = {"correct": "correct_rate", "prm": "prm_mean",
|
| 262 |
+
"sympy": "sympy_mean", "fmt": "format_mean"}
|
| 263 |
+
|
| 264 |
+
# Build lookup per iteration
|
| 265 |
+
it_map: Dict[int, Dict] = {r["iteration"]: r for r in rows if r["iteration"] in iter_set}
|
| 266 |
+
iters_sorted = sorted(iter_set)
|
| 267 |
+
|
| 268 |
+
for it in iters_sorted:
|
| 269 |
+
row = it_map.get(it, {})
|
| 270 |
+
for comp, field in keys.items():
|
| 271 |
+
v = row.get(field)
|
| 272 |
+
if v is not None and v != "":
|
| 273 |
+
aligned[comp].append(float(v) * weights[comp])
|
| 274 |
+
else:
|
| 275 |
+
aligned[comp].append(0.0)
|
| 276 |
+
|
| 277 |
+
x = np.array(iters_sorted)
|
| 278 |
+
arr = np.array([aligned["correct"], aligned["prm"],
|
| 279 |
+
aligned["sympy"], aligned["fmt"]])
|
| 280 |
+
|
| 281 |
+
fig, ax = plt.subplots(figsize=(10, 5))
|
| 282 |
+
labels = ["Correct (Γ0.60)", "PRM (Γ0.15)", "SymPy (Γ0.15)", "Format (Γ0.10)"]
|
| 283 |
+
colors = [PALETTE[k] for k in ("correct", "prm", "sympy", "fmt")]
|
| 284 |
+
ax.stackplot(x, arr, labels=labels, colors=colors, alpha=0.75)
|
| 285 |
+
|
| 286 |
+
ax.plot(x, ev_combined, color="black", linewidth=1.5,
|
| 287 |
+
linestyle="--", label="Combined score", zorder=5)
|
| 288 |
+
|
| 289 |
+
ax.set_xlabel("Iteration")
|
| 290 |
+
ax.set_ylabel("Weighted contribution to score")
|
| 291 |
+
ax.set_ylim(0, 1.0)
|
| 292 |
+
ax.yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1, decimals=0))
|
| 293 |
+
ax.set_title("Contribution of Each Reward Component (Stacked)", fontsize=12)
|
| 294 |
+
ax.legend(loc="lower right", ncol=2)
|
| 295 |
+
fig.tight_layout()
|
| 296 |
+
fig.savefig(out)
|
| 297 |
+
plt.close(fig)
|
| 298 |
+
print(f" saved {out.name}")
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def plot_summary_card(rows: List[Dict], run_name: str, out: Path) -> None:
|
| 302 |
+
"""Plot 06: all key metrics on a single clean card β ideal for poster / slide."""
|
| 303 |
+
ei, ev = _field(rows, "combined_score")
|
| 304 |
+
_, crv = _field(rows, "correct_rate")
|
| 305 |
+
_, prmv = _field(rows, "prm_mean")
|
| 306 |
+
_, syv = _field(rows, "sympy_mean")
|
| 307 |
+
_, fmv = _field(rows, "format_mean")
|
| 308 |
+
_, lv = _field(rows, "loss")
|
| 309 |
+
_, rv = _field(rows, "mean_reward")
|
| 310 |
+
li = _field(rows, "loss")[0]
|
| 311 |
+
ri = _field(rows, "mean_reward")[0]
|
| 312 |
+
|
| 313 |
+
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
|
| 314 |
+
axes = axes.flatten()
|
| 315 |
+
|
| 316 |
+
def _panel(ax, iters, vals, color, title, pct=True):
|
| 317 |
+
if not iters:
|
| 318 |
+
ax.set_visible(False)
|
| 319 |
+
return
|
| 320 |
+
ax.plot(iters, vals, color=color, linewidth=2, marker="o", markersize=4)
|
| 321 |
+
ax.fill_between(iters, vals, alpha=0.12, color=color)
|
| 322 |
+
ax.set_title(title, fontsize=11, fontweight="bold")
|
| 323 |
+
ax.set_xlabel("Iteration", fontsize=9)
|
| 324 |
+
if pct:
|
| 325 |
+
ax.set_ylim(0, 1.05)
|
| 326 |
+
ax.yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1, decimals=0))
|
| 327 |
+
if vals:
|
| 328 |
+
ax.annotate(f"{vals[-1]:.3f}", (iters[-1], vals[-1]),
|
| 329 |
+
textcoords="offset points", xytext=(6, 4),
|
| 330 |
+
fontsize=8, color=color)
|
| 331 |
+
|
| 332 |
+
_panel(axes[0], ei, ev, PALETTE["combined"], "Training-Objective Score")
|
| 333 |
+
_panel(axes[1], ei, crv, PALETTE["correct"], "Correctness Rate")
|
| 334 |
+
_panel(axes[2], ei, prmv, PALETTE["prm"], "PRM Step Quality")
|
| 335 |
+
_panel(axes[3], ei, syv, PALETTE["sympy"], "SymPy Verification")
|
| 336 |
+
_panel(axes[4], ei, fmv, PALETTE["fmt"], "Format Compliance")
|
| 337 |
+
_panel(axes[5], li, lv, PALETTE["loss"], "GRPO Loss", pct=False)
|
| 338 |
+
|
| 339 |
+
fig.suptitle(f"GRPO Training Summary β {run_name}", fontsize=14, fontweight="bold")
|
| 340 |
+
fig.tight_layout()
|
| 341 |
+
fig.savefig(out, bbox_inches="tight")
|
| 342 |
+
plt.close(fig)
|
| 343 |
+
print(f" saved {out.name}")
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
# ββ CLI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 347 |
+
|
| 348 |
+
def find_latest_metrics() -> Optional[Path]:
|
| 349 |
+
"""Find the most recently modified metrics.jsonl under checkpoints/grpo/."""
|
| 350 |
+
ckpt = Path("checkpoints/grpo")
|
| 351 |
+
if not ckpt.exists():
|
| 352 |
+
return None
|
| 353 |
+
candidates = sorted(
|
| 354 |
+
ckpt.rglob("metrics.jsonl"),
|
| 355 |
+
key=lambda p: p.stat().st_mtime,
|
| 356 |
+
)
|
| 357 |
+
return candidates[-1] if candidates else None
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
def generate_plots(metrics_path: Path, out_dir: Optional[Path] = None) -> Path:
|
| 361 |
+
"""Generate all six plots and return the output directory."""
|
| 362 |
+
rows = _load(metrics_path)
|
| 363 |
+
if not rows:
|
| 364 |
+
print(f"[plot] No data in {metrics_path}", file=sys.stderr)
|
| 365 |
+
return metrics_path.parent
|
| 366 |
+
|
| 367 |
+
out_dir = out_dir or metrics_path.parent / "plots"
|
| 368 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 369 |
+
|
| 370 |
+
# Derive run name from the directory name two levels up
|
| 371 |
+
run_name = metrics_path.parent.name
|
| 372 |
+
|
| 373 |
+
print(f"[plot] Generating plots for run '{run_name}' ({len(rows)} iterations)")
|
| 374 |
+
print(f"[plot] Output β {out_dir}")
|
| 375 |
+
|
| 376 |
+
plot_training_objective(rows, out_dir / "01_training_objective.png")
|
| 377 |
+
plot_reward_components(rows, out_dir / "02_reward_components.png")
|
| 378 |
+
plot_training_dynamics(rows, out_dir / "03_training_dynamics.png")
|
| 379 |
+
plot_reward_vs_eval(rows, out_dir / "04_reward_vs_eval.png")
|
| 380 |
+
plot_component_area(rows, out_dir / "05_component_area.png")
|
| 381 |
+
plot_summary_card(rows, run_name, out_dir / "06_summary_card.png")
|
| 382 |
+
|
| 383 |
+
print(f"[plot] Done β {len(list(out_dir.glob('*.png')))} PNGs in {out_dir}")
|
| 384 |
+
return out_dir
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
def main() -> None:
|
| 388 |
+
parser = argparse.ArgumentParser(
|
| 389 |
+
description="Generate demo plots from a GRPO metrics.jsonl file."
|
| 390 |
+
)
|
| 391 |
+
parser.add_argument(
|
| 392 |
+
"metrics_jsonl", nargs="?", type=Path, default=None,
|
| 393 |
+
help="Path to metrics.jsonl produced by run_grpo_training.py",
|
| 394 |
+
)
|
| 395 |
+
parser.add_argument(
|
| 396 |
+
"--latest", action="store_true",
|
| 397 |
+
help="Auto-discover the most recent metrics.jsonl under checkpoints/grpo/",
|
| 398 |
+
)
|
| 399 |
+
parser.add_argument(
|
| 400 |
+
"--out-dir", type=Path, default=None,
|
| 401 |
+
help="Directory to write PNG files (default: <metrics_dir>/plots/)",
|
| 402 |
+
)
|
| 403 |
+
args = parser.parse_args()
|
| 404 |
+
|
| 405 |
+
if args.latest:
|
| 406 |
+
path = find_latest_metrics()
|
| 407 |
+
if path is None:
|
| 408 |
+
print("No metrics.jsonl found under checkpoints/grpo/", file=sys.stderr)
|
| 409 |
+
sys.exit(1)
|
| 410 |
+
print(f"[plot] Auto-selected {path}")
|
| 411 |
+
elif args.metrics_jsonl:
|
| 412 |
+
path = args.metrics_jsonl
|
| 413 |
+
else:
|
| 414 |
+
parser.print_help()
|
| 415 |
+
sys.exit(1)
|
| 416 |
+
|
| 417 |
+
if not path.exists():
|
| 418 |
+
print(f"File not found: {path}", file=sys.stderr)
|
| 419 |
+
sys.exit(1)
|
| 420 |
+
|
| 421 |
+
generate_plots(path, args.out_dir)
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
if __name__ == "__main__":
|
| 425 |
+
main()
|
scripts/plot_training_results.py
ADDED
|
@@ -0,0 +1,521 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
AxiomForgeAI β Training Results Plots
|
| 4 |
+
======================================
|
| 5 |
+
Reads the metrics CSV from a GRPO training run and generates five focused plots
|
| 6 |
+
that tell the story of what improved, how self-play was earned, and why step-level
|
| 7 |
+
reasoning quality matters as much as final-answer accuracy.
|
| 8 |
+
|
| 9 |
+
All plots are saved to images/ as high-resolution PNGs.
|
| 10 |
+
|
| 11 |
+
Usage
|
| 12 |
+
-----
|
| 13 |
+
python scripts/plot_training_results.py
|
| 14 |
+
python scripts/plot_training_results.py --metrics logs/grpo/grpo_20260426_032827/metrics.csv
|
| 15 |
+
python scripts/plot_training_results.py --out images/
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
import argparse
|
| 21 |
+
import csv
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from typing import Dict, List
|
| 24 |
+
|
| 25 |
+
import matplotlib
|
| 26 |
+
matplotlib.use("Agg")
|
| 27 |
+
import matplotlib.pyplot as plt
|
| 28 |
+
import matplotlib.patches as mpatches
|
| 29 |
+
import numpy as np
|
| 30 |
+
|
| 31 |
+
# ββ Style ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 32 |
+
PALETTE = {
|
| 33 |
+
"indigo": "#6366f1",
|
| 34 |
+
"pink": "#ec4899",
|
| 35 |
+
"cyan": "#06b6d4",
|
| 36 |
+
"amber": "#f59e0b",
|
| 37 |
+
"emerald": "#10b981",
|
| 38 |
+
"slate": "#94a3b8",
|
| 39 |
+
"red": "#ef4444",
|
| 40 |
+
"violet": "#8b5cf6",
|
| 41 |
+
"white": "#f8fafc",
|
| 42 |
+
"bg": "#0f172a",
|
| 43 |
+
"bg2": "#1e293b",
|
| 44 |
+
"gridline": "#1e293b",
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
plt.rcParams.update({
|
| 48 |
+
"figure.facecolor": PALETTE["bg"],
|
| 49 |
+
"axes.facecolor": PALETTE["bg"],
|
| 50 |
+
"axes.edgecolor": PALETTE["slate"],
|
| 51 |
+
"axes.labelcolor": PALETTE["white"],
|
| 52 |
+
"axes.titlecolor": PALETTE["white"],
|
| 53 |
+
"axes.titlesize": 13,
|
| 54 |
+
"axes.labelsize": 11,
|
| 55 |
+
"axes.grid": True,
|
| 56 |
+
"grid.color": "#1e293b",
|
| 57 |
+
"grid.linewidth": 0.8,
|
| 58 |
+
"xtick.color": PALETTE["slate"],
|
| 59 |
+
"ytick.color": PALETTE["slate"],
|
| 60 |
+
"xtick.labelsize": 9,
|
| 61 |
+
"ytick.labelsize": 9,
|
| 62 |
+
"legend.facecolor": "#1e293b",
|
| 63 |
+
"legend.edgecolor": PALETTE["slate"],
|
| 64 |
+
"legend.labelcolor": PALETTE["white"],
|
| 65 |
+
"legend.fontsize": 9,
|
| 66 |
+
"text.color": PALETTE["white"],
|
| 67 |
+
"font.family": "sans-serif",
|
| 68 |
+
"lines.linewidth": 2.0,
|
| 69 |
+
})
|
| 70 |
+
|
| 71 |
+
PHASE_COLORS = {
|
| 72 |
+
"GROUNDED_ONLY": ("#6366f120", "#6366f1"),
|
| 73 |
+
"SELFPLAY_RAMP": ("#10b98120", "#10b981"),
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
DPI = 160
|
| 77 |
+
IMAGES_DIR = Path("images")
|
| 78 |
+
|
| 79 |
+
DEFAULT_METRICS = (
|
| 80 |
+
"logs/grpo/grpo_20260426_032827/metrics.csv"
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# ββ Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 85 |
+
|
| 86 |
+
def load_csv(path: str) -> List[Dict]:
|
| 87 |
+
rows = []
|
| 88 |
+
with open(path, encoding="utf-8") as f:
|
| 89 |
+
for r in csv.DictReader(f):
|
| 90 |
+
rows.append({k: v for k, v in r.items()})
|
| 91 |
+
return rows
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def f(row: Dict, key: str, default: float = float("nan")) -> float:
|
| 95 |
+
v = row.get(key, "")
|
| 96 |
+
try:
|
| 97 |
+
return float(v) if v != "" else default
|
| 98 |
+
except (ValueError, TypeError):
|
| 99 |
+
return default
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def moving_avg(values: List[float], w: int = 3) -> List[float]:
|
| 103 |
+
result = []
|
| 104 |
+
for i in range(len(values)):
|
| 105 |
+
lo = max(0, i - w + 1)
|
| 106 |
+
chunk = [v for v in values[lo : i + 1] if not np.isnan(v)]
|
| 107 |
+
result.append(float(np.mean(chunk)) if chunk else float("nan"))
|
| 108 |
+
return result
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def shade_phases(ax, iters, phases):
|
| 112 |
+
"""Draw translucent background rectangles for each training phase."""
|
| 113 |
+
prev_phase, start = None, iters[0]
|
| 114 |
+
for it, ph in zip(iters, phases):
|
| 115 |
+
if ph != prev_phase:
|
| 116 |
+
if prev_phase is not None:
|
| 117 |
+
bg, _ = PHASE_COLORS.get(prev_phase, ("#ffffff10", "#ffffff"))
|
| 118 |
+
ax.axvspan(start - 0.5, it - 0.5, facecolor=bg, linewidth=0, zorder=0)
|
| 119 |
+
prev_phase, start = ph, it
|
| 120 |
+
if prev_phase is not None:
|
| 121 |
+
bg, _ = PHASE_COLORS.get(prev_phase, ("#ffffff10", "#ffffff"))
|
| 122 |
+
ax.axvspan(start - 0.5, iters[-1] + 0.5, facecolor=bg, linewidth=0, zorder=0)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def phase_legend_patches(phases):
|
| 126 |
+
seen = []
|
| 127 |
+
patches = []
|
| 128 |
+
for ph in phases:
|
| 129 |
+
if ph not in seen:
|
| 130 |
+
seen.append(ph)
|
| 131 |
+
_, edge = PHASE_COLORS.get(ph, ("#ffffff10", "#ffffff"))
|
| 132 |
+
label = ph.replace("_", " ").title()
|
| 133 |
+
patches.append(mpatches.Patch(facecolor=edge + "40", edgecolor=edge,
|
| 134 |
+
linewidth=1.2, label=label))
|
| 135 |
+
return patches
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def annotate_transition(ax, x_iter, label, ypos=0.97, color="#94a3b8"):
|
| 139 |
+
ax.axvline(x=x_iter - 0.5, color=color, linewidth=1, linestyle="--", alpha=0.7)
|
| 140 |
+
ax.text(x_iter, ypos, label, transform=ax.get_xaxis_transform(),
|
| 141 |
+
fontsize=7.5, color=color, ha="left", va="top",
|
| 142 |
+
bbox=dict(facecolor=PALETTE["bg2"], edgecolor="none", pad=2))
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def save(fig: plt.Figure, name: str, out: Path):
|
| 146 |
+
out.mkdir(parents=True, exist_ok=True)
|
| 147 |
+
path = out / name
|
| 148 |
+
fig.savefig(path, dpi=DPI, bbox_inches="tight", facecolor=fig.get_facecolor())
|
| 149 |
+
print(f" β {path}")
|
| 150 |
+
plt.close(fig)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 154 |
+
# PLOT 1 β Hero: Reasoning quality at evaluation checkpoints
|
| 155 |
+
# Shows four signals together: GSM8K accuracy, combined score, step accuracy,
|
| 156 |
+
# and LCCP. The message: the model doesn't just get more answers right β
|
| 157 |
+
# every step of the reasoning chain gets better.
|
| 158 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 159 |
+
|
| 160 |
+
def plot_eval_quality(rows: List[Dict], out: Path):
|
| 161 |
+
eval_rows = [r for r in rows if r.get("eval_combined", "") != ""]
|
| 162 |
+
iters = [int(r["iteration"]) for r in eval_rows]
|
| 163 |
+
|
| 164 |
+
gsm8k_acc = [f(r, "eval_correct_rt") * 100 for r in eval_rows]
|
| 165 |
+
combined = [f(r, "eval_combined") * 100 for r in eval_rows]
|
| 166 |
+
step_acc = [f(r, "eval_step_acc") * 100 for r in eval_rows]
|
| 167 |
+
lccp = [f(r, "eval_lccp") * 100 for r in eval_rows]
|
| 168 |
+
prm = [f(r, "eval_prm") * 100 for r in eval_rows]
|
| 169 |
+
|
| 170 |
+
fig, ax = plt.subplots(figsize=(9, 5))
|
| 171 |
+
fig.suptitle("Evaluation Quality Over Training β AxiomForgeAI",
|
| 172 |
+
fontsize=14, fontweight="bold", color=PALETTE["white"], y=1.01)
|
| 173 |
+
|
| 174 |
+
# --- lines
|
| 175 |
+
ax.plot(iters, gsm8k_acc, "o-", color=PALETTE["pink"], label="GSM8K Accuracy (final answer)", ms=7, zorder=5)
|
| 176 |
+
ax.plot(iters, combined, "s-", color=PALETTE["indigo"], label="Combined Score", ms=6, zorder=5)
|
| 177 |
+
ax.plot(iters, step_acc, "^-", color=PALETTE["cyan"], label="Step Accuracy (reasoning chain)", ms=6, zorder=5)
|
| 178 |
+
ax.plot(iters, lccp, "D-", color=PALETTE["emerald"], label="LCCP (chain integrity)", ms=6, zorder=5)
|
| 179 |
+
ax.plot(iters, prm, "v--", color=PALETTE["amber"], label="PRM Mean Score", ms=5, alpha=0.8, zorder=4)
|
| 180 |
+
|
| 181 |
+
# annotate best GSM8K
|
| 182 |
+
best_gsm = max(gsm8k_acc)
|
| 183 |
+
bi = gsm8k_acc.index(best_gsm)
|
| 184 |
+
ax.annotate(f" {best_gsm:.1f}%",
|
| 185 |
+
xy=(iters[bi], best_gsm), fontsize=9, color=PALETTE["pink"],
|
| 186 |
+
va="bottom", ha="left")
|
| 187 |
+
|
| 188 |
+
# annotate best combined
|
| 189 |
+
best_c = max(combined)
|
| 190 |
+
bci = combined.index(best_c)
|
| 191 |
+
ax.annotate(f" {best_c:.1f}",
|
| 192 |
+
xy=(iters[bci], best_c), fontsize=9, color=PALETTE["indigo"],
|
| 193 |
+
va="top", ha="left")
|
| 194 |
+
|
| 195 |
+
ax.set_xlabel("Training Iteration")
|
| 196 |
+
ax.set_ylabel("Score (%)")
|
| 197 |
+
ax.set_xticks(iters)
|
| 198 |
+
ax.set_ylim(78, 96)
|
| 199 |
+
ax.yaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter("%.0f%%"))
|
| 200 |
+
ax.legend(loc="lower right", framealpha=0.8)
|
| 201 |
+
ax.set_title(
|
| 202 |
+
"Four angles on quality β answer correctness, holistic score, per-step reasoning, and chain integrity",
|
| 203 |
+
fontsize=9, color=PALETTE["slate"], pad=6,
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
fig.tight_layout()
|
| 207 |
+
save(fig, "plot1_eval_quality.png", out)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 211 |
+
# PLOT 2 β Training Journey: full 30-iteration timeline with phase shading
|
| 212 |
+
# Shows mean reward, GT match rate, and step accuracy over every iteration.
|
| 213 |
+
# Phase backgrounds show when self-play unlocked and the curriculum ramped.
|
| 214 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 215 |
+
|
| 216 |
+
def plot_training_journey(rows: List[Dict], out: Path):
|
| 217 |
+
iters = [int(r["iteration"]) for r in rows]
|
| 218 |
+
phases = [r["training_phase"] for r in rows]
|
| 219 |
+
mean_r = [f(r, "mean_reward") * 100 for r in rows]
|
| 220 |
+
gt_match = [f(r, "gt_match_rate") * 100 for r in rows]
|
| 221 |
+
step_acc = [f(r, "step_accuracy") * 100 for r in rows]
|
| 222 |
+
batch_acc = [f(r, "batch_accuracy") * 100 for r in rows]
|
| 223 |
+
|
| 224 |
+
ma_reward = moving_avg(mean_r, w=4)
|
| 225 |
+
ma_gt = moving_avg(gt_match, w=4)
|
| 226 |
+
ma_step = moving_avg(step_acc, w=4)
|
| 227 |
+
|
| 228 |
+
fig, ax = plt.subplots(figsize=(11, 5))
|
| 229 |
+
shade_phases(ax, iters, phases)
|
| 230 |
+
|
| 231 |
+
# raw (faint)
|
| 232 |
+
ax.plot(iters, mean_r, alpha=0.25, color=PALETTE["indigo"], linewidth=1)
|
| 233 |
+
ax.plot(iters, gt_match, alpha=0.25, color=PALETTE["pink"], linewidth=1)
|
| 234 |
+
ax.plot(iters, step_acc, alpha=0.25, color=PALETTE["cyan"], linewidth=1)
|
| 235 |
+
|
| 236 |
+
# smoothed (bold)
|
| 237 |
+
ax.plot(iters, ma_reward, color=PALETTE["indigo"], linewidth=2.5, label="Mean Reward (smooth)")
|
| 238 |
+
ax.plot(iters, ma_gt, color=PALETTE["pink"], linewidth=2.5, label="GT Match Rate (smooth)")
|
| 239 |
+
ax.plot(iters, ma_step, color=PALETTE["cyan"], linewidth=2.5, label="Step Accuracy (smooth)")
|
| 240 |
+
|
| 241 |
+
# self-play transition annotation
|
| 242 |
+
sp_start = next(i for i, p in enumerate(phases) if p == "SELFPLAY_RAMP")
|
| 243 |
+
annotate_transition(ax, iters[sp_start], "Self-play\nunlocked", ypos=0.98,
|
| 244 |
+
color=PALETTE["emerald"])
|
| 245 |
+
|
| 246 |
+
ax.set_xlabel("Training Iteration")
|
| 247 |
+
ax.set_ylabel("Score (%)")
|
| 248 |
+
ax.set_ylim(55, 105)
|
| 249 |
+
ax.yaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter("%.0f%%"))
|
| 250 |
+
ax.set_xticks(range(1, max(iters) + 1, 2))
|
| 251 |
+
ax.set_title("30-Iteration GRPO Training Timeline | Faint = raw Β· Bold = 4-iter moving average",
|
| 252 |
+
fontsize=9, color=PALETTE["slate"], pad=6)
|
| 253 |
+
fig.suptitle("Training Journey β Reward, GT Match & Step Accuracy",
|
| 254 |
+
fontsize=14, fontweight="bold", color=PALETTE["white"], y=1.01)
|
| 255 |
+
|
| 256 |
+
legend_patches = phase_legend_patches(phases)
|
| 257 |
+
h, l = ax.get_legend_handles_labels()
|
| 258 |
+
ax.legend(handles=h + legend_patches, loc="lower right", framealpha=0.8, ncol=2)
|
| 259 |
+
|
| 260 |
+
fig.tight_layout()
|
| 261 |
+
save(fig, "plot2_training_journey.png", out)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 265 |
+
# PLOT 3 β Self-Play Success: the curriculum earning its right to generate
|
| 266 |
+
# Shows the self-play ratio ramping up while question quality stays high.
|
| 267 |
+
# The headline: by iteration 30 more than 60% of training is model-generated,
|
| 268 |
+
# and those questions are 95-100% solvable and genuinely novel.
|
| 269 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 270 |
+
|
| 271 |
+
def plot_selfplay_success(rows: List[Dict], out: Path):
|
| 272 |
+
sp_rows = [r for r in rows if f(r, "q_reward") > 0]
|
| 273 |
+
iters = [int(r["iteration"]) for r in sp_rows]
|
| 274 |
+
sp_rat = [f(r, "sp_ratio") * 100 for r in sp_rows]
|
| 275 |
+
q_sol = [f(r, "q_solvability") * 100 for r in sp_rows]
|
| 276 |
+
q_nov = [f(r, "q_novelty") * 100 for r in sp_rows]
|
| 277 |
+
q_rew = [f(r, "q_reward") * 100 for r in sp_rows]
|
| 278 |
+
|
| 279 |
+
fig, ax1 = plt.subplots(figsize=(10, 5))
|
| 280 |
+
ax2 = ax1.twinx()
|
| 281 |
+
ax2.tick_params(axis="y", labelcolor=PALETTE["slate"])
|
| 282 |
+
ax2.spines["right"].set_color(PALETTE["slate"])
|
| 283 |
+
|
| 284 |
+
# self-play ramp (left axis)
|
| 285 |
+
ax1.fill_between(iters, sp_rat, alpha=0.18, color=PALETTE["emerald"])
|
| 286 |
+
ax1.plot(iters, sp_rat, "o-", color=PALETTE["emerald"], ms=6,
|
| 287 |
+
label="Self-play ratio", linewidth=2.5)
|
| 288 |
+
ax1.set_ylabel("Self-play share of training (%)", color=PALETTE["emerald"])
|
| 289 |
+
ax1.tick_params(axis="y", labelcolor=PALETTE["emerald"])
|
| 290 |
+
ax1.set_ylim(0, 80)
|
| 291 |
+
|
| 292 |
+
# question quality (right axis)
|
| 293 |
+
ax2.plot(iters, q_sol, "s--", color=PALETTE["cyan"], ms=5, label="Solvability", linewidth=1.8)
|
| 294 |
+
ax2.plot(iters, q_nov, "^--", color=PALETTE["amber"], ms=5, label="Novelty", linewidth=1.8)
|
| 295 |
+
ax2.plot(iters, q_rew, "D--", color=PALETTE["pink"], ms=5, label="Q-Reward", linewidth=1.8)
|
| 296 |
+
ax2.set_ylabel("Question quality score (%)", color=PALETTE["slate"])
|
| 297 |
+
ax2.set_ylim(0, 115)
|
| 298 |
+
|
| 299 |
+
# merge legends
|
| 300 |
+
h1, l1 = ax1.get_legend_handles_labels()
|
| 301 |
+
h2, l2 = ax2.get_legend_handles_labels()
|
| 302 |
+
ax1.legend(h1 + h2, l1 + l2, loc="upper left", framealpha=0.8)
|
| 303 |
+
|
| 304 |
+
ax1.set_xlabel("Training Iteration")
|
| 305 |
+
ax1.set_xticks(iters)
|
| 306 |
+
ax1.yaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter("%.0f%%"))
|
| 307 |
+
ax2.yaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter("%.0f%%"))
|
| 308 |
+
|
| 309 |
+
# annotate final sp ratio
|
| 310 |
+
ax1.annotate(f" {sp_rat[-1]:.0f}% self-play\n by iter {iters[-1]}",
|
| 311 |
+
xy=(iters[-1], sp_rat[-1]), fontsize=9, color=PALETTE["emerald"],
|
| 312 |
+
va="center", ha="left")
|
| 313 |
+
|
| 314 |
+
fig.suptitle("Self-Play Curriculum β The Model Earns Its Own Training Data",
|
| 315 |
+
fontsize=14, fontweight="bold", color=PALETTE["white"], y=1.01)
|
| 316 |
+
ax1.set_title(
|
| 317 |
+
"Self-play ratio ramps from 0 β 61% Β· Generated questions stay 93-100% solvable throughout",
|
| 318 |
+
fontsize=9, color=PALETTE["slate"], pad=6,
|
| 319 |
+
)
|
| 320 |
+
fig.tight_layout()
|
| 321 |
+
save(fig, "plot3_selfplay_success.png", out)
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 325 |
+
# PLOT 4 β Reward Signal Tightening: mean Β± std over 30 iterations
|
| 326 |
+
# As the policy learns what "good" looks like, the spread between the best
|
| 327 |
+
# and worst solutions in a group narrows. Lower variance = more consistent
|
| 328 |
+
# reasoning, not lucky guessing.
|
| 329 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββοΏ½οΏ½οΏ½βββββββββββββββββββββββββββββββ
|
| 330 |
+
|
| 331 |
+
def plot_reward_confidence(rows: List[Dict], out: Path):
|
| 332 |
+
iters = [int(r["iteration"]) for r in rows]
|
| 333 |
+
phases = [r["training_phase"] for r in rows]
|
| 334 |
+
mean_r = np.array([f(r, "mean_reward") for r in rows])
|
| 335 |
+
std_r = np.array([f(r, "std_reward") for r in rows])
|
| 336 |
+
skipped = np.array([f(r, "skipped_groups", 0) for r in rows])
|
| 337 |
+
n_grps = np.array([f(r, "n_groups", 1) for r in rows])
|
| 338 |
+
skip_rt = skipped / np.maximum(n_grps, 1) * 100
|
| 339 |
+
|
| 340 |
+
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(11, 7), sharex=True,
|
| 341 |
+
gridspec_kw={"height_ratios": [3, 1.2]})
|
| 342 |
+
fig.suptitle("Reward Confidence β Mean Β± Std & Skipped Groups Over 30 Iterations",
|
| 343 |
+
fontsize=14, fontweight="bold", color=PALETTE["white"], y=1.01)
|
| 344 |
+
|
| 345 |
+
shade_phases(ax1, iters, phases)
|
| 346 |
+
|
| 347 |
+
ax1.fill_between(iters, (mean_r - std_r) * 100, (mean_r + std_r) * 100,
|
| 348 |
+
alpha=0.20, color=PALETTE["indigo"])
|
| 349 |
+
ax1.plot(iters, mean_r * 100, color=PALETTE["indigo"], linewidth=2.5, label="Mean reward")
|
| 350 |
+
ax1.plot(iters, (mean_r - std_r) * 100, "--", color=PALETTE["slate"], linewidth=1,
|
| 351 |
+
alpha=0.6, label="Β±1 std")
|
| 352 |
+
ax1.plot(iters, (mean_r + std_r) * 100, "--", color=PALETTE["slate"], linewidth=1,
|
| 353 |
+
alpha=0.6)
|
| 354 |
+
|
| 355 |
+
# highlight the two tight-cluster peaks
|
| 356 |
+
for special_iter, label in [(11, "iter 11\nstd=0.098"), (22, "iter 22\nstd=0.124")]:
|
| 357 |
+
si = iters.index(special_iter)
|
| 358 |
+
ax1.annotate(label,
|
| 359 |
+
xy=(special_iter, (mean_r[si] + std_r[si]) * 100),
|
| 360 |
+
xytext=(special_iter + 1, (mean_r[si] + std_r[si]) * 100 + 2),
|
| 361 |
+
fontsize=8, color=PALETTE["amber"],
|
| 362 |
+
arrowprops=dict(arrowstyle="->", color=PALETTE["amber"], lw=1.2))
|
| 363 |
+
|
| 364 |
+
ax1.set_ylabel("Reward (%)")
|
| 365 |
+
ax1.set_ylim(55, 115)
|
| 366 |
+
ax1.yaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter("%.0f%%"))
|
| 367 |
+
h1, l1 = ax1.get_legend_handles_labels()
|
| 368 |
+
ax1.legend(handles=h1 + phase_legend_patches(phases), framealpha=0.8, ncol=3)
|
| 369 |
+
|
| 370 |
+
# skip-rate bar chart (bottom panel)
|
| 371 |
+
shade_phases(ax2, iters, phases)
|
| 372 |
+
ax2.bar(iters, skip_rt, color=PALETTE["red"], alpha=0.7, width=0.7, label="Skipped groups %")
|
| 373 |
+
ax2.set_ylabel("Skipped\ngroups (%)")
|
| 374 |
+
ax2.set_xlabel("Training Iteration")
|
| 375 |
+
ax2.set_ylim(0, 75)
|
| 376 |
+
ax2.set_xticks(range(1, max(iters) + 1, 2))
|
| 377 |
+
ax2.yaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter("%.0f%%"))
|
| 378 |
+
ax2.legend(loc="upper right", framealpha=0.8)
|
| 379 |
+
|
| 380 |
+
fig.tight_layout()
|
| 381 |
+
save(fig, "plot4_reward_confidence.png", out)
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 385 |
+
# PLOT 5 β Step-Level Reasoning Quality: train vs eval
|
| 386 |
+
# Breaks down the two signals that measure HOW the model thinks (not just
|
| 387 |
+
# whether it gets the final answer right): step accuracy and LCCP.
|
| 388 |
+
# Train lines are noisy; eval lines show clean upward trends.
|
| 389 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 390 |
+
|
| 391 |
+
def plot_reasoning_quality(rows: List[Dict], out: Path):
|
| 392 |
+
iters = [int(r["iteration"]) for r in rows]
|
| 393 |
+
phases = [r["training_phase"] for r in rows]
|
| 394 |
+
|
| 395 |
+
# training
|
| 396 |
+
t_step = [f(r, "step_accuracy") * 100 for r in rows]
|
| 397 |
+
t_lccp = [f(r, "lccp") * 100 for r in rows]
|
| 398 |
+
t_gt = [f(r, "gt_match_rate") * 100 for r in rows]
|
| 399 |
+
|
| 400 |
+
# eval (only at checkpoint iters)
|
| 401 |
+
eval_rows = [r for r in rows if r.get("eval_combined", "") != ""]
|
| 402 |
+
e_iters = [int(r["iteration"]) for r in eval_rows]
|
| 403 |
+
e_step = [f(r, "eval_step_acc") * 100 for r in eval_rows]
|
| 404 |
+
e_lccp = [f(r, "eval_lccp") * 100 for r in eval_rows]
|
| 405 |
+
|
| 406 |
+
# moving averages
|
| 407 |
+
ma_step = moving_avg(t_step, w=4)
|
| 408 |
+
ma_lccp = moving_avg(t_lccp, w=4)
|
| 409 |
+
ma_gt = moving_avg(t_gt, w=4)
|
| 410 |
+
|
| 411 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(13, 5.5))
|
| 412 |
+
fig.suptitle("Step-Level Reasoning Quality β Training vs Held-Out Evaluation",
|
| 413 |
+
fontsize=14, fontweight="bold", color=PALETTE["white"], y=1.01)
|
| 414 |
+
|
| 415 |
+
# ββ LEFT: step accuracy ββ
|
| 416 |
+
shade_phases(ax1, iters, phases)
|
| 417 |
+
ax1.plot(iters, t_step, alpha=0.2, color=PALETTE["cyan"], linewidth=1)
|
| 418 |
+
ax1.plot(iters, ma_step, color=PALETTE["cyan"], linewidth=2.5, label="Train step acc (smooth)")
|
| 419 |
+
ax1.plot(iters, t_gt, alpha=0.15, color=PALETTE["pink"], linewidth=1)
|
| 420 |
+
ax1.plot(iters, ma_gt, color=PALETTE["pink"], linewidth=2.5, label="Train GT match (smooth)")
|
| 421 |
+
ax1.plot(e_iters, e_step, "o-", color=PALETTE["white"], ms=8, linewidth=2,
|
| 422 |
+
label="Eval step accuracy", zorder=6)
|
| 423 |
+
|
| 424 |
+
# annotate eval start/end
|
| 425 |
+
ax1.annotate(f"{e_step[0]:.1f}%", xy=(e_iters[0], e_step[0]),
|
| 426 |
+
xytext=(e_iters[0] - 0.3, e_step[0] - 1.2), fontsize=8.5,
|
| 427 |
+
color=PALETTE["white"], ha="right")
|
| 428 |
+
ax1.annotate(f"{e_step[-1]:.1f}%", xy=(e_iters[-1], e_step[-1]),
|
| 429 |
+
xytext=(e_iters[-1] + 0.3, e_step[-1] + 0.5), fontsize=8.5,
|
| 430 |
+
color=PALETTE["white"])
|
| 431 |
+
ax1.annotate("", xy=(e_iters[-1], e_step[-1]),
|
| 432 |
+
xytext=(e_iters[0], e_step[0]),
|
| 433 |
+
arrowprops=dict(arrowstyle="->", color=PALETTE["cyan"], lw=1.5,
|
| 434 |
+
connectionstyle="arc3,rad=-0.3"))
|
| 435 |
+
|
| 436 |
+
ax1.set_title("Step Accuracy β Did each reasoning step hold up?",
|
| 437 |
+
fontsize=9.5, color=PALETTE["slate"], pad=5)
|
| 438 |
+
ax1.set_xlabel("Training Iteration")
|
| 439 |
+
ax1.set_ylabel("Score (%)")
|
| 440 |
+
ax1.set_ylim(55, 105)
|
| 441 |
+
ax1.set_xticks(range(1, max(iters) + 1, 3))
|
| 442 |
+
ax1.yaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter("%.0f%%"))
|
| 443 |
+
ax1.legend(handles=ax1.get_legend_handles_labels()[0] + phase_legend_patches(phases),
|
| 444 |
+
framealpha=0.8, ncol=1, loc="lower right")
|
| 445 |
+
|
| 446 |
+
# ββ RIGHT: LCCP ββ
|
| 447 |
+
shade_phases(ax2, iters, phases)
|
| 448 |
+
ax2.plot(iters, t_lccp, alpha=0.2, color=PALETTE["emerald"], linewidth=1)
|
| 449 |
+
ax2.plot(iters, ma_lccp, color=PALETTE["emerald"], linewidth=2.5, label="Train LCCP (smooth)")
|
| 450 |
+
ax2.plot(e_iters, e_lccp, "o-", color=PALETTE["white"], ms=8, linewidth=2,
|
| 451 |
+
label="Eval LCCP", zorder=6)
|
| 452 |
+
|
| 453 |
+
ax2.annotate(f"{e_lccp[0]:.1f}%", xy=(e_iters[0], e_lccp[0]),
|
| 454 |
+
xytext=(e_iters[0] - 0.3, e_lccp[0] - 1.5), fontsize=8.5,
|
| 455 |
+
color=PALETTE["white"], ha="right")
|
| 456 |
+
ax2.annotate(f"{e_lccp[-1]:.1f}%", xy=(e_iters[-1], e_lccp[-1]),
|
| 457 |
+
xytext=(e_iters[-1] + 0.3, e_lccp[-1] + 0.5), fontsize=8.5,
|
| 458 |
+
color=PALETTE["white"])
|
| 459 |
+
|
| 460 |
+
# show LCCP delta
|
| 461 |
+
delta = e_lccp[-1] - e_lccp[0]
|
| 462 |
+
ax2.text(0.97, 0.06,
|
| 463 |
+
f"Eval LCCP Ξ = +{delta:.2f}pp\n(iter {e_iters[0]} β {e_iters[-1]})",
|
| 464 |
+
transform=ax2.transAxes, ha="right", va="bottom",
|
| 465 |
+
fontsize=8.5, color=PALETTE["emerald"],
|
| 466 |
+
bbox=dict(facecolor=PALETTE["bg2"], edgecolor=PALETTE["emerald"],
|
| 467 |
+
linewidth=0.8, pad=5))
|
| 468 |
+
|
| 469 |
+
ax2.set_title("LCCP β Did the chain of reasoning stay correct until the first error?",
|
| 470 |
+
fontsize=9.5, color=PALETTE["slate"], pad=5)
|
| 471 |
+
ax2.set_xlabel("Training Iteration")
|
| 472 |
+
ax2.set_ylabel("LCCP (%)")
|
| 473 |
+
ax2.set_ylim(55, 100)
|
| 474 |
+
ax2.set_xticks(range(1, max(iters) + 1, 3))
|
| 475 |
+
ax2.yaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter("%.0f%%"))
|
| 476 |
+
ax2.legend(handles=ax2.get_legend_handles_labels()[0] + phase_legend_patches(phases),
|
| 477 |
+
framealpha=0.8, ncol=1, loc="lower right")
|
| 478 |
+
|
| 479 |
+
fig.tight_layout()
|
| 480 |
+
save(fig, "plot5_reasoning_quality.png", out)
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 484 |
+
# Main
|
| 485 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 486 |
+
|
| 487 |
+
def parse_args():
|
| 488 |
+
p = argparse.ArgumentParser(description="Generate AxiomForgeAI training plots")
|
| 489 |
+
p.add_argument("--metrics", default=DEFAULT_METRICS,
|
| 490 |
+
help=f"Path to metrics.csv (default: {DEFAULT_METRICS})")
|
| 491 |
+
p.add_argument("--out", default="images",
|
| 492 |
+
help="Output directory for PNGs (default: images/)")
|
| 493 |
+
return p.parse_args()
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
def main():
|
| 497 |
+
args = parse_args()
|
| 498 |
+
out = Path(args.out)
|
| 499 |
+
|
| 500 |
+
print(f"Loading metrics from : {args.metrics}")
|
| 501 |
+
print(f"Saving plots to : {out}/")
|
| 502 |
+
print()
|
| 503 |
+
|
| 504 |
+
rows = load_csv(args.metrics)
|
| 505 |
+
print(f"Loaded {len(rows)} iterations.\n")
|
| 506 |
+
|
| 507 |
+
print("Generating plots β¦")
|
| 508 |
+
plot_eval_quality(rows, out)
|
| 509 |
+
plot_training_journey(rows, out)
|
| 510 |
+
plot_selfplay_success(rows, out)
|
| 511 |
+
plot_reward_confidence(rows, out)
|
| 512 |
+
plot_reasoning_quality(rows, out)
|
| 513 |
+
|
| 514 |
+
print(f"\nβ
All 5 plots saved to {out}/")
|
| 515 |
+
print("\nFiles:")
|
| 516 |
+
for p in sorted(out.glob("plot*.png")):
|
| 517 |
+
print(f" {p} ({p.stat().st_size // 1024} KB)")
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
if __name__ == "__main__":
|
| 521 |
+
main()
|
scripts/precompute_extraction_cache.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Offline step-chain extraction cache builder.
|
| 3 |
+
|
| 4 |
+
Run this once before training to pre-extract structured step chains from all
|
| 5 |
+
grounded training data (GSM8K + MATH). The resulting cache file is passed to
|
| 6 |
+
run_grpo_training.py via --extraction-cache so the extractor LLM is never
|
| 7 |
+
called for fixed training examples β only novel self-play solutions require
|
| 8 |
+
live extraction during training.
|
| 9 |
+
|
| 10 |
+
Usage
|
| 11 |
+
-----
|
| 12 |
+
python scripts/precompute_extraction_cache.py \\
|
| 13 |
+
--gsm8k-data data/sft/gsm8k_sft.jsonl \\
|
| 14 |
+
--math-data data/sft/math_sft.jsonl \\
|
| 15 |
+
--output-cache data/extraction_cache.json \\
|
| 16 |
+
--extractor-model Qwen/Qwen2.5-0.5B-Instruct \\
|
| 17 |
+
--device cuda
|
| 18 |
+
|
| 19 |
+
Cache key: md5(question + "\\n" + solution) β keying on both prevents
|
| 20 |
+
collisions when two MATH problems share identical solution text.
|
| 21 |
+
Entries for solutions the extractor cannot parse are stored with
|
| 22 |
+
success=False so training never re-attempts and correctly penalises them.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
from __future__ import annotations
|
| 26 |
+
|
| 27 |
+
import argparse
|
| 28 |
+
import json
|
| 29 |
+
import logging
|
| 30 |
+
import pathlib
|
| 31 |
+
import sys
|
| 32 |
+
from typing import List, Tuple
|
| 33 |
+
|
| 34 |
+
logging.basicConfig(
|
| 35 |
+
level=logging.INFO,
|
| 36 |
+
format="%(asctime)s %(levelname)-8s %(message)s",
|
| 37 |
+
handlers=[logging.StreamHandler(sys.stdout)],
|
| 38 |
+
)
|
| 39 |
+
logger = logging.getLogger(__name__)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def load_jsonl(path: str) -> list[dict]:
|
| 43 |
+
records: list[dict] = []
|
| 44 |
+
with open(path, encoding="utf-8") as f:
|
| 45 |
+
for line in f:
|
| 46 |
+
line = line.strip()
|
| 47 |
+
if line:
|
| 48 |
+
try:
|
| 49 |
+
records.append(json.loads(line))
|
| 50 |
+
except json.JSONDecodeError:
|
| 51 |
+
pass
|
| 52 |
+
return records
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def collect_qa_pairs(records: list[dict]) -> List[Tuple[str, str]]:
|
| 56 |
+
"""
|
| 57 |
+
Extract (question, solution) pairs from dataset records.
|
| 58 |
+
|
| 59 |
+
Returns pairs where both fields are non-empty. Falls back to empty
|
| 60 |
+
string for the question when only the solution field is present.
|
| 61 |
+
"""
|
| 62 |
+
pairs: List[Tuple[str, str]] = []
|
| 63 |
+
for rec in records:
|
| 64 |
+
sol = (
|
| 65 |
+
rec.get("solution")
|
| 66 |
+
or rec.get("output")
|
| 67 |
+
or rec.get("response")
|
| 68 |
+
or ""
|
| 69 |
+
)
|
| 70 |
+
q = (
|
| 71 |
+
rec.get("question")
|
| 72 |
+
or rec.get("problem")
|
| 73 |
+
or rec.get("input")
|
| 74 |
+
or ""
|
| 75 |
+
)
|
| 76 |
+
if sol.strip():
|
| 77 |
+
pairs.append((q.strip(), sol.strip()))
|
| 78 |
+
return pairs
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def main() -> None:
|
| 82 |
+
parser = argparse.ArgumentParser(
|
| 83 |
+
description="Pre-extract step chains for grounded training data."
|
| 84 |
+
)
|
| 85 |
+
parser.add_argument(
|
| 86 |
+
"--gsm8k-data", required=True,
|
| 87 |
+
help="Path to GSM8K training JSONL (e.g. data/sft/gsm8k_sft.jsonl).",
|
| 88 |
+
)
|
| 89 |
+
parser.add_argument(
|
| 90 |
+
"--math-data", default=None,
|
| 91 |
+
help="Optional path to MATH training JSONL. If provided, those solutions "
|
| 92 |
+
"are also extracted and added to the cache.",
|
| 93 |
+
)
|
| 94 |
+
parser.add_argument(
|
| 95 |
+
"--output-cache", required=True,
|
| 96 |
+
help="Destination JSON file for the extraction cache.",
|
| 97 |
+
)
|
| 98 |
+
parser.add_argument(
|
| 99 |
+
"--extractor-model", default="Qwen/Qwen2.5-0.5B-Instruct",
|
| 100 |
+
help="HuggingFace model ID for the step chain extractor. Default Qwen/Qwen2.5-0.5B-Instruct.",
|
| 101 |
+
)
|
| 102 |
+
parser.add_argument(
|
| 103 |
+
"--device", default="cuda",
|
| 104 |
+
help="Device for the extractor model (default: cuda).",
|
| 105 |
+
)
|
| 106 |
+
parser.add_argument(
|
| 107 |
+
"--batch-size", type=int, default=1,
|
| 108 |
+
help="Reserved for future batched extraction. Currently always 1.",
|
| 109 |
+
)
|
| 110 |
+
args = parser.parse_args()
|
| 111 |
+
|
| 112 |
+
# ββ Load data βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 113 |
+
logger.info("Loading GSM8K data from: %s", args.gsm8k_data)
|
| 114 |
+
gsm8k_records = load_jsonl(args.gsm8k_data)
|
| 115 |
+
qa_pairs = collect_qa_pairs(gsm8k_records)
|
| 116 |
+
logger.info("GSM8K: %d (question, solution) pairs", len(qa_pairs))
|
| 117 |
+
|
| 118 |
+
if args.math_data:
|
| 119 |
+
logger.info("Loading MATH data from: %s", args.math_data)
|
| 120 |
+
math_records = load_jsonl(args.math_data)
|
| 121 |
+
math_pairs = collect_qa_pairs(math_records)
|
| 122 |
+
logger.info("MATH: %d (question, solution) pairs", len(math_pairs))
|
| 123 |
+
qa_pairs += math_pairs
|
| 124 |
+
|
| 125 |
+
if not qa_pairs:
|
| 126 |
+
logger.error(
|
| 127 |
+
"No solutions found in provided files. "
|
| 128 |
+
"Check field names (question/problem/input + solution/output/response)."
|
| 129 |
+
)
|
| 130 |
+
sys.exit(1)
|
| 131 |
+
|
| 132 |
+
# Deduplicate by (question, solution) content
|
| 133 |
+
# Two different MATH problems can have identical solution text but different
|
| 134 |
+
# questions β the question+solution key keeps them distinct in the cache.
|
| 135 |
+
seen: set = set()
|
| 136 |
+
unique_pairs: List[Tuple[str, str]] = []
|
| 137 |
+
for q, sol in qa_pairs:
|
| 138 |
+
key = (q, sol)
|
| 139 |
+
if key not in seen:
|
| 140 |
+
seen.add(key)
|
| 141 |
+
unique_pairs.append((q, sol))
|
| 142 |
+
|
| 143 |
+
logger.info(
|
| 144 |
+
"Total: %d pairs (%d unique after dedup)", len(qa_pairs), len(unique_pairs)
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# ββ Load extractor ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 148 |
+
sys.path.insert(0, str(pathlib.Path(__file__).parent.parent))
|
| 149 |
+
from src.rl.unified_accuracy import StepChainExtractor
|
| 150 |
+
|
| 151 |
+
extractor = StepChainExtractor(
|
| 152 |
+
model_name=args.extractor_model,
|
| 153 |
+
device=args.device,
|
| 154 |
+
cache_path=args.output_cache, # load existing cache if present (resume)
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
# ββ Build cache βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 158 |
+
already_cached = len(extractor._cache)
|
| 159 |
+
if already_cached:
|
| 160 |
+
logger.info("Resuming: %d entries already in cache", already_cached)
|
| 161 |
+
|
| 162 |
+
extractor.build_cache(unique_pairs)
|
| 163 |
+
|
| 164 |
+
# ββ Save ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 165 |
+
extractor.save_cache()
|
| 166 |
+
logger.info(
|
| 167 |
+
"Done. Cache contains %d entries β %s",
|
| 168 |
+
len(extractor._cache),
|
| 169 |
+
args.output_cache,
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
if __name__ == "__main__":
|
| 174 |
+
main()
|
scripts/prepare_aqua_dataset.py
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Download Chinar/AQuA-RAT from HuggingFace and convert it to the same JSONL
|
| 4 |
+
format used by gsm8k_sft.jsonl so the GRPO training script can consume it
|
| 5 |
+
directly via --gsm8k-data.
|
| 6 |
+
|
| 7 |
+
Chinar/AQuA-RAT schema (processed version)
|
| 8 |
+
-------------------------------------------
|
| 9 |
+
prompt : str β the math question
|
| 10 |
+
completion : str β step-by-step reasoning ending with:
|
| 11 |
+
"The answer is X . Therefore, the correct answer is: <value>"
|
| 12 |
+
|
| 13 |
+
Output schema (messages format expected by load_gsm8k)
|
| 14 |
+
-------------------------------------------------------
|
| 15 |
+
{
|
| 16 |
+
"id": "aqua_<idx>",
|
| 17 |
+
"skill_id": "aqua_rat_algebra",
|
| 18 |
+
"source": "Chinar/AQuA-RAT",
|
| 19 |
+
"split": "train" | "validation",
|
| 20 |
+
"messages": [
|
| 21 |
+
{"role": "system", "content": SOLVER_SYSTEM_PROMPT},
|
| 22 |
+
{"role": "user", "content": "Solve ... Problem:\\n<question>"},
|
| 23 |
+
{"role": "assistant", "content": "Step 1: ...\\nFinal Answer: <value>"}
|
| 24 |
+
]
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
The dataset has only a 'train' split β we reserve the last 500 rows as
|
| 28 |
+
a validation set and use the rest for training.
|
| 29 |
+
|
| 30 |
+
Usage
|
| 31 |
+
-----
|
| 32 |
+
python scripts/prepare_aqua_dataset.py
|
| 33 |
+
python scripts/prepare_aqua_dataset.py --val-size 300 --dry-run
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
from __future__ import annotations
|
| 37 |
+
|
| 38 |
+
import argparse
|
| 39 |
+
import json
|
| 40 |
+
import re
|
| 41 |
+
import sys
|
| 42 |
+
from pathlib import Path
|
| 43 |
+
from typing import Any, Optional
|
| 44 |
+
|
| 45 |
+
# ---------------------------------------------------------------------------
|
| 46 |
+
# Prompt constants (kept in sync with src/config/prompts.py)
|
| 47 |
+
# ---------------------------------------------------------------------------
|
| 48 |
+
|
| 49 |
+
SOLVER_SYSTEM_PROMPT = (
|
| 50 |
+
"You are a step-by-step math solver. "
|
| 51 |
+
"Solve the given problem one step at a time. "
|
| 52 |
+
"Each step must be on its own line, starting with 'Step N:'. "
|
| 53 |
+
"End with a line starting with 'Final Answer:'. "
|
| 54 |
+
"Write every mathematical expression in Python/SymPy syntax "
|
| 55 |
+
"so it can be verified programmatically."
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
USER_WRAPPER = (
|
| 59 |
+
"Solve the following problem. Show your reasoning as numbered steps, "
|
| 60 |
+
"then give the final numeric answer on the last line.\n\nProblem:\n{question}"
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# ---------------------------------------------------------------------------
|
| 64 |
+
# Answer extraction
|
| 65 |
+
# ---------------------------------------------------------------------------
|
| 66 |
+
|
| 67 |
+
# The completion always ends with a variant of:
|
| 68 |
+
# "The answer is E . Therefore, the correct answer is: 23"
|
| 69 |
+
_ANSWER_TAIL = re.compile(
|
| 70 |
+
r"(?:The answer is\s+[A-Ea-e]\s*[.\-]?\s*)?"
|
| 71 |
+
r"Therefore,?\s+the correct answer is\s*:?\s*(.+)$",
|
| 72 |
+
re.IGNORECASE,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def _extract_answer_and_rationale(completion: str) -> Optional[tuple[str, str]]:
|
| 77 |
+
"""
|
| 78 |
+
Split the completion into (rationale_lines, final_answer_str).
|
| 79 |
+
Returns None if no extractable numeric answer is found.
|
| 80 |
+
"""
|
| 81 |
+
# Find the tail marker
|
| 82 |
+
m = _ANSWER_TAIL.search(completion)
|
| 83 |
+
if not m:
|
| 84 |
+
return None
|
| 85 |
+
|
| 86 |
+
raw_answer = m.group(1).strip()
|
| 87 |
+
# Everything before the tail is the rationale
|
| 88 |
+
rationale = completion[: m.start()].strip()
|
| 89 |
+
# Also strip a standalone "The answer is X ." line at the end of rationale
|
| 90 |
+
rationale = re.sub(r"\s*The answer is\s+[A-Ea-e]\s*[.\-]?\s*$", "", rationale, flags=re.IGNORECASE).strip()
|
| 91 |
+
|
| 92 |
+
# Normalise the answer to a clean numeric string
|
| 93 |
+
final_answer = _normalise_answer(raw_answer)
|
| 94 |
+
if final_answer is None:
|
| 95 |
+
return None
|
| 96 |
+
|
| 97 |
+
return rationale, final_answer
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def _normalise_answer(raw: str) -> Optional[str]:
|
| 101 |
+
"""
|
| 102 |
+
Extract a single numeric value from an answer string.
|
| 103 |
+
|
| 104 |
+
"23" β "23"
|
| 105 |
+
"$ 1600" β "1600"
|
| 106 |
+
"8 seconds" β "8"
|
| 107 |
+
"5 and 1" β None (multi-value β skip)
|
| 108 |
+
"I and II" β None (non-numeric β skip)
|
| 109 |
+
"β 3 β€ x β€ 4" β None (inequality β skip)
|
| 110 |
+
"""
|
| 111 |
+
text = raw.strip()
|
| 112 |
+
|
| 113 |
+
# Remove currency / whitespace
|
| 114 |
+
text = text.replace("$", "").replace("Rs.", "").replace("Rs", "").replace(",", "").strip()
|
| 115 |
+
|
| 116 |
+
# Handle unicode minus
|
| 117 |
+
text = text.replace("\u2212", "-").replace("β", "-")
|
| 118 |
+
|
| 119 |
+
# Skip if "and" still present (multi-value like "5 and 1")
|
| 120 |
+
if re.search(r"\band\b", text, re.IGNORECASE):
|
| 121 |
+
return None
|
| 122 |
+
|
| 123 |
+
# Skip inequalities / expressions with variables
|
| 124 |
+
if re.search(r"[a-zA-Zβ€β₯<>]", text):
|
| 125 |
+
return None
|
| 126 |
+
|
| 127 |
+
# Single number (integer or decimal, optionally negative)
|
| 128 |
+
m = re.fullmatch(r"\s*(-?\d+(?:\.\d+)?)\s*(?:[a-zA-Z%Β°].*)?", text)
|
| 129 |
+
if m:
|
| 130 |
+
val_str = m.group(1)
|
| 131 |
+
try:
|
| 132 |
+
val = float(val_str)
|
| 133 |
+
return str(int(val)) if val == int(val) else val_str
|
| 134 |
+
except ValueError:
|
| 135 |
+
pass
|
| 136 |
+
|
| 137 |
+
return None
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
# ---------------------------------------------------------------------------
|
| 141 |
+
# Rationale β Step N: format
|
| 142 |
+
# ---------------------------------------------------------------------------
|
| 143 |
+
|
| 144 |
+
def _rationale_to_steps(rationale: str) -> list[str]:
|
| 145 |
+
lines: list[str] = []
|
| 146 |
+
for raw in rationale.splitlines():
|
| 147 |
+
line = raw.strip()
|
| 148 |
+
if line:
|
| 149 |
+
line = line.replace("^", "**")
|
| 150 |
+
lines.append(line)
|
| 151 |
+
if not lines and rationale.strip():
|
| 152 |
+
sentences = re.split(r"(?<=[.!?])\s+", rationale.strip())
|
| 153 |
+
lines = [s.strip() for s in sentences if s.strip()]
|
| 154 |
+
return lines
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def _build_assistant(rationale: str, final_answer: str) -> str:
|
| 158 |
+
steps = _rationale_to_steps(rationale)
|
| 159 |
+
parts = [f"Step {i}: {line}" for i, line in enumerate(steps, 1)]
|
| 160 |
+
body = "\n".join(parts)
|
| 161 |
+
return f"{body}\nFinal Answer: {final_answer}" if body else f"Final Answer: {final_answer}"
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
# ---------------------------------------------------------------------------
|
| 165 |
+
# Row conversion
|
| 166 |
+
# ---------------------------------------------------------------------------
|
| 167 |
+
|
| 168 |
+
def convert_row(row: dict[str, Any], idx: int, split: str) -> Optional[dict[str, Any]]:
|
| 169 |
+
question = (row.get("prompt") or "").strip()
|
| 170 |
+
completion = (row.get("completion") or "").strip()
|
| 171 |
+
|
| 172 |
+
if not question or not completion:
|
| 173 |
+
return None
|
| 174 |
+
|
| 175 |
+
result = _extract_answer_and_rationale(completion)
|
| 176 |
+
if result is None:
|
| 177 |
+
return None
|
| 178 |
+
|
| 179 |
+
rationale, final_answer = result
|
| 180 |
+
assistant_text = _build_assistant(rationale, final_answer)
|
| 181 |
+
|
| 182 |
+
return {
|
| 183 |
+
"id": f"aqua_{split}_{idx}",
|
| 184 |
+
"skill_id": "aqua_rat_algebra",
|
| 185 |
+
"source": "Chinar/AQuA-RAT",
|
| 186 |
+
"split": split,
|
| 187 |
+
"messages": [
|
| 188 |
+
{"role": "system", "content": SOLVER_SYSTEM_PROMPT},
|
| 189 |
+
{"role": "user", "content": USER_WRAPPER.format(question=question)},
|
| 190 |
+
{"role": "assistant", "content": assistant_text},
|
| 191 |
+
],
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
# ---------------------------------------------------------------------------
|
| 196 |
+
# Main
|
| 197 |
+
# ---------------------------------------------------------------------------
|
| 198 |
+
|
| 199 |
+
def main() -> None:
|
| 200 |
+
parser = argparse.ArgumentParser()
|
| 201 |
+
parser.add_argument("--output-dir", default="data/sft")
|
| 202 |
+
parser.add_argument("--val-size", type=int, default=500,
|
| 203 |
+
help="How many rows from the end of the dataset to use as validation.")
|
| 204 |
+
parser.add_argument("--dry-run", action="store_true")
|
| 205 |
+
parser.add_argument("--max-samples", type=int, default=None)
|
| 206 |
+
args = parser.parse_args()
|
| 207 |
+
|
| 208 |
+
try:
|
| 209 |
+
from datasets import load_dataset
|
| 210 |
+
except ImportError:
|
| 211 |
+
print("ERROR: pip install datasets", file=sys.stderr)
|
| 212 |
+
sys.exit(1)
|
| 213 |
+
|
| 214 |
+
print("Downloading Chinar/AQuA-RAT β¦")
|
| 215 |
+
ds = load_dataset("Chinar/AQuA-RAT")
|
| 216 |
+
all_rows = list(ds["train"])
|
| 217 |
+
total = len(all_rows)
|
| 218 |
+
print(f" Total rows: {total:,}")
|
| 219 |
+
|
| 220 |
+
val_rows = all_rows[-args.val_size:]
|
| 221 |
+
train_rows = all_rows[: -args.val_size]
|
| 222 |
+
|
| 223 |
+
splits = {
|
| 224 |
+
"train": train_rows,
|
| 225 |
+
"validation": val_rows,
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
out_dir = Path(args.output_dir)
|
| 229 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 230 |
+
|
| 231 |
+
for split, rows in splits.items():
|
| 232 |
+
if args.max_samples:
|
| 233 |
+
rows = rows[: args.max_samples]
|
| 234 |
+
|
| 235 |
+
records: list[dict] = []
|
| 236 |
+
skipped = 0
|
| 237 |
+
for idx, row in enumerate(rows):
|
| 238 |
+
rec = convert_row(row, idx, split)
|
| 239 |
+
if rec is None:
|
| 240 |
+
skipped += 1
|
| 241 |
+
else:
|
| 242 |
+
records.append(rec)
|
| 243 |
+
|
| 244 |
+
skip_pct = 100.0 * skipped / max(1, len(rows))
|
| 245 |
+
|
| 246 |
+
if args.dry_run:
|
| 247 |
+
print(f"\nββ {split}: {len(records)} valid / {skipped} skipped ({skip_pct:.1f}%) ββ")
|
| 248 |
+
for rec in records[:3]:
|
| 249 |
+
print(json.dumps(rec, indent=2))
|
| 250 |
+
continue
|
| 251 |
+
|
| 252 |
+
out_path = out_dir / f"aqua_{split}.jsonl"
|
| 253 |
+
with out_path.open("w", encoding="utf-8") as f:
|
| 254 |
+
for rec in records:
|
| 255 |
+
f.write(json.dumps(rec, ensure_ascii=False) + "\n")
|
| 256 |
+
|
| 257 |
+
print(f" [{split:12s}] {len(records):6,d} valid {skipped:5,d} skipped ({skip_pct:.1f}%) β {out_path}")
|
| 258 |
+
|
| 259 |
+
if not args.dry_run:
|
| 260 |
+
print("\nDone. Launch continuation training with:")
|
| 261 |
+
print(" bash launch_grpo_aqua.sh")
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
if __name__ == "__main__":
|
| 265 |
+
main()
|
scripts/prepare_combined_dataset.py
ADDED
|
@@ -0,0 +1,711 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Combined dataset pipeline β NuminaMath-CoT + OpenMathInstruct-2
|
| 4 |
+
================================================================
|
| 5 |
+
Downloads, filters, normalises, and merges two large math datasets into a single
|
| 6 |
+
JSONL file (train / val / test) that the GRPO training script can consume directly
|
| 7 |
+
via --gsm8k-data.
|
| 8 |
+
|
| 9 |
+
Why these two datasets
|
| 10 |
+
----------------------
|
| 11 |
+
NuminaMath-CoT (AI-MO/NuminaMath-CoT)
|
| 12 |
+
860 K problems. Clean \\boxed{} answers. 7 rich topic categories that map
|
| 13 |
+
directly to ZPD skill_ids. Sources span AMC, AIME, Chinese HS, olympiads,
|
| 14 |
+
and synthetic β giving natural difficulty diversity.
|
| 15 |
+
|
| 16 |
+
OpenMathInstruct-2 (nvidia/OpenMathInstruct-2)
|
| 17 |
+
14 M synthetic problems with step-level CoT. `expected_answer` is pre-verified.
|
| 18 |
+
Diverse surface forms prevent pattern memorisation. We skip any row whose
|
| 19 |
+
problem_source is "gsm8k" (already in prior training).
|
| 20 |
+
|
| 21 |
+
Output schema (identical to gsm8k_sft.jsonl / aqua_train.jsonl)
|
| 22 |
+
---------------------------------------------------------------
|
| 23 |
+
{
|
| 24 |
+
"id": "<source>_<split>_<idx>",
|
| 25 |
+
"skill_id": "<topic_slug>", β used by ZPD CurriculumManager
|
| 26 |
+
"source": "<hf_dataset_name>",
|
| 27 |
+
"split": "train" | "val" | "test",
|
| 28 |
+
"difficulty": 1 | 2 | 3, β 1=easy 2=medium 3=hard (for ZPD)
|
| 29 |
+
"task_type": "solve",
|
| 30 |
+
"messages": [
|
| 31 |
+
{"role": "system", "content": SOLVER_SYSTEM_PROMPT},
|
| 32 |
+
{"role": "user", "content": "Solve ... Problem:\\n<question>"},
|
| 33 |
+
{"role": "assistant", "content": "Step 1: ...\\nFinal Answer: <answer>"}
|
| 34 |
+
]
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
Usage
|
| 38 |
+
-----
|
| 39 |
+
# Quick test (no download, just show stats)
|
| 40 |
+
python scripts/prepare_combined_dataset.py --dry-run
|
| 41 |
+
|
| 42 |
+
# Full pipeline (default caps: 20 K numina + 15 K openmath)
|
| 43 |
+
python scripts/prepare_combined_dataset.py
|
| 44 |
+
|
| 45 |
+
# Larger run
|
| 46 |
+
python scripts/prepare_combined_dataset.py --max-numina 40000 --max-openmath 30000
|
| 47 |
+
|
| 48 |
+
# Only one source
|
| 49 |
+
python scripts/prepare_combined_dataset.py --skip-openmath
|
| 50 |
+
python scripts/prepare_combined_dataset.py --skip-numina
|
| 51 |
+
|
| 52 |
+
# Custom output dir
|
| 53 |
+
python scripts/prepare_combined_dataset.py --output-dir data/sft/combined
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
from __future__ import annotations
|
| 57 |
+
|
| 58 |
+
import argparse
|
| 59 |
+
import hashlib
|
| 60 |
+
import json
|
| 61 |
+
import logging
|
| 62 |
+
import math
|
| 63 |
+
import random
|
| 64 |
+
import re
|
| 65 |
+
import sys
|
| 66 |
+
from collections import Counter, defaultdict
|
| 67 |
+
from pathlib import Path
|
| 68 |
+
from typing import Any, Dict, Iterator, List, Optional, Tuple
|
| 69 |
+
|
| 70 |
+
logging.basicConfig(
|
| 71 |
+
level=logging.INFO,
|
| 72 |
+
format="%(asctime)s %(levelname)-8s %(message)s",
|
| 73 |
+
datefmt="%H:%M:%S",
|
| 74 |
+
)
|
| 75 |
+
log = logging.getLogger(__name__)
|
| 76 |
+
|
| 77 |
+
# ---------------------------------------------------------------------------
|
| 78 |
+
# Constants β kept in sync with src/config/prompts.py
|
| 79 |
+
# ---------------------------------------------------------------------------
|
| 80 |
+
|
| 81 |
+
SOLVER_SYSTEM_PROMPT = (
|
| 82 |
+
"You are a step-by-step math solver. "
|
| 83 |
+
"Solve the given problem one step at a time. "
|
| 84 |
+
"Each step must be on its own line, starting with 'Step N:'. "
|
| 85 |
+
"End with a line starting with 'Final Answer:'. "
|
| 86 |
+
"Write every mathematical expression in Python/SymPy syntax "
|
| 87 |
+
"so it can be verified programmatically."
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
USER_WRAPPER = (
|
| 91 |
+
"Solve the following problem. Show your reasoning as numbered steps, "
|
| 92 |
+
"then give the final numeric answer on the last line.\n\nProblem:\n{question}"
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
# ---------------------------------------------------------------------------
|
| 96 |
+
# Skill-ID mappings (drives ZPD CurriculumManager per-topic mastery)
|
| 97 |
+
# ---------------------------------------------------------------------------
|
| 98 |
+
|
| 99 |
+
# NuminaMath-CoT `type` field β skill_id
|
| 100 |
+
NUMINA_TYPE_TO_SKILL: Dict[str, str] = {
|
| 101 |
+
"algebra": "numina_algebra",
|
| 102 |
+
"intermediate_algebra": "numina_algebra",
|
| 103 |
+
"prealgebra": "numina_prealgebra",
|
| 104 |
+
"number_theory": "numina_number_theory",
|
| 105 |
+
"geometry": "numina_geometry",
|
| 106 |
+
"counting_and_probability": "numina_combinatorics",
|
| 107 |
+
"precalculus": "numina_calculus",
|
| 108 |
+
"calculus": "numina_calculus",
|
| 109 |
+
"statistics": "numina_statistics",
|
| 110 |
+
"probability": "numina_statistics",
|
| 111 |
+
# competition-source buckets (fallback when type not in map above)
|
| 112 |
+
"cn_k12": "numina_algebra",
|
| 113 |
+
"olympiads": "numina_olympiad",
|
| 114 |
+
"amc_aime": "numina_competition",
|
| 115 |
+
"synthetic_math": "numina_synthetic",
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
# NuminaMath source β approximate difficulty (1=easy 2=medium 3=hard)
|
| 119 |
+
NUMINA_SOURCE_DIFFICULTY: Dict[str, int] = {
|
| 120 |
+
"cn_k12": 1,
|
| 121 |
+
"synthetic_math": 2,
|
| 122 |
+
"amc_aime": 2,
|
| 123 |
+
"olympiads": 3,
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
# OpenMathInstruct-2 problem_source β skill_id / difficulty
|
| 127 |
+
OPENMATH_SOURCE_TO_SKILL: Dict[str, str] = {
|
| 128 |
+
"math": "openmath_algebra", # overridden per-row by subject
|
| 129 |
+
"amc_aime_1983_2024": "openmath_competition",
|
| 130 |
+
"synthetic_math": "openmath_synthetic",
|
| 131 |
+
"number_theory": "openmath_number_theory",
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
OPENMATH_SOURCE_DIFFICULTY: Dict[str, int] = {
|
| 135 |
+
"math": 2,
|
| 136 |
+
"amc_aime_1983_2024": 3,
|
| 137 |
+
"synthetic_math": 1,
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
# OpenMathInstruct MATH-subject β skill_id (when problem_source == "math")
|
| 141 |
+
OPENMATH_MATH_SUBJECT_SKILL: Dict[str, str] = {
|
| 142 |
+
"Algebra": "openmath_algebra",
|
| 143 |
+
"Number Theory": "openmath_number_theory",
|
| 144 |
+
"Geometry": "openmath_geometry",
|
| 145 |
+
"Counting & Probability": "openmath_combinatorics",
|
| 146 |
+
"Intermediate Algebra": "openmath_algebra",
|
| 147 |
+
"Prealgebra": "openmath_prealgebra",
|
| 148 |
+
"Precalculus": "openmath_calculus",
|
| 149 |
+
"Calculus": "openmath_calculus",
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
# ---------------------------------------------------------------------------
|
| 153 |
+
# Answer normalisation
|
| 154 |
+
# ---------------------------------------------------------------------------
|
| 155 |
+
|
| 156 |
+
_BOXED_RE = re.compile(r"\\boxed\{((?:[^{}]|\{[^{}]*\})*)\}")
|
| 157 |
+
_LATEX_FRAC = re.compile(r"\\frac\{(\d+)\}\{(\d+)\}")
|
| 158 |
+
_PLAIN_FRAC = re.compile(r"^(-?\d+)\s*/\s*(\d+)$")
|
| 159 |
+
_CURRENCY = re.compile(r"(?:Rs\.?|USD|\$|β¬|Β£)\s*", re.IGNORECASE)
|
| 160 |
+
_UNICODE_MINUS = str.maketrans({"\u2212": "-", "β": "-"})
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def extract_boxed(text: str) -> Optional[str]:
|
| 164 |
+
"""Return the last \\boxed{} contents from a solution string."""
|
| 165 |
+
matches = _BOXED_RE.findall(text)
|
| 166 |
+
return matches[-1].strip() if matches else None
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def normalise_numeric(raw: str) -> Optional[str]:
|
| 170 |
+
"""
|
| 171 |
+
Convert a raw answer string to a clean numeric string.
|
| 172 |
+
|
| 173 |
+
Returns None for:
|
| 174 |
+
- multi-value answers ("3 and 5")
|
| 175 |
+
- symbolic expressions ("3\\sqrt{2}", "x+1")
|
| 176 |
+
- inequalities
|
| 177 |
+
- fractions where num/den exceed safe range
|
| 178 |
+
"""
|
| 179 |
+
text = raw.strip()
|
| 180 |
+
|
| 181 |
+
# Remove currency symbols and commas in numbers
|
| 182 |
+
text = _CURRENCY.sub("", text)
|
| 183 |
+
text = text.replace(",", "").translate(_UNICODE_MINUS).strip()
|
| 184 |
+
|
| 185 |
+
# Skip if still contains words other than units
|
| 186 |
+
if re.search(r"\b(and|or|none|no solution|undefined)\b", text, re.IGNORECASE):
|
| 187 |
+
return None
|
| 188 |
+
|
| 189 |
+
# Skip if contains letters (symbolic)
|
| 190 |
+
if re.search(r"[a-zA-Z]", text):
|
| 191 |
+
return None
|
| 192 |
+
|
| 193 |
+
# Skip inequalities / ranges
|
| 194 |
+
if re.search(r"[β€β₯<>]", text):
|
| 195 |
+
return None
|
| 196 |
+
|
| 197 |
+
# Handle LaTeX fractions: \frac{3}{4}
|
| 198 |
+
m = _LATEX_FRAC.fullmatch(text)
|
| 199 |
+
if m:
|
| 200 |
+
num, den = int(m.group(1)), int(m.group(2))
|
| 201 |
+
if den:
|
| 202 |
+
v = num / den
|
| 203 |
+
return str(int(v)) if v == int(v) else f"{v:.4f}"
|
| 204 |
+
return None
|
| 205 |
+
|
| 206 |
+
# Handle plain fractions: 3/4
|
| 207 |
+
m = _PLAIN_FRAC.match(text)
|
| 208 |
+
if m:
|
| 209 |
+
num, den = int(m.group(1)), int(m.group(2))
|
| 210 |
+
if den:
|
| 211 |
+
v = num / den
|
| 212 |
+
return str(int(v)) if v == int(v) else f"{v:.4f}"
|
| 213 |
+
return None
|
| 214 |
+
|
| 215 |
+
# Handle percentage β decimal
|
| 216 |
+
pct = re.fullmatch(r"(-?\d+(?:\.\d+)?)\s*%", text)
|
| 217 |
+
if pct:
|
| 218 |
+
v = float(pct.group(1))
|
| 219 |
+
return str(int(v)) if v == int(v) else f"{v:.4f}"
|
| 220 |
+
|
| 221 |
+
# Plain integer or decimal (possibly negative, possibly with trailing unit like "km")
|
| 222 |
+
m = re.match(r"^\s*(-?\d+(?:\.\d+)?)\s*(?:[^0-9.\s].*)?\s*$", text)
|
| 223 |
+
if m:
|
| 224 |
+
val_str = m.group(1)
|
| 225 |
+
try:
|
| 226 |
+
v = float(val_str)
|
| 227 |
+
return str(int(v)) if v == int(v) else val_str
|
| 228 |
+
except ValueError:
|
| 229 |
+
pass
|
| 230 |
+
|
| 231 |
+
return None
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
# ---------------------------------------------------------------------------
|
| 235 |
+
# Solution β Step N: format
|
| 236 |
+
# ---------------------------------------------------------------------------
|
| 237 |
+
|
| 238 |
+
_SKIP_LINE_RE = re.compile(
|
| 239 |
+
r"^\s*("
|
| 240 |
+
r"\\boxed\{|"
|
| 241 |
+
r"(Therefore|Thus|Hence|So),?\s+(the\s+)?(final\s+)?answer\s+is|"
|
| 242 |
+
r"The\s+(final\s+)?answer\s+is|"
|
| 243 |
+
r"Answer\s*[:=]"
|
| 244 |
+
r")",
|
| 245 |
+
re.IGNORECASE,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def solution_to_steps(solution: str, final_answer: str, max_steps: int = 18) -> str:
|
| 250 |
+
"""
|
| 251 |
+
Convert an arbitrary CoT solution to the pipeline's Step N: format.
|
| 252 |
+
|
| 253 |
+
Strategy:
|
| 254 |
+
1. Split on newlines.
|
| 255 |
+
2. Drop blank lines and lines that just announce the final answer
|
| 256 |
+
(those are replaced by the explicit Final Answer: line).
|
| 257 |
+
3. Strip any existing "Step N:" prefix to avoid double-numbering.
|
| 258 |
+
4. Re-number as "Step 1:", "Step 2:", β¦
|
| 259 |
+
5. Append "Final Answer: <answer>".
|
| 260 |
+
"""
|
| 261 |
+
raw_lines = [l.strip() for l in solution.split("\n") if l.strip()]
|
| 262 |
+
clean: List[str] = []
|
| 263 |
+
for line in raw_lines:
|
| 264 |
+
if _SKIP_LINE_RE.match(line):
|
| 265 |
+
continue
|
| 266 |
+
# Strip old step prefix
|
| 267 |
+
line = re.sub(r"^Step\s*\d+\s*[:.)]\s*", "", line)
|
| 268 |
+
if line:
|
| 269 |
+
clean.append(line)
|
| 270 |
+
|
| 271 |
+
# Cap to max_steps to keep token count reasonable
|
| 272 |
+
clean = clean[:max_steps]
|
| 273 |
+
|
| 274 |
+
if not clean:
|
| 275 |
+
return f"Final Answer: {final_answer}"
|
| 276 |
+
|
| 277 |
+
parts = [f"Step {i}: {line}" for i, line in enumerate(clean, 1)]
|
| 278 |
+
return "\n".join(parts) + f"\nFinal Answer: {final_answer}"
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
# ---------------------------------------------------------------------------
|
| 282 |
+
# Record builders
|
| 283 |
+
# ---------------------------------------------------------------------------
|
| 284 |
+
|
| 285 |
+
def build_record(
|
| 286 |
+
idx: int,
|
| 287 |
+
split: str,
|
| 288 |
+
source_name: str,
|
| 289 |
+
skill_id: str,
|
| 290 |
+
difficulty: int,
|
| 291 |
+
question: str,
|
| 292 |
+
solution_text: str,
|
| 293 |
+
final_answer: str,
|
| 294 |
+
) -> Dict[str, Any]:
|
| 295 |
+
assistant_content = solution_to_steps(solution_text, final_answer)
|
| 296 |
+
return {
|
| 297 |
+
"id": f"{source_name.replace('/', '_')}_{split}_{idx}",
|
| 298 |
+
"skill_id": skill_id,
|
| 299 |
+
"source": source_name,
|
| 300 |
+
"split": split,
|
| 301 |
+
"difficulty": difficulty,
|
| 302 |
+
"task_type": "solve",
|
| 303 |
+
"messages": [
|
| 304 |
+
{"role": "system", "content": SOLVER_SYSTEM_PROMPT},
|
| 305 |
+
{"role": "user", "content": USER_WRAPPER.format(question=question.strip())},
|
| 306 |
+
{"role": "assistant", "content": assistant_content},
|
| 307 |
+
],
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
# ---------------------------------------------------------------------------
|
| 312 |
+
# Deduplication
|
| 313 |
+
# ---------------------------------------------------------------------------
|
| 314 |
+
|
| 315 |
+
def problem_hash(text: str) -> str:
|
| 316 |
+
"""Fast 16-char hash for near-dedup (exact-match on normalised text)."""
|
| 317 |
+
normalised = re.sub(r"\s+", " ", text.strip().lower())
|
| 318 |
+
return hashlib.md5(normalised.encode()).hexdigest()[:16]
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
# ---------------------------------------------------------------------------
|
| 322 |
+
# NuminaMath-CoT processing
|
| 323 |
+
# ---------------------------------------------------------------------------
|
| 324 |
+
|
| 325 |
+
def _numina_skill_and_difficulty(row: Dict) -> Tuple[str, int]:
|
| 326 |
+
topic = (row.get("type") or "").lower().strip()
|
| 327 |
+
source = (row.get("source") or "").lower().strip()
|
| 328 |
+
|
| 329 |
+
skill = NUMINA_TYPE_TO_SKILL.get(topic)
|
| 330 |
+
if skill is None:
|
| 331 |
+
skill = NUMINA_TYPE_TO_SKILL.get(source, "numina_general")
|
| 332 |
+
|
| 333 |
+
difficulty = NUMINA_SOURCE_DIFFICULTY.get(source, 2)
|
| 334 |
+
return skill, difficulty
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def iter_numina(
|
| 338 |
+
max_samples: int,
|
| 339 |
+
per_skill_cap: int,
|
| 340 |
+
skip_olympiad: bool,
|
| 341 |
+
seed: int,
|
| 342 |
+
) -> Iterator[Dict[str, Any]]:
|
| 343 |
+
"""
|
| 344 |
+
Stream NuminaMath-CoT from HuggingFace and yield cleaned records.
|
| 345 |
+
Uses per-skill quota to guarantee topic diversity.
|
| 346 |
+
"""
|
| 347 |
+
try:
|
| 348 |
+
from datasets import load_dataset # type: ignore
|
| 349 |
+
except ImportError:
|
| 350 |
+
log.error("pip install datasets huggingface_hub")
|
| 351 |
+
sys.exit(1)
|
| 352 |
+
|
| 353 |
+
log.info("Streaming AI-MO/NuminaMath-CoT β¦")
|
| 354 |
+
ds = load_dataset("AI-MO/NuminaMath-CoT", split="train", streaming=True,
|
| 355 |
+
trust_remote_code=True)
|
| 356 |
+
|
| 357 |
+
skill_counts: Counter = Counter()
|
| 358 |
+
seen_hashes: set = set()
|
| 359 |
+
total_yielded = 0
|
| 360 |
+
|
| 361 |
+
rng = random.Random(seed)
|
| 362 |
+
|
| 363 |
+
for row in ds:
|
| 364 |
+
if total_yielded >= max_samples:
|
| 365 |
+
break
|
| 366 |
+
|
| 367 |
+
problem = (row.get("problem") or "").strip()
|
| 368 |
+
solution = (row.get("solution") or "").strip()
|
| 369 |
+
if not problem or not solution:
|
| 370 |
+
continue
|
| 371 |
+
|
| 372 |
+
# Extract and normalise answer from \boxed{}
|
| 373 |
+
raw_answer = extract_boxed(solution)
|
| 374 |
+
if raw_answer is None:
|
| 375 |
+
continue
|
| 376 |
+
final_answer = normalise_numeric(raw_answer)
|
| 377 |
+
if final_answer is None:
|
| 378 |
+
continue
|
| 379 |
+
|
| 380 |
+
skill, difficulty = _numina_skill_and_difficulty(row)
|
| 381 |
+
|
| 382 |
+
# Optionally skip very hard olympiad problems
|
| 383 |
+
if skip_olympiad and skill == "numina_olympiad":
|
| 384 |
+
continue
|
| 385 |
+
|
| 386 |
+
# Per-skill cap to guarantee diversity
|
| 387 |
+
if skill_counts[skill] >= per_skill_cap:
|
| 388 |
+
continue
|
| 389 |
+
|
| 390 |
+
# Dedup
|
| 391 |
+
h = problem_hash(problem)
|
| 392 |
+
if h in seen_hashes:
|
| 393 |
+
continue
|
| 394 |
+
seen_hashes.add(h)
|
| 395 |
+
|
| 396 |
+
skill_counts[skill] += 1
|
| 397 |
+
total_yielded += 1
|
| 398 |
+
|
| 399 |
+
yield build_record(
|
| 400 |
+
idx=total_yielded,
|
| 401 |
+
split="__assign__",
|
| 402 |
+
source_name="AI-MO/NuminaMath-CoT",
|
| 403 |
+
skill_id=skill,
|
| 404 |
+
difficulty=difficulty,
|
| 405 |
+
question=problem,
|
| 406 |
+
solution_text=solution,
|
| 407 |
+
final_answer=final_answer,
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
log.info("NuminaMath-CoT: yielded %d records | skill dist: %s",
|
| 411 |
+
total_yielded, dict(skill_counts.most_common()))
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
# ---------------------------------------------------------------------------
|
| 415 |
+
# OpenMathInstruct-2 processing
|
| 416 |
+
# ---------------------------------------------------------------------------
|
| 417 |
+
|
| 418 |
+
def _openmath_skill_and_difficulty(row: Dict) -> Tuple[str, int]:
|
| 419 |
+
src = (row.get("problem_source") or "").lower().strip()
|
| 420 |
+
subj = (row.get("subject") or "").strip()
|
| 421 |
+
|
| 422 |
+
if src == "math" and subj:
|
| 423 |
+
skill = OPENMATH_MATH_SUBJECT_SKILL.get(subj, "openmath_algebra")
|
| 424 |
+
else:
|
| 425 |
+
skill = OPENMATH_SOURCE_TO_SKILL.get(src, "openmath_general")
|
| 426 |
+
|
| 427 |
+
difficulty = OPENMATH_SOURCE_DIFFICULTY.get(src, 2)
|
| 428 |
+
return skill, difficulty
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
def iter_openmath(
|
| 432 |
+
max_samples: int,
|
| 433 |
+
per_skill_cap: int,
|
| 434 |
+
skip_gsm8k: bool,
|
| 435 |
+
seed: int,
|
| 436 |
+
) -> Iterator[Dict[str, Any]]:
|
| 437 |
+
"""
|
| 438 |
+
Stream OpenMathInstruct-2 from HuggingFace and yield cleaned records.
|
| 439 |
+
Only yields rows where `is_correct_solution` is True (pre-verified by NVIDIA).
|
| 440 |
+
"""
|
| 441 |
+
try:
|
| 442 |
+
from datasets import load_dataset # type: ignore
|
| 443 |
+
except ImportError:
|
| 444 |
+
log.error("pip install datasets huggingface_hub")
|
| 445 |
+
sys.exit(1)
|
| 446 |
+
|
| 447 |
+
log.info("Streaming nvidia/OpenMathInstruct-2 (this may take a moment) β¦")
|
| 448 |
+
ds = load_dataset(
|
| 449 |
+
"nvidia/OpenMathInstruct-2",
|
| 450 |
+
split="train",
|
| 451 |
+
streaming=True,
|
| 452 |
+
trust_remote_code=True,
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
skill_counts: Counter = Counter()
|
| 456 |
+
seen_hashes: set = set()
|
| 457 |
+
total_yielded = 0
|
| 458 |
+
|
| 459 |
+
for row in ds:
|
| 460 |
+
if total_yielded >= max_samples:
|
| 461 |
+
break
|
| 462 |
+
|
| 463 |
+
# Filter: skip gsm8k (contamination risk)
|
| 464 |
+
problem_src = (row.get("problem_source") or "").lower()
|
| 465 |
+
if skip_gsm8k and "gsm8k" in problem_src:
|
| 466 |
+
continue
|
| 467 |
+
|
| 468 |
+
# Filter: only verified correct solutions
|
| 469 |
+
if not row.get("is_correct_solution", True):
|
| 470 |
+
continue
|
| 471 |
+
|
| 472 |
+
problem = (row.get("problem") or "").strip()
|
| 473 |
+
solution = (row.get("generated_solution") or "").strip()
|
| 474 |
+
expected = (row.get("expected_answer") or "").strip()
|
| 475 |
+
|
| 476 |
+
if not problem or not solution or not expected:
|
| 477 |
+
continue
|
| 478 |
+
|
| 479 |
+
# Normalise the pre-extracted answer
|
| 480 |
+
final_answer = normalise_numeric(expected)
|
| 481 |
+
if final_answer is None:
|
| 482 |
+
continue
|
| 483 |
+
|
| 484 |
+
skill, difficulty = _openmath_skill_and_difficulty(row)
|
| 485 |
+
|
| 486 |
+
# Per-skill cap
|
| 487 |
+
if skill_counts[skill] >= per_skill_cap:
|
| 488 |
+
continue
|
| 489 |
+
|
| 490 |
+
# Dedup
|
| 491 |
+
h = problem_hash(problem)
|
| 492 |
+
if h in seen_hashes:
|
| 493 |
+
continue
|
| 494 |
+
seen_hashes.add(h)
|
| 495 |
+
|
| 496 |
+
skill_counts[skill] += 1
|
| 497 |
+
total_yielded += 1
|
| 498 |
+
|
| 499 |
+
yield build_record(
|
| 500 |
+
idx=total_yielded,
|
| 501 |
+
split="__assign__",
|
| 502 |
+
source_name="nvidia/OpenMathInstruct-2",
|
| 503 |
+
skill_id=skill,
|
| 504 |
+
difficulty=difficulty,
|
| 505 |
+
question=problem,
|
| 506 |
+
solution_text=solution,
|
| 507 |
+
final_answer=final_answer,
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
log.info("OpenMathInstruct-2: yielded %d records | skill dist: %s",
|
| 511 |
+
total_yielded, dict(skill_counts.most_common()))
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
# ---------------------------------------------------------------------------
|
| 515 |
+
# Dataset stats printer
|
| 516 |
+
# ---------------------------------------------------------------------------
|
| 517 |
+
|
| 518 |
+
def print_stats(records: List[Dict], label: str) -> None:
|
| 519 |
+
skill_c: Counter = Counter(r["skill_id"] for r in records)
|
| 520 |
+
diff_c: Counter = Counter(r["difficulty"] for r in records)
|
| 521 |
+
src_c: Counter = Counter(r["source"] for r in records)
|
| 522 |
+
split_c: Counter = Counter(r["split"] for r in records)
|
| 523 |
+
|
| 524 |
+
log.info("βββ %s (%d records) βββββββββββββββββββββββββββββββ", label, len(records))
|
| 525 |
+
log.info(" by split: %s", dict(split_c))
|
| 526 |
+
log.info(" by source: %s", dict(src_c))
|
| 527 |
+
log.info(" by difficulty: %s", dict(sorted(diff_c.items())))
|
| 528 |
+
log.info(" by skill_id:")
|
| 529 |
+
for sk, cnt in skill_c.most_common():
|
| 530 |
+
log.info(" %-40s %5d", sk, cnt)
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
# ---------------------------------------------------------------------------
|
| 534 |
+
# Write JSONL
|
| 535 |
+
# ---------------------------------------------------------------------------
|
| 536 |
+
|
| 537 |
+
def write_jsonl(records: List[Dict], path: Path) -> None:
|
| 538 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 539 |
+
with path.open("w", encoding="utf-8") as f:
|
| 540 |
+
for rec in records:
|
| 541 |
+
f.write(json.dumps(rec, ensure_ascii=False) + "\n")
|
| 542 |
+
log.info("Wrote %d records β %s", len(records), path)
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
# ---------------------------------------------------------------------------
|
| 546 |
+
# Train / val / test split (stratified by skill_id)
|
| 547 |
+
# ---------------------------------------------------------------------------
|
| 548 |
+
|
| 549 |
+
def stratified_split(
|
| 550 |
+
records: List[Dict],
|
| 551 |
+
train_frac: float = 0.85,
|
| 552 |
+
val_frac: float = 0.10,
|
| 553 |
+
seed: int = 42,
|
| 554 |
+
) -> Tuple[List[Dict], List[Dict], List[Dict]]:
|
| 555 |
+
"""
|
| 556 |
+
Stratified split by skill_id so every skill appears in all three sets.
|
| 557 |
+
Remaining fraction after train+val goes to test.
|
| 558 |
+
"""
|
| 559 |
+
rng = random.Random(seed)
|
| 560 |
+
|
| 561 |
+
by_skill: Dict[str, List[Dict]] = defaultdict(list)
|
| 562 |
+
for r in records:
|
| 563 |
+
by_skill[r["skill_id"]].append(r)
|
| 564 |
+
|
| 565 |
+
train_, val_, test_ = [], [], []
|
| 566 |
+
for skill, items in by_skill.items():
|
| 567 |
+
rng.shuffle(items)
|
| 568 |
+
n = len(items)
|
| 569 |
+
n_train = math.floor(n * train_frac)
|
| 570 |
+
n_val = math.floor(n * val_frac)
|
| 571 |
+
train_ += items[:n_train]
|
| 572 |
+
val_ += items[n_train: n_train + n_val]
|
| 573 |
+
test_ += items[n_train + n_val:]
|
| 574 |
+
|
| 575 |
+
for r in train_: r["split"] = "train"
|
| 576 |
+
for r in val_: r["split"] = "val"
|
| 577 |
+
for r in test_: r["split"] = "test"
|
| 578 |
+
|
| 579 |
+
# Shuffle each split so skill interleaves during training
|
| 580 |
+
rng.shuffle(train_)
|
| 581 |
+
rng.shuffle(val_)
|
| 582 |
+
rng.shuffle(test_)
|
| 583 |
+
|
| 584 |
+
return train_, val_, test_
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
# ---------------------------------------------------------------------------
|
| 588 |
+
# Main
|
| 589 |
+
# ---------------------------------------------------------------------------
|
| 590 |
+
|
| 591 |
+
def parse_args() -> argparse.Namespace:
|
| 592 |
+
p = argparse.ArgumentParser(
|
| 593 |
+
description="Build combined NuminaMath + OpenMathInstruct-2 training data."
|
| 594 |
+
)
|
| 595 |
+
p.add_argument("--output-dir", default="data/sft",
|
| 596 |
+
help="Directory for output JSONL files.")
|
| 597 |
+
p.add_argument("--max-numina", type=int, default=20_000,
|
| 598 |
+
help="Max records from NuminaMath-CoT (default 20 000).")
|
| 599 |
+
p.add_argument("--max-openmath", type=int, default=15_000,
|
| 600 |
+
help="Max records from OpenMathInstruct-2 (default 15 000).")
|
| 601 |
+
p.add_argument("--per-skill-cap", type=int, default=4_000,
|
| 602 |
+
help="Max records per skill_id to guarantee topic diversity.")
|
| 603 |
+
p.add_argument("--skip-numina", action="store_true",
|
| 604 |
+
help="Skip NuminaMath-CoT entirely.")
|
| 605 |
+
p.add_argument("--skip-openmath", action="store_true",
|
| 606 |
+
help="Skip OpenMathInstruct-2 entirely.")
|
| 607 |
+
p.add_argument("--skip-olympiad", action="store_true", default=True,
|
| 608 |
+
help="Skip numina_olympiad problems (too hard for 1.5B; default: True).")
|
| 609 |
+
p.add_argument("--no-skip-olympiad", dest="skip_olympiad", action="store_false",
|
| 610 |
+
help="Include olympiad-level problems.")
|
| 611 |
+
p.add_argument("--train-frac", type=float, default=0.85)
|
| 612 |
+
p.add_argument("--val-frac", type=float, default=0.10)
|
| 613 |
+
p.add_argument("--seed", type=int, default=42)
|
| 614 |
+
p.add_argument("--dry-run", action="store_true",
|
| 615 |
+
help="Process only 500 rows from each source and show stats (no write).")
|
| 616 |
+
return p.parse_args()
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
def main() -> None:
|
| 620 |
+
args = parse_args()
|
| 621 |
+
rng = random.Random(args.seed)
|
| 622 |
+
|
| 623 |
+
if args.dry_run:
|
| 624 |
+
args.max_numina = min(args.max_numina, 500)
|
| 625 |
+
args.max_openmath = min(args.max_openmath, 500)
|
| 626 |
+
log.info("DRY RUN β capped at 500 samples per source, nothing written to disk.")
|
| 627 |
+
|
| 628 |
+
all_records: List[Dict] = []
|
| 629 |
+
|
| 630 |
+
# ββ NuminaMath-CoT ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 631 |
+
if not args.skip_numina:
|
| 632 |
+
numina_recs = list(iter_numina(
|
| 633 |
+
max_samples = args.max_numina,
|
| 634 |
+
per_skill_cap = args.per_skill_cap,
|
| 635 |
+
skip_olympiad = args.skip_olympiad,
|
| 636 |
+
seed = args.seed,
|
| 637 |
+
))
|
| 638 |
+
all_records.extend(numina_recs)
|
| 639 |
+
log.info("NuminaMath-CoT collected: %d records", len(numina_recs))
|
| 640 |
+
else:
|
| 641 |
+
log.info("Skipping NuminaMath-CoT (--skip-numina).")
|
| 642 |
+
|
| 643 |
+
# ββ OpenMathInstruct-2 ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 644 |
+
if not args.skip_openmath:
|
| 645 |
+
openmath_recs = list(iter_openmath(
|
| 646 |
+
max_samples = args.max_openmath,
|
| 647 |
+
per_skill_cap = args.per_skill_cap,
|
| 648 |
+
skip_gsm8k = True,
|
| 649 |
+
seed = args.seed,
|
| 650 |
+
))
|
| 651 |
+
all_records.extend(openmath_recs)
|
| 652 |
+
log.info("OpenMathInstruct-2 collected: %d records", len(openmath_recs))
|
| 653 |
+
else:
|
| 654 |
+
log.info("Skipping OpenMathInstruct-2 (--skip-openmath).")
|
| 655 |
+
|
| 656 |
+
if not all_records:
|
| 657 |
+
log.error("No records collected β check dataset availability.")
|
| 658 |
+
sys.exit(1)
|
| 659 |
+
|
| 660 |
+
# ββ Deduplicate across sources βββββββββββββββββββββββββββββββββββββββββ
|
| 661 |
+
seen: set = set()
|
| 662 |
+
deduped: List[Dict] = []
|
| 663 |
+
for r in all_records:
|
| 664 |
+
question = r["messages"][1]["content"]
|
| 665 |
+
h = problem_hash(question)
|
| 666 |
+
if h not in seen:
|
| 667 |
+
seen.add(h)
|
| 668 |
+
deduped.append(r)
|
| 669 |
+
|
| 670 |
+
log.info("After cross-source dedup: %d β %d records (removed %d dupes)",
|
| 671 |
+
len(all_records), len(deduped), len(all_records) - len(deduped))
|
| 672 |
+
|
| 673 |
+
# ββ Stratified split ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 674 |
+
train_recs, val_recs, test_recs = stratified_split(
|
| 675 |
+
deduped, args.train_frac, args.val_frac, args.seed
|
| 676 |
+
)
|
| 677 |
+
|
| 678 |
+
print_stats(train_recs + val_recs + test_recs, "COMBINED DATASET")
|
| 679 |
+
|
| 680 |
+
# ββ Write outputs βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 681 |
+
if args.dry_run:
|
| 682 |
+
log.info("DRY RUN complete β no files written.")
|
| 683 |
+
log.info(" would write: combined_train.jsonl (%d rows)", len(train_recs))
|
| 684 |
+
log.info(" would write: combined_val.jsonl (%d rows)", len(val_recs))
|
| 685 |
+
log.info(" would write: combined_test.jsonl (%d rows)", len(test_recs))
|
| 686 |
+
log.info("Sample record:")
|
| 687 |
+
print(json.dumps(train_recs[0], indent=2, ensure_ascii=False))
|
| 688 |
+
return
|
| 689 |
+
|
| 690 |
+
out = Path(args.output_dir)
|
| 691 |
+
write_jsonl(train_recs, out / "combined_train.jsonl")
|
| 692 |
+
write_jsonl(val_recs, out / "combined_val.jsonl")
|
| 693 |
+
write_jsonl(test_recs, out / "combined_test.jsonl")
|
| 694 |
+
|
| 695 |
+
log.info("")
|
| 696 |
+
log.info("ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ")
|
| 697 |
+
log.info("β Pipeline complete. Next step: β")
|
| 698 |
+
log.info("β bash launch_grpo_combined.sh β")
|
| 699 |
+
log.info("ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ")
|
| 700 |
+
log.info(" train : %6d rows β %s/combined_train.jsonl", len(train_recs), out)
|
| 701 |
+
log.info(" val : %6d rows β %s/combined_val.jsonl", len(val_recs), out)
|
| 702 |
+
log.info(" test : %6d rows β %s/combined_test.jsonl", len(test_recs), out)
|
| 703 |
+
log.info("")
|
| 704 |
+
log.info("Skill coverage (for ZPD CurriculumManager):")
|
| 705 |
+
skill_c = Counter(r["skill_id"] for r in train_recs)
|
| 706 |
+
for sk, cnt in sorted(skill_c.items()):
|
| 707 |
+
log.info(" %-40s %5d train samples", sk, cnt)
|
| 708 |
+
|
| 709 |
+
|
| 710 |
+
if __name__ == "__main__":
|
| 711 |
+
main()
|
scripts/run_grpo_training.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
scripts/run_inference.py
ADDED
|
@@ -0,0 +1,502 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Inference pipeline: Base Qwen2.5-Math-1.5B-Instruct vs RL fine-tuned checkpoint.
|
| 4 |
+
|
| 5 |
+
For each sampled GSM8K question, both models generate a step-by-step solution.
|
| 6 |
+
Results are saved to reports/<run_name>/ as JSON files for the Gradio demo.
|
| 7 |
+
|
| 8 |
+
Usage
|
| 9 |
+
-----
|
| 10 |
+
# Full run (50 questions, both models):
|
| 11 |
+
python scripts/run_inference.py \\
|
| 12 |
+
--checkpoint checkpoints/grpo_run_v1 \\
|
| 13 |
+
--num-questions 50 \\
|
| 14 |
+
--run-name comparison_v1
|
| 15 |
+
|
| 16 |
+
# Quick smoke test (10 questions, no RL model):
|
| 17 |
+
python scripts/run_inference.py \\
|
| 18 |
+
--num-questions 10 \\
|
| 19 |
+
--base-only \\
|
| 20 |
+
--run-name smoke
|
| 21 |
+
|
| 22 |
+
# Custom data source:
|
| 23 |
+
python scripts/run_inference.py \\
|
| 24 |
+
--checkpoint checkpoints/grpo_run_v1 \\
|
| 25 |
+
--data data/sft/gsm8k_test.jsonl \\
|
| 26 |
+
--num-questions 30
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
from __future__ import annotations
|
| 30 |
+
|
| 31 |
+
import argparse
|
| 32 |
+
import json
|
| 33 |
+
import logging
|
| 34 |
+
import random
|
| 35 |
+
import sys
|
| 36 |
+
import time
|
| 37 |
+
from datetime import datetime
|
| 38 |
+
from pathlib import Path
|
| 39 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 40 |
+
|
| 41 |
+
import torch
|
| 42 |
+
from tqdm.auto import tqdm
|
| 43 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 44 |
+
|
| 45 |
+
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
| 46 |
+
|
| 47 |
+
from src.config.prompts import create_solver_messages
|
| 48 |
+
from src.sft.solution_format import extract_final_answer_numeric_str
|
| 49 |
+
from src.utils.attn_backend import select_attn_implementation
|
| 50 |
+
|
| 51 |
+
logging.basicConfig(
|
| 52 |
+
level=logging.INFO,
|
| 53 |
+
format="%(asctime)s %(levelname)-8s %(name)s - %(message)s",
|
| 54 |
+
)
|
| 55 |
+
logger = logging.getLogger(__name__)
|
| 56 |
+
|
| 57 |
+
BASE_MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct"
|
| 58 |
+
REPORTS_DIR = Path("reports")
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# ββ Data loading ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 62 |
+
|
| 63 |
+
def load_gsm8k_questions(
|
| 64 |
+
data_path: Optional[str],
|
| 65 |
+
num_questions: int,
|
| 66 |
+
seed: int = 42,
|
| 67 |
+
) -> List[Dict[str, str]]:
|
| 68 |
+
"""
|
| 69 |
+
Load GSM8K questions from a local JSONL file or fall back to HuggingFace.
|
| 70 |
+
|
| 71 |
+
Each returned record has keys: ``question``, ``gold_final``, ``answer``.
|
| 72 |
+
"""
|
| 73 |
+
# ββ Try local JSONL first ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 74 |
+
candidates = [data_path] if data_path else []
|
| 75 |
+
candidates += [
|
| 76 |
+
"data/sft/gsm8k_test.jsonl",
|
| 77 |
+
"data/sft/gsm8k_sft.jsonl",
|
| 78 |
+
]
|
| 79 |
+
|
| 80 |
+
for path in candidates:
|
| 81 |
+
if path and Path(path).exists():
|
| 82 |
+
logger.info("Loading GSM8K from local file: %s", path)
|
| 83 |
+
rows: List[Dict] = []
|
| 84 |
+
with open(path, encoding="utf-8") as f:
|
| 85 |
+
for line in f:
|
| 86 |
+
line = line.strip()
|
| 87 |
+
if line:
|
| 88 |
+
rows.append(json.loads(line))
|
| 89 |
+
rng = random.Random(seed)
|
| 90 |
+
sample = rng.sample(rows, min(num_questions, len(rows)))
|
| 91 |
+
logger.info("Sampled %d / %d questions.", len(sample), len(rows))
|
| 92 |
+
return sample
|
| 93 |
+
|
| 94 |
+
# ββ Fall back to HuggingFace datasets ββββββββββββββββββββββββββββββββββββ
|
| 95 |
+
logger.info("No local file found β downloading GSM8K from HuggingFaceβ¦")
|
| 96 |
+
try:
|
| 97 |
+
from datasets import load_dataset
|
| 98 |
+
ds = load_dataset("openai/gsm8k", "main", split="test")
|
| 99 |
+
except Exception as e:
|
| 100 |
+
raise RuntimeError(
|
| 101 |
+
"Could not load GSM8K. Provide --data or install datasets: pip install datasets"
|
| 102 |
+
) from e
|
| 103 |
+
|
| 104 |
+
rows = []
|
| 105 |
+
for item in ds:
|
| 106 |
+
q = item["question"].strip()
|
| 107 |
+
a = item["answer"].strip()
|
| 108 |
+
# GSM8K answers end with "#### <number>"
|
| 109 |
+
gold = a.split("####")[-1].strip() if "####" in a else ""
|
| 110 |
+
rows.append({"question": q, "gold_final": gold, "answer": a})
|
| 111 |
+
|
| 112 |
+
rng = random.Random(seed)
|
| 113 |
+
sample = rng.sample(rows, min(num_questions, len(rows)))
|
| 114 |
+
logger.info("Sampled %d questions from HF GSM8K test split.", len(sample))
|
| 115 |
+
return sample
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
# ββ Model loading βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 119 |
+
|
| 120 |
+
def load_base_model(
|
| 121 |
+
device: torch.device,
|
| 122 |
+
attn_impl: str,
|
| 123 |
+
) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
|
| 124 |
+
logger.info("Loading base model: %s", BASE_MODEL_ID)
|
| 125 |
+
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, trust_remote_code=True)
|
| 126 |
+
if tokenizer.pad_token is None:
|
| 127 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 128 |
+
tokenizer.padding_side = "left"
|
| 129 |
+
|
| 130 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 131 |
+
BASE_MODEL_ID,
|
| 132 |
+
torch_dtype=torch.bfloat16,
|
| 133 |
+
device_map={"": device},
|
| 134 |
+
trust_remote_code=True,
|
| 135 |
+
attn_implementation=attn_impl,
|
| 136 |
+
)
|
| 137 |
+
model.eval()
|
| 138 |
+
logger.info("Base model loaded.")
|
| 139 |
+
return model, tokenizer
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def load_rl_model(
|
| 143 |
+
checkpoint: str,
|
| 144 |
+
base_model: AutoModelForCausalLM,
|
| 145 |
+
base_tokenizer: AutoTokenizer,
|
| 146 |
+
device: torch.device,
|
| 147 |
+
attn_impl: str,
|
| 148 |
+
) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
|
| 149 |
+
"""
|
| 150 |
+
Load the RL fine-tuned checkpoint for comparison against the raw base model.
|
| 151 |
+
|
| 152 |
+
Two checkpoint formats are supported:
|
| 153 |
+
|
| 154 |
+
PEFT / LoRA adapter (has adapter_config.json)
|
| 155 |
+
The already-loaded base model weights are deep-copied in CPU memory,
|
| 156 |
+
the adapter is applied on top, then merged and unloaded.
|
| 157 |
+
This avoids downloading the 1.5B base weights from HuggingFace a
|
| 158 |
+
second time β the base model is downloaded only once per run.
|
| 159 |
+
|
| 160 |
+
Full saved model (has config.json, no adapter_config.json)
|
| 161 |
+
Loaded directly from disk with from_pretrained.
|
| 162 |
+
"""
|
| 163 |
+
import copy
|
| 164 |
+
|
| 165 |
+
ckpt_path = Path(checkpoint)
|
| 166 |
+
if not ckpt_path.exists():
|
| 167 |
+
raise FileNotFoundError(f"Checkpoint not found: {checkpoint}")
|
| 168 |
+
|
| 169 |
+
is_peft = (ckpt_path / "adapter_config.json").exists()
|
| 170 |
+
|
| 171 |
+
if is_peft:
|
| 172 |
+
logger.info(
|
| 173 |
+
"Loading PEFT adapter from %s (reusing base weights β no second HF download)",
|
| 174 |
+
checkpoint,
|
| 175 |
+
)
|
| 176 |
+
from peft import PeftModel
|
| 177 |
+
|
| 178 |
+
# Deep-copy the already-loaded base model so the base remains untouched
|
| 179 |
+
# for side-by-side comparison. For a 1.5B bfloat16 model this takes
|
| 180 |
+
# ~1-2 s and avoids re-downloading ~3 GB from HuggingFace.
|
| 181 |
+
base_copy = copy.deepcopy(base_model)
|
| 182 |
+
model = PeftModel.from_pretrained(base_copy, checkpoint)
|
| 183 |
+
model = model.merge_and_unload()
|
| 184 |
+
model = model.to(device)
|
| 185 |
+
else:
|
| 186 |
+
logger.info("Loading full model checkpoint from %s", checkpoint)
|
| 187 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 188 |
+
checkpoint,
|
| 189 |
+
torch_dtype=torch.bfloat16,
|
| 190 |
+
device_map={"": device},
|
| 191 |
+
trust_remote_code=True,
|
| 192 |
+
attn_implementation=attn_impl,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
# Patch chat_template from base tokenizer if missing
|
| 196 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 197 |
+
checkpoint if (ckpt_path / "tokenizer_config.json").exists() else BASE_MODEL_ID,
|
| 198 |
+
trust_remote_code=True,
|
| 199 |
+
)
|
| 200 |
+
if tokenizer.pad_token is None:
|
| 201 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 202 |
+
tokenizer.padding_side = "left"
|
| 203 |
+
if tokenizer.chat_template is None and base_tokenizer.chat_template:
|
| 204 |
+
tokenizer.chat_template = base_tokenizer.chat_template
|
| 205 |
+
|
| 206 |
+
model.eval()
|
| 207 |
+
logger.info("RL model loaded.")
|
| 208 |
+
return model, tokenizer
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
# ββ Inference βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 212 |
+
|
| 213 |
+
def generate_solution(
|
| 214 |
+
model: AutoModelForCausalLM,
|
| 215 |
+
tokenizer: AutoTokenizer,
|
| 216 |
+
question: str,
|
| 217 |
+
device: torch.device,
|
| 218 |
+
max_new_tokens: int = 512,
|
| 219 |
+
temperature: float = 0.1,
|
| 220 |
+
) -> Tuple[str, float]:
|
| 221 |
+
"""
|
| 222 |
+
Generate a step-by-step solution for ``question``.
|
| 223 |
+
|
| 224 |
+
Returns ``(solution_text, elapsed_seconds)``.
|
| 225 |
+
Low temperature (0.1) for deterministic, greedy-like output during eval.
|
| 226 |
+
"""
|
| 227 |
+
messages = create_solver_messages(question)
|
| 228 |
+
prompt = tokenizer.apply_chat_template(
|
| 229 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 230 |
+
)
|
| 231 |
+
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
|
| 232 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 233 |
+
prompt_len = inputs["input_ids"].shape[1]
|
| 234 |
+
|
| 235 |
+
stop_ids = [tokenizer.eos_token_id]
|
| 236 |
+
im_end = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
| 237 |
+
if isinstance(im_end, int) and im_end not in stop_ids:
|
| 238 |
+
stop_ids.append(im_end)
|
| 239 |
+
|
| 240 |
+
t0 = time.time()
|
| 241 |
+
with torch.no_grad():
|
| 242 |
+
output = model.generate(
|
| 243 |
+
**inputs,
|
| 244 |
+
max_new_tokens=max_new_tokens,
|
| 245 |
+
do_sample=temperature > 0.05,
|
| 246 |
+
temperature=temperature if temperature > 0.05 else None,
|
| 247 |
+
top_p=0.95 if temperature > 0.05 else None,
|
| 248 |
+
eos_token_id=stop_ids,
|
| 249 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 250 |
+
use_cache=True,
|
| 251 |
+
)
|
| 252 |
+
elapsed = time.time() - t0
|
| 253 |
+
|
| 254 |
+
response_ids = output[0][prompt_len:]
|
| 255 |
+
solution = tokenizer.decode(response_ids, skip_special_tokens=True).strip()
|
| 256 |
+
return solution, elapsed
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def score_answer(solution: str, gold_final: str) -> Dict[str, Any]:
|
| 260 |
+
"""
|
| 261 |
+
Extract the predicted final answer and compare with gold.
|
| 262 |
+
Returns a dict with ``predicted``, ``gold``, ``correct``, ``match_type``.
|
| 263 |
+
"""
|
| 264 |
+
predicted_raw = extract_final_answer_numeric_str(solution)
|
| 265 |
+
|
| 266 |
+
if predicted_raw is None:
|
| 267 |
+
return {
|
| 268 |
+
"predicted": None,
|
| 269 |
+
"gold": gold_final,
|
| 270 |
+
"correct": False,
|
| 271 |
+
"match_type": "no_answer_found",
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
# Normalise: strip whitespace, remove commas (e.g. "1,200" β "1200")
|
| 275 |
+
def _norm(s: str) -> str:
|
| 276 |
+
return s.strip().replace(",", "").rstrip(".").lower()
|
| 277 |
+
|
| 278 |
+
pred_n = _norm(predicted_raw)
|
| 279 |
+
gold_n = _norm(gold_final)
|
| 280 |
+
|
| 281 |
+
# Direct string match
|
| 282 |
+
if pred_n == gold_n:
|
| 283 |
+
return {
|
| 284 |
+
"predicted": predicted_raw,
|
| 285 |
+
"gold": gold_final,
|
| 286 |
+
"correct": True,
|
| 287 |
+
"match_type": "exact",
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
# Numeric match (handles float/int equivalence)
|
| 291 |
+
try:
|
| 292 |
+
pred_f = float(pred_n)
|
| 293 |
+
gold_f = float(gold_n)
|
| 294 |
+
if abs(pred_f - gold_f) < 1e-6:
|
| 295 |
+
return {
|
| 296 |
+
"predicted": predicted_raw,
|
| 297 |
+
"gold": gold_final,
|
| 298 |
+
"correct": True,
|
| 299 |
+
"match_type": "numeric",
|
| 300 |
+
}
|
| 301 |
+
except (ValueError, TypeError):
|
| 302 |
+
pass
|
| 303 |
+
|
| 304 |
+
return {
|
| 305 |
+
"predicted": predicted_raw,
|
| 306 |
+
"gold": gold_final,
|
| 307 |
+
"correct": False,
|
| 308 |
+
"match_type": "wrong",
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
# ββ Report serialisation ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 313 |
+
|
| 314 |
+
def save_question_report(
|
| 315 |
+
report_dir: Path,
|
| 316 |
+
idx: int,
|
| 317 |
+
question: str,
|
| 318 |
+
gold_final: str,
|
| 319 |
+
base_result: Dict[str, Any],
|
| 320 |
+
rl_result: Optional[Dict[str, Any]],
|
| 321 |
+
) -> Path:
|
| 322 |
+
record = {
|
| 323 |
+
"idx": idx,
|
| 324 |
+
"question": question,
|
| 325 |
+
"gold_final": gold_final,
|
| 326 |
+
"base_model": base_result,
|
| 327 |
+
"rl_model": rl_result,
|
| 328 |
+
}
|
| 329 |
+
out = report_dir / f"q_{idx:04d}.json"
|
| 330 |
+
out.write_text(json.dumps(record, indent=2, ensure_ascii=False), encoding="utf-8")
|
| 331 |
+
return out
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def save_summary(
|
| 335 |
+
report_dir: Path,
|
| 336 |
+
run_name: str,
|
| 337 |
+
checkpoint: Optional[str],
|
| 338 |
+
base_correct: int,
|
| 339 |
+
rl_correct: Optional[int],
|
| 340 |
+
total: int,
|
| 341 |
+
total_time_s: float,
|
| 342 |
+
args_dict: Dict,
|
| 343 |
+
) -> None:
|
| 344 |
+
summary = {
|
| 345 |
+
"run_name": run_name,
|
| 346 |
+
"timestamp": datetime.now().isoformat(),
|
| 347 |
+
"base_model": BASE_MODEL_ID,
|
| 348 |
+
"rl_checkpoint": checkpoint,
|
| 349 |
+
"num_questions": total,
|
| 350 |
+
"base_accuracy": round(base_correct / total, 4) if total else 0,
|
| 351 |
+
"rl_accuracy": round(rl_correct / total, 4) if (rl_correct is not None and total) else None,
|
| 352 |
+
"base_correct": base_correct,
|
| 353 |
+
"rl_correct": rl_correct,
|
| 354 |
+
"total_time_s": round(total_time_s, 1),
|
| 355 |
+
"args": args_dict,
|
| 356 |
+
}
|
| 357 |
+
out = report_dir / "summary.json"
|
| 358 |
+
out.write_text(json.dumps(summary, indent=2), encoding="utf-8")
|
| 359 |
+
logger.info("Summary saved β %s", out)
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
# ββ Main ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 363 |
+
|
| 364 |
+
def parse_args() -> argparse.Namespace:
|
| 365 |
+
p = argparse.ArgumentParser(description="Run inference: base vs RL model on GSM8K")
|
| 366 |
+
p.add_argument("--checkpoint", type=str, default=None,
|
| 367 |
+
help="Path to RL fine-tuned model or PEFT adapter. "
|
| 368 |
+
"If omitted, only the base model is run.")
|
| 369 |
+
p.add_argument("--data", type=str, default=None,
|
| 370 |
+
help="Path to local GSM8K JSONL file. "
|
| 371 |
+
"Defaults to data/sft/gsm8k_test.jsonl or HuggingFace.")
|
| 372 |
+
p.add_argument("--num-questions", type=int, default=50)
|
| 373 |
+
p.add_argument("--seed", type=int, default=42)
|
| 374 |
+
p.add_argument("--max-new-tokens", type=int, default=512)
|
| 375 |
+
p.add_argument("--temperature", type=float, default=0.1)
|
| 376 |
+
p.add_argument("--run-name", type=str, default=None,
|
| 377 |
+
help="Report sub-folder name. Defaults to timestamp.")
|
| 378 |
+
p.add_argument("--base-only", action="store_true",
|
| 379 |
+
help="Skip RL model; only run the base model.")
|
| 380 |
+
p.add_argument("--reports-dir", type=str, default="reports")
|
| 381 |
+
return p.parse_args()
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
def main() -> None:
|
| 385 |
+
args = parse_args()
|
| 386 |
+
|
| 387 |
+
run_name = args.run_name or f"run_{datetime.now():%Y%m%d_%H%M%S}"
|
| 388 |
+
report_dir = Path(args.reports_dir) / run_name
|
| 389 |
+
report_dir.mkdir(parents=True, exist_ok=True)
|
| 390 |
+
logger.info("Reports β %s", report_dir)
|
| 391 |
+
|
| 392 |
+
# ββ Device ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 393 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 394 |
+
attn_impl = select_attn_implementation()
|
| 395 |
+
logger.info("Device: %s | attn: %s", device, attn_impl)
|
| 396 |
+
if torch.cuda.is_available():
|
| 397 |
+
g = torch.cuda.get_device_properties(0)
|
| 398 |
+
logger.info("GPU: %s | %.1f GB", g.name, g.total_memory / 1e9)
|
| 399 |
+
|
| 400 |
+
# ββ Data ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 401 |
+
questions = load_gsm8k_questions(args.data, args.num_questions, args.seed)
|
| 402 |
+
|
| 403 |
+
# ββ Models ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 404 |
+
base_model, base_tokenizer = load_base_model(device, attn_impl)
|
| 405 |
+
|
| 406 |
+
rl_model, rl_tokenizer = None, None
|
| 407 |
+
if not args.base_only and args.checkpoint:
|
| 408 |
+
rl_model, rl_tokenizer = load_rl_model(
|
| 409 |
+
args.checkpoint, base_model, base_tokenizer, device, attn_impl
|
| 410 |
+
)
|
| 411 |
+
elif not args.base_only and not args.checkpoint:
|
| 412 |
+
logger.warning("No --checkpoint provided. Running base model only.")
|
| 413 |
+
|
| 414 |
+
# ββ Inference loop ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 415 |
+
base_correct = 0
|
| 416 |
+
rl_correct = 0 if rl_model else None
|
| 417 |
+
t_total_start = time.time()
|
| 418 |
+
|
| 419 |
+
for idx, row in enumerate(tqdm(questions, desc="Inference")):
|
| 420 |
+
question = row["question"]
|
| 421 |
+
gold_final = row.get("gold_final", "").strip()
|
| 422 |
+
|
| 423 |
+
# Base model
|
| 424 |
+
base_solution, base_time = generate_solution(
|
| 425 |
+
base_model, base_tokenizer, question, device,
|
| 426 |
+
args.max_new_tokens, args.temperature,
|
| 427 |
+
)
|
| 428 |
+
base_score = score_answer(base_solution, gold_final)
|
| 429 |
+
if base_score["correct"]:
|
| 430 |
+
base_correct += 1
|
| 431 |
+
|
| 432 |
+
base_result = {
|
| 433 |
+
"solution": base_solution,
|
| 434 |
+
"predicted": base_score["predicted"],
|
| 435 |
+
"correct": base_score["correct"],
|
| 436 |
+
"match_type": base_score["match_type"],
|
| 437 |
+
"time_s": round(base_time, 2),
|
| 438 |
+
"num_tokens": len(base_tokenizer.encode(base_solution)),
|
| 439 |
+
}
|
| 440 |
+
|
| 441 |
+
# RL model
|
| 442 |
+
rl_result = None
|
| 443 |
+
if rl_model is not None:
|
| 444 |
+
rl_solution, rl_time = generate_solution(
|
| 445 |
+
rl_model, rl_tokenizer, question, device,
|
| 446 |
+
args.max_new_tokens, args.temperature,
|
| 447 |
+
)
|
| 448 |
+
rl_score = score_answer(rl_solution, gold_final)
|
| 449 |
+
if rl_score["correct"]:
|
| 450 |
+
rl_correct += 1
|
| 451 |
+
|
| 452 |
+
rl_result = {
|
| 453 |
+
"solution": rl_solution,
|
| 454 |
+
"predicted": rl_score["predicted"],
|
| 455 |
+
"correct": rl_score["correct"],
|
| 456 |
+
"match_type": rl_score["match_type"],
|
| 457 |
+
"time_s": round(rl_time, 2),
|
| 458 |
+
"num_tokens": len(rl_tokenizer.encode(rl_solution)),
|
| 459 |
+
}
|
| 460 |
+
|
| 461 |
+
save_question_report(report_dir, idx, question, gold_final, base_result, rl_result)
|
| 462 |
+
|
| 463 |
+
# Live progress log every 10 questions
|
| 464 |
+
if (idx + 1) % 10 == 0 or idx == len(questions) - 1:
|
| 465 |
+
done = idx + 1
|
| 466 |
+
b_acc = base_correct / done
|
| 467 |
+
log_str = f"[{done}/{len(questions)}] Base acc: {b_acc:.1%}"
|
| 468 |
+
if rl_correct is not None:
|
| 469 |
+
log_str += f" | RL acc: {rl_correct / done:.1%}"
|
| 470 |
+
logger.info(log_str)
|
| 471 |
+
|
| 472 |
+
total_time = time.time() - t_total_start
|
| 473 |
+
|
| 474 |
+
# ββ Summary βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 475 |
+
save_summary(
|
| 476 |
+
report_dir=report_dir,
|
| 477 |
+
run_name=run_name,
|
| 478 |
+
checkpoint=args.checkpoint,
|
| 479 |
+
base_correct=base_correct,
|
| 480 |
+
rl_correct=rl_correct,
|
| 481 |
+
total=len(questions),
|
| 482 |
+
total_time_s=total_time,
|
| 483 |
+
args_dict=vars(args),
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
logger.info("=" * 60)
|
| 487 |
+
logger.info("Run complete: %s", run_name)
|
| 488 |
+
logger.info("Base accuracy : %d / %d = %.1f%%",
|
| 489 |
+
base_correct, len(questions), 100 * base_correct / len(questions))
|
| 490 |
+
if rl_correct is not None:
|
| 491 |
+
logger.info("RL accuracy : %d / %d = %.1f%%",
|
| 492 |
+
rl_correct, len(questions), 100 * rl_correct / len(questions))
|
| 493 |
+
delta = rl_correct - base_correct
|
| 494 |
+
sign = "+" if delta >= 0 else ""
|
| 495 |
+
logger.info("Delta : %s%d questions (%s%.1f%%)",
|
| 496 |
+
sign, delta, sign, 100 * delta / len(questions))
|
| 497 |
+
logger.info("Reports : %s", report_dir)
|
| 498 |
+
logger.info("=" * 60)
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
if __name__ == "__main__":
|
| 502 |
+
main()
|