jampuramprem commited on
Commit
ec4ae03
Β·
0 Parent(s):

Initial Space deployment

Browse files
This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. .gitignore +73 -0
  2. Dockerfile +130 -0
  3. README.md +132 -0
  4. __init__.py +16 -0
  5. blog.md +94 -0
  6. client.py +76 -0
  7. docs/environment-overview.puml +69 -0
  8. docs/reward-system.puml +51 -0
  9. docs/training-phases.puml +27 -0
  10. images/axiomforgeai_scenes/scene_01.svg +52 -0
  11. images/axiomforgeai_scenes/scene_02.svg +72 -0
  12. images/axiomforgeai_scenes/scene_03.svg +67 -0
  13. images/axiomforgeai_scenes/scene_04.svg +78 -0
  14. images/axiomforgeai_scenes/scene_05.svg +66 -0
  15. images/axiomforgeai_scenes/scene_06.svg +79 -0
  16. images/axiomforgeai_scenes/scene_07.svg +66 -0
  17. images/axiomforgeai_scenes/scene_08.svg +74 -0
  18. images/axiomforgeai_scenes/scene_09.svg +61 -0
  19. images/axiomforgeai_scenes/scene_10.svg +86 -0
  20. images/blog_flow/architecture.svg +50 -0
  21. images/blog_flow/grading.svg +45 -0
  22. images/blog_flow/grpo-loop.svg +44 -0
  23. images/blog_flow/task-sources.svg +35 -0
  24. images/environment_overview.svg +0 -0
  25. images/training_phases.svg +1 -0
  26. logs/grpo/grpo_20260426_024029.log +44 -0
  27. logs/grpo/grpo_20260426_032827.log +0 -0
  28. logs/grpo/grpo_20260426_032827/config.json +44 -0
  29. logs/grpo/grpo_20260426_032827/console_output.log +0 -0
  30. logs/grpo/grpo_20260426_032827/metrics.csv +31 -0
  31. logs/metrics.jsonl +31 -0
  32. models.py +67 -0
  33. openenv.yaml +7 -0
  34. pyproject.toml +55 -0
  35. requirements.txt +160 -0
  36. scripts/__init__.py +1 -0
  37. scripts/convert_gsm8k_to_sft.py +193 -0
  38. scripts/create_dual_task_dataset.py +321 -0
  39. scripts/demo_before_after.py +591 -0
  40. scripts/dual_task_sft_pipeline.py +390 -0
  41. scripts/eval_sft_inference.py +565 -0
  42. scripts/gsm8k_sft_pipeline.py +475 -0
  43. scripts/launch_grpo.sh +127 -0
  44. scripts/plot_grpo_run.py +425 -0
  45. scripts/plot_training_results.py +521 -0
  46. scripts/precompute_extraction_cache.py +174 -0
  47. scripts/prepare_aqua_dataset.py +265 -0
  48. scripts/prepare_combined_dataset.py +711 -0
  49. scripts/run_grpo_training.py +0 -0
  50. scripts/run_inference.py +502 -0
.gitignore ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ dist/
13
+ downloads/
14
+ eggs/
15
+ .eggs/
16
+ lib/
17
+ lib64/
18
+ parts/
19
+ sdist/
20
+ var/
21
+ wheels/
22
+ share/python-wheels/
23
+ *.egg-info/
24
+ .installed.cfg
25
+ *.egg
26
+ MANIFEST
27
+
28
+ # Installer logs
29
+ pip-log.txt
30
+ pip-delete-this-directory.txt
31
+
32
+ # Unit test / coverage reports
33
+ .pytest_cache/
34
+ .coverage
35
+ .coverage.*
36
+ htmlcov/
37
+ .tox/
38
+ .nox/
39
+ coverage.xml
40
+ *.cover
41
+ *.py,cover
42
+
43
+ # Type checkers / static analyzers
44
+ .mypy_cache/
45
+ .pyre/
46
+ .ruff_cache/
47
+ .pytype/
48
+
49
+ # Virtual environments
50
+ .venv/
51
+ venv/
52
+ env/
53
+ ENV/
54
+
55
+ # Local environment files
56
+ .env
57
+ .env.*
58
+ *.local
59
+
60
+ # IDE / editor files
61
+ .vscode/
62
+ .idea/
63
+ *.swp
64
+ *.swo
65
+ *~
66
+
67
+ # OS files
68
+ .DS_Store
69
+ Thumbs.db
70
+ data/
71
+
72
+ */ui
73
+ images/
Dockerfile ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AxiomForgeAI β€” GRPO Training Image
2
+ # ─────────────────────────────────────────────────────────────────────────────
3
+ # Hardware target : 1Γ— A100 PCIE 80 GB | AMD EPYC 7V13 | NVMe 300 GB
4
+ #
5
+ # CUDA driver : >= 13.0 (enforced at container start via entrypoint)
6
+ # CUDA toolkit : 12.4.1 (backward-compatible with driver 13.x)
7
+ # PyTorch : 2.5.1+cu124 (pinned in requirements.txt)
8
+ # Flash-Attn : 2.8.3 (pinned in requirements.txt)
9
+ #
10
+ # All Python package versions are taken exclusively from requirements.txt.
11
+ # No versions are hard-coded in this file.
12
+ #
13
+ # ── Build ─────────────────────────────────────────────────────────────────────
14
+ # docker build -t axiomforgeai-train:latest .
15
+ #
16
+ # ── Interactive shell ─────────────────────────────────────────────────────────
17
+ # docker run --gpus all --ipc=host --ulimit memlock=-1 \
18
+ # -v $(pwd)/data:/workspace/data \
19
+ # -v $(pwd)/checkpoints:/workspace/checkpoints \
20
+ # -v $(pwd)/logs:/workspace/logs \
21
+ # -it axiomforgeai-train:latest bash
22
+ #
23
+ # ── GRPO training (one-shot) ──────────────────────────────────────────────────
24
+ # docker run --gpus all --ipc=host --ulimit memlock=-1 \
25
+ # -v $(pwd)/data:/workspace/data \
26
+ # -v $(pwd)/checkpoints:/workspace/checkpoints \
27
+ # -v $(pwd)/logs:/workspace/logs \
28
+ # axiomforgeai-train:latest \
29
+ # python scripts/run_grpo_training.py \
30
+ # --base-model checkpoints/dual_task_v1 \
31
+ # --gsm8k-data data/sft/gsm8k_sft.jsonl \
32
+ # --num-iterations 30 --group-size 8 --questions-per-iter 16
33
+ # ─────────────────────────────────────────────────────────────────────────────
34
+
35
+ # CUDA toolkit 12.4.1 β€” matches the cu124 wheels in requirements.txt and is
36
+ # fully compatible with the A100's CUDA 13.2 driver (driver is always β‰₯ toolkit).
37
+ FROM nvidia/cuda:12.4.1-devel-ubuntu22.04
38
+
39
+ LABEL org.opencontainers.image.title="AxiomForgeAI Training" \
40
+ cuda.driver.minimum="13.0" \
41
+ cuda.toolkit="12.4.1" \
42
+ torch.version="2.5.1+cu124" \
43
+ flash_attn.version="2.8.3"
44
+
45
+ # ── System packages ────────────────────────────────────────────────────────────
46
+ ENV DEBIAN_FRONTEND=noninteractive
47
+ RUN apt-get update && apt-get install -y --no-install-recommends \
48
+ python3.11 \
49
+ python3.11-dev \
50
+ python3-pip \
51
+ python3.11-venv \
52
+ git \
53
+ git-lfs \
54
+ curl \
55
+ wget \
56
+ build-essential \
57
+ ninja-build \
58
+ pkg-config \
59
+ libssl-dev \
60
+ libffi-dev \
61
+ ca-certificates \
62
+ && ln -sf /usr/bin/python3.11 /usr/bin/python3 \
63
+ && ln -sf /usr/bin/python3 /usr/bin/python \
64
+ && rm -rf /var/lib/apt/lists/*
65
+
66
+ # ── Upgrade pip + build tooling ───────────────────────────────────────────────
67
+ RUN python -m pip install --upgrade --no-cache-dir pip setuptools wheel
68
+
69
+ # ── PyTorch (CUDA 12.4 wheels) ────────────────────────────────────────────────
70
+ # Must be installed before flash-attn because flash-attn runs a torch version
71
+ # check at install time. The cu124 index is also used for all CUDA-linked wheels.
72
+ # Version is taken from requirements.txt β€” the --constraint flag keeps pip from
73
+ # re-resolving to a different version when requirements.txt is processed next.
74
+ RUN pip install --no-cache-dir \
75
+ --extra-index-url https://download.pytorch.org/whl/cu124 \
76
+ "torch==2.5.1" "torchvision==0.20.1" "torchaudio==2.5.1"
77
+
78
+ # ── All remaining pinned requirements (from requirements.txt) ─────────────────
79
+ # flash-attn, xformers, vllm, triton, bitsandbytes, transformers, accelerate,
80
+ # peft, ray, sympy, scipy, numpy, openenv-core, fastapi, uvicorn, … are all
81
+ # installed here at the exact versions pinned in requirements.txt.
82
+ # The cu124 index is provided so CUDA-linked wheels resolve correctly.
83
+ COPY requirements.txt /tmp/requirements.txt
84
+ RUN pip install --no-cache-dir \
85
+ --extra-index-url https://download.pytorch.org/whl/cu124 \
86
+ -r /tmp/requirements.txt
87
+
88
+ # ── Project source ───────────��────────────────────────────────────────────────
89
+ WORKDIR /workspace
90
+ COPY . /workspace/
91
+
92
+ # ── Environment variables ─────────────────────────────────────────────────────
93
+ # Repo root on PYTHONPATH so `from src.rl.X import Y` works without editable install
94
+ ENV PYTHONPATH="/workspace:$PYTHONPATH"
95
+
96
+ # HuggingFace model cache β€” mount a host path here to persist model downloads:
97
+ # -v /host/hf_cache:/workspace/.hf_cache
98
+ ENV HF_HOME="/workspace/.hf_cache"
99
+ ENV TRANSFORMERS_CACHE="/workspace/.hf_cache"
100
+
101
+ # A100 CUDA / NCCL tuning
102
+ ENV CUDA_DEVICE_MAX_CONNECTIONS=1
103
+ ENV NCCL_P2P_DISABLE=0
104
+ ENV NCCL_IB_DISABLE=0
105
+ # Required for Flash-Attn 2 with bfloat16 on Ampere
106
+ ENV TORCH_CUDNN_V8_API_ENABLED=1
107
+
108
+ # ── Runtime entrypoint: enforce CUDA driver >= 13.0 ──────────────────────────
109
+ # nvidia-smi is injected at runtime via --gpus, so this check runs when the
110
+ # container starts, not at build time.
111
+ RUN printf '%s\n' \
112
+ '#!/bin/sh' \
113
+ 'if command -v nvidia-smi >/dev/null 2>&1; then' \
114
+ ' CUDA_VER=$(nvidia-smi 2>/dev/null | grep -oP "CUDA Version: \K[0-9.]+" || echo "0.0")' \
115
+ ' MAJOR=$(echo "$CUDA_VER" | cut -d. -f1)' \
116
+ ' echo "[AxiomForgeAI] CUDA driver reports toolkit: $CUDA_VER"' \
117
+ ' if [ "${MAJOR:-0}" -lt 13 ] 2>/dev/null; then' \
118
+ ' echo "[ERROR] CUDA driver >= 13.0 required; detected $CUDA_VER. Upgrade your NVIDIA driver."' \
119
+ ' exit 1' \
120
+ ' fi' \
121
+ ' echo "[AxiomForgeAI] CUDA $CUDA_VER >= 13.0 β€” OK"' \
122
+ 'else' \
123
+ ' echo "[WARNING] nvidia-smi not found β€” CUDA driver version check skipped."' \
124
+ 'fi' \
125
+ 'exec "$@"' \
126
+ > /usr/local/bin/entrypoint.sh \
127
+ && chmod +x /usr/local/bin/entrypoint.sh
128
+
129
+ ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
130
+ CMD ["bash"]
README.md ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: AxiomForgeAI Environment Server
3
+ emoji: 🌌
4
+ colorFrom: indigo
5
+ colorTo: pink
6
+ sdk: docker
7
+ pinned: false
8
+ app_port: 8000
9
+ base_path: /web
10
+ tags:
11
+ - openenv
12
+ ---
13
+
14
+ # AxiomForgeAI
15
+
16
+ [![OpenEnv](https://img.shields.io/badge/Powered%20by-OpenEnv-blue)](https://github.com/meta-pytorch/OpenEnv)
17
+
18
+ *A self-improving math environment where a model practices on verified problems, generates new challenges when ready, and learns from solution attempts whose reasoning steps and final answers agree.*
19
+
20
+ ## The Problem
21
+
22
+ Math reasoning models can fail in two different ways. Sometimes the setup, arithmetic, and algebraic steps look reasonable, but the final answer is wrong. Sometimes the final answer is right, but the reasoning that produced it is incomplete, inconsistent, or hard to trust.
23
+
24
+ For a math user, both failures matter. Checking only the final answer misses where the solution went off track. Checking only the steps misses whether the work actually reaches the right result. The useful signal is the agreement between the reasoning path and the final answer.
25
+
26
+ This project builds a practice loop around that signal. The model first works on problems with known answers, gets feedback on both the chain of reasoning and the final result, and only then starts generating new challenges for itself. The constraint is intentionally small: a 1.5B math model.
27
+
28
+ ## The Environment
29
+
30
+ The environment is a practice loop for math reasoning. Each training group starts with one problem, asks the model for multiple solution attempts, scores those attempts from several angles, and uses GRPO to reinforce the attempts that are stronger than the rest of the group.
31
+
32
+ ![AxiomForgeAI environment overview](images/environment_overview.svg)
33
+
34
+ The environment has two task sources:
35
+
36
+ - **Grounded source:** A dataset problem from GSM8K / MATH comes with a known final answer. This gives the environment a reliable anchor for checking whether the model actually reached the right result.
37
+ - **Self-play source:** The curriculum selects a target skill and difficulty. The model writes a new question, then samples multiple solutions to that question. This adds practice beyond static datasets, but only after the grounded signal is stable enough.
38
+
39
+ Both sources feed the same scoring and update loop. For every selected problem, the model samples `K` candidate solutions. The environment checks final-answer correctness when a gold answer exists, scores reasoning quality with a PRM, checks chain consistency and symbolic arithmetic where possible, checks answer formatting, and scores self-generated questions for clarity, novelty, difficulty fit, and solvability.
40
+
41
+ GRPO then compares the `K` attempts against each other. The model is not rewarded for a solution in isolation; the strongest attempt in the group becomes the direction for learning. Training starts grounded-only, gradually mixes in self-play groups, and falls back to grounded practice if generated-question quality or answer correctness drops.
42
+
43
+ ## How Self-Improvement Works
44
+
45
+ Self-improvement comes from turning each problem into a small comparison. The model does not produce one solution and move on; the environment samples several attempts, scores each attempt, and asks which reasoning path was strongest.
46
+
47
+ GRPO uses that within-group comparison as the learning signal. Attempts with correct answers, stronger reasoning chains, and cleaner final-answer format are reinforced. Attempts with broken chains or unsupported answers become weaker examples.
48
+
49
+ ```text
50
+ practice -> sample attempts -> verify steps and answer -> compare -> reinforce -> adjust difficulty
51
+ ```
52
+
53
+ ## Reward System
54
+
55
+ The reward is designed to avoid a common math-training failure: optimizing for either the final answer or the reasoning trace alone. A good solution should reach the right answer, explain the path clearly, and keep the final result consistent with the steps that produced it.
56
+
57
+ | Signal | What it checks | Why it matters |
58
+ | --- | --- | --- |
59
+ | Final answer | Matches the gold answer when one exists | Keeps grounded problems tied to objective correctness |
60
+ | Process score | PRM score over the reasoning steps | Rewards clear mathematical progress, not just the last line |
61
+ | Chain consistency | Correct-prefix and step-answer consistency signals | Gives partial learning signal when a solution goes wrong midway |
62
+ | Format | Parseable final answer and clean response structure | Makes automatic grading reliable |
63
+ | Question quality | Topic fit, difficulty fit, clarity, novelty, and solvability | Keeps self-play from generating vague or useless practice tasks |
64
+
65
+ Grounded problems use the gold answer as the anchor. Self-play problems add a question-quality score before the solution reward is trusted. Both paths produce one combined score for each sampled attempt, and GRPO uses those scores only in comparison with the other attempts from the same problem.
66
+
67
+ ```text
68
+ grounded: answer correctness + process score + chain consistency + format
69
+ self-play: question quality + solution quality
70
+ both -> one combined score per attempt -> GRPO compares attempts within the group
71
+ ```
72
+
73
+ ## Training Phases
74
+
75
+ Training follows a simple three-phase schedule. It starts with grounded-only practice so the model learns to keep answers and reasoning stable on problems with known solutions. Self-play is then introduced gradually, while grounded questions remain as an anchor. Once both are stable, training continues with a mixed task source and falls back to grounded-only batches if answer quality drops.
76
+
77
+ ![Training phases overview](images/training_phases.svg)
78
+
79
+ ## Training Script
80
+
81
+ The GRPO training loop is available in two forms:
82
+
83
+ - [`scripts/launch_grpo.sh`](scripts/launch_grpo.sh) β€” the primary launch script; sets CUDA/threading env vars, verifies Flash-Attention, and calls `run_grpo_training.py` with the full parameter set.
84
+
85
+ ```bash
86
+ bash scripts/launch_grpo.sh
87
+ ```
88
+ - [`train_grpo.ipynb`](train_grpo.ipynb) β€” notebook version with the same parameters, structured around `env.reset / env.step / env.state / env.close` for interactive inspection.
89
+
90
+
91
+ ## Results
92
+
93
+ These plots come from a single GPU training run and focus on the core question: did the model get better at making its reasoning and final answer agree?
94
+
95
+ ### Evaluation Quality Over Training
96
+
97
+ ![Evaluation quality over training](images/plot1_eval_quality.png)
98
+
99
+ The environment tracks final correctness, solution quality, step validity, and how long the reasoning chain stays correct. All four move upward together, which suggests the model is not just finding better final answers. It is also producing reasoning that holds up longer.
100
+
101
+ ### Training Journey
102
+
103
+ ![Training journey across all 30 iterations](images/plot2_training_journey.png)
104
+
105
+ Training starts with grounded practice on problems with known answers. Self-play is introduced only after the grounded signal is stable, so the model does not train on its own generated problems too early. The transition is conditional, not just a timer.
106
+
107
+ ### Self-Play Curriculum
108
+
109
+ ![Self-play curriculum ramp and question quality](images/plot3_selfplay_success.png)
110
+
111
+ By the end of training, most practice came from self-play. The important part is that generated problems stayed solvable and novel even after self-play became a larger share of training. That makes the ramp meaningful: self-play added useful practice instead of recycled noise.
112
+
113
+ ### Reward Confidence
114
+
115
+ ![Reward confidence and skipped groups](images/plot4_reward_confidence.png)
116
+
117
+ The reward spread shows how much contrast exists between the model's best and worst attempts. Wide spread gives GRPO something to learn from. Skipped groups are cases where attempts are too similar to compare usefully. That rate falls as harder material enters the curriculum, which suggests the comparison signal stays useful.
118
+
119
+ ### Step-Level Reasoning Quality
120
+
121
+ ![Step accuracy and LCCP across training](images/plot5_reasoning_quality.png)
122
+
123
+ Step accuracy checks whether each line of reasoning is valid. Chain integrity checks whether those valid steps form an unbroken path to the answer. Both improve together, which means the model is building solutions that hold together more often instead of only producing better-looking outputs.
124
+
125
+ ## Why It Matters
126
+
127
+ Reliable math reasoning needs more than fluent explanations or lucky final answers. A system that can separate correct reasoning from unsupported answers gives the model a better training target: not just "get the number," but build a chain of logic that reaches the number.
128
+
129
+ AxiomForgeAI matters because it turns that target into an environment. The same pattern can extend beyond math to other verifiable domains where attempts can be checked, compared, and improved: code, logic, structured data transformations, and scientific problem solving.
130
+
131
+ ---
132
+ *Engineered for the OpenEnv Hackathon India 2026*
__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Axiomforgeai Environment."""
8
+
9
+ from .client import AxiomforgeaiEnv
10
+ from .models import AxiomforgeaiAction, AxiomforgeaiObservation
11
+
12
+ __all__ = [
13
+ "AxiomforgeaiAction",
14
+ "AxiomforgeaiObservation",
15
+ "AxiomforgeaiEnv",
16
+ ]
blog.md ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AxiomForgeAI: Self-Improving Math Models Need More Than the Final Answer
2
+
3
+ Math models have a strange failure mode.
4
+
5
+ They can write a solution that looks careful, step-by-step, and confident, then end with the wrong answer. They can also produce the right final number with reasoning that is incomplete, inconsistent, or impossible to trust.
6
+
7
+ For math, that gap matters. The final answer is not enough. A proof, derivation, or word-problem solution only becomes useful when the path and the answer support each other.
8
+
9
+ AxiomForgeAI is built around that idea.
10
+
11
+ Instead of treating math reasoning as a one-shot generation problem, AxiomForgeAI turns it into a practice environment. The model does not simply answer a question and move on. It attempts the same problem multiple ways, receives feedback on both the reasoning path and the final answer, and learns from the attempts where the two agree.
12
+
13
+ ## The Architecture
14
+
15
+ ![AxiomForgeAI architecture](./images/blog_flow/architecture.svg)
16
+
17
+ AxiomForgeAI is a training loop around one simple idea: a math solution should be judged by whether the reasoning path and the final answer support each other.
18
+
19
+ The environment first selects one task. It can come from a grounded dataset problem with a known answer, or from a self-play question written from a curriculum target. Only after that task is selected does the model sample `K` candidate solutions. The environment scores each attempt, and GRPO compares the attempts within that same problem group.
20
+
21
+ That is the important part. The model is not rewarded for sounding fluent. It is rewarded when the chain of reasoning and the final answer line up.
22
+
23
+ ## Where Practice Comes From
24
+
25
+ ![Task sources](./images/blog_flow/task-sources.svg)
26
+
27
+ The environment uses two sources of problems.
28
+
29
+ Grounded practice starts with dataset problems from sources like GSM8K or MATH. These problems come with known final answers, so the environment has a reliable anchor for correctness.
30
+
31
+ Self-play starts later. The curriculum selects a skill and difficulty, and the model writes a new question. That question is only useful if it is clear, solvable, on-topic, and appropriately difficult. This keeps self-play from becoming random problem generation.
32
+
33
+ Both sources eventually become the same interface: one selected problem. From there, the model samples multiple candidate solutions and the environment compares the resulting reasoning paths.
34
+
35
+ ## What Gets Checked
36
+
37
+ ![Grading signals](./images/blog_flow/grading.svg)
38
+
39
+ AxiomForgeAI does not rely on a single reward signal. A final answer check is useful, but it is not enough. A process score is useful, but it is also not enough. The environment combines several signals so that a polished but wrong solution does not look good, and a lucky answer with weak reasoning does not look good either.
40
+
41
+ For grounded problems, the gold answer anchors correctness. For all attempts, the environment also looks at reasoning quality, chain consistency, symbolic arithmetic where possible, and whether the answer can be parsed cleanly. For self-play, the generated question itself is scored before the solution reward is trusted.
42
+
43
+ The result is one score per attempt. That score is not the end of training. It becomes useful because there are other attempts for the same problem.
44
+
45
+ ## Why GRPO Fits
46
+
47
+ ![GRPO loop](./images/blog_flow/grpo-loop.svg)
48
+
49
+ GRPO turns a problem into a small comparison game. The model samples several attempts for the same prompt. Some are wrong, some are partially right, and one may be clearly better because the answer follows from the steps.
50
+
51
+ Instead of asking whether an attempt is good in isolation, GRPO asks which attempts are stronger relative to the rest of the group. That relative signal is exactly what this project needs. The model learns from contrast: this reasoning path held together better than the others.
52
+
53
+ After the update, the improved model goes back into the environment for the next batch. The curriculum can keep it grounded, introduce more self-play, or fall back to grounded-only practice if quality drops.
54
+
55
+ ## Why the 1.5B Constraint Matters
56
+
57
+ AxiomForgeAI is intentionally built around a compact math model.
58
+
59
+ That constraint makes the loop easier to see. A smaller model cannot hide every reasoning mistake behind scale. If the setup is wrong, if the arithmetic drifts, or if the final answer does not follow from the steps, the environment has to catch it and turn it into feedback.
60
+
61
+ The point is not that a compact model magically solves math. The point is that improvement has to come from better practice, better verification, and better selection of reasoning paths.
62
+
63
+ ## What the Model Learns From
64
+
65
+ AxiomForgeAI rewards attempts that are mathematically useful, not just polished.
66
+
67
+ The model learns to solve problems with reasoning that supports the answer. It also learns, during self-play, to generate practice problems that are worth solving. A useful self-generated problem should be clear, solvable, on-topic, appropriately difficult, and not just a duplicate of what the model has already seen.
68
+
69
+ That makes the loop different from ordinary fine-tuning. The model is not only seeing more answers. It is practicing, being checked, and learning from the solution paths that survived verification.
70
+
71
+ ## Where Examples Will Go
72
+
73
+ This section will include real model responses from the run.
74
+
75
+ - an example where the model had good steps but a wrong final answer
76
+ - an example where the model guessed correctly but the reasoning was weak
77
+ - an example after training where the reasoning chain and final answer agree
78
+ - a self-generated problem that passed the quality checks
79
+
80
+ These examples are important because the project is not only about a metric. The clearest evidence is seeing the model become better at making the path and the answer line up.
81
+
82
+ ## Why This Matters
83
+
84
+ Math is a good starting point because mistakes are often checkable. Arithmetic can be verified. Final answers can be compared. Reasoning steps can be scored. That makes math a clean domain for building self-improvement loops.
85
+
86
+ But the pattern is bigger than math.
87
+
88
+ Many useful AI tasks have the same structure. Generate an attempt, check it, compare it against alternatives, and reinforce the better path. Code, logic, structured data transformation, and scientific problem solving all benefit from environments where progress can be verified.
89
+
90
+ AxiomForgeAI is one version of that pattern. It asks a simple question.
91
+
92
+ > What if a model could practice until its reasoning and answers agreed?
93
+
94
+ That is the loop this project builds.
client.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """AxiomForgeAI Math RL Environment Client."""
8
+
9
+ from typing import Any, Dict, Optional
10
+
11
+ from openenv.core import EnvClient
12
+ from openenv.core.client_types import StepResult
13
+ from openenv.core.env_server.types import State
14
+
15
+ from .models import AxiomforgeaiAction, AxiomforgeaiObservation
16
+
17
+
18
+ class AxiomforgeaiEnv(
19
+ EnvClient[AxiomforgeaiAction, AxiomforgeaiObservation, State]
20
+ ):
21
+ """
22
+ Client for the AxiomForgeAI math RL environment.
23
+
24
+ Maintains a persistent WebSocket connection to the environment server.
25
+ Each client instance gets its own session with independent episode state.
26
+
27
+ Episode flow::
28
+
29
+ with AxiomforgeaiEnv(base_url="http://localhost:8000") as env:
30
+ # 1. Reset β€” receive a math question
31
+ result = env.reset()
32
+ question = result.observation.question
33
+
34
+ # 2. Step β€” submit a solution, receive reward + feedback
35
+ solution = "Step 1: ... Final Answer: 42"
36
+ result = env.step(AxiomforgeaiAction(solution=solution))
37
+ print(result.reward, result.observation.feedback)
38
+
39
+ Example with Docker::
40
+
41
+ client = AxiomforgeaiEnv.from_docker_image("axiomforgeai-env:latest")
42
+ try:
43
+ result = client.reset()
44
+ result = client.step(AxiomforgeaiAction(solution="Final Answer: 17"))
45
+ finally:
46
+ client.close()
47
+ """
48
+
49
+ def _step_payload(self, action: AxiomforgeaiAction) -> Dict[str, Any]:
50
+ """Convert AxiomforgeaiAction to JSON payload for the step endpoint."""
51
+ return {"solution": action.solution}
52
+
53
+ def _parse_result(self, payload: Dict[str, Any]) -> StepResult[AxiomforgeaiObservation]:
54
+ """Parse the server's step response into a StepResult."""
55
+ obs_data: Dict[str, Any] = payload.get("observation", {})
56
+ observation = AxiomforgeaiObservation(
57
+ question=obs_data.get("question", ""),
58
+ topic=obs_data.get("topic", ""),
59
+ difficulty=float(obs_data.get("difficulty", 0.5)),
60
+ feedback=obs_data.get("feedback", ""),
61
+ done=payload.get("done", False),
62
+ reward=payload.get("reward"),
63
+ metadata=obs_data.get("metadata"),
64
+ )
65
+ return StepResult(
66
+ observation=observation,
67
+ reward=payload.get("reward"),
68
+ done=payload.get("done", False),
69
+ )
70
+
71
+ def _parse_state(self, payload: Dict[str, Any]) -> State:
72
+ """Parse the server's state response into a State object."""
73
+ return State(
74
+ episode_id=payload.get("episode_id"),
75
+ step_count=payload.get("step_count", 0),
76
+ )
docs/environment-overview.puml ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @startuml environment_overview
2
+ !theme plain
3
+ top to bottom direction
4
+ skinparam backgroundColor #FEFEFE
5
+ skinparam defaultFontName Arial
6
+ skinparam defaultFontSize 14
7
+ skinparam ArrowColor #334155
8
+ skinparam RectangleBorderColor #64748B
9
+ skinparam RectangleFontColor #0F172A
10
+ skinparam roundcorner 10
11
+ skinparam linetype ortho
12
+ skinparam packageStyle rectangle
13
+ skinparam nodesep 42
14
+ skinparam ranksep 42
15
+
16
+ title AxiomForgeAI - Phase-Controlled Math Reasoning Loop
17
+
18
+ rectangle "Small Math Model\n1.5B parameters" as MODEL #DBEAFE
19
+
20
+ rectangle "Phase Controller\nwarmup: grounded only\nramp: gradual self-play\ncontinuous: capped mix + fallback" as PHASE #E2E8F0
21
+
22
+ rectangle "Task Source\nfor each GRPO group" as SELECT #E2E8F0
23
+
24
+ rectangle "Grounded Source\nKnown-answer practice" as GLANE #ECFDF5 {
25
+ rectangle "Dataset problem\nGSM8K / MATH" as GQ #CCFBF1
26
+ rectangle "Gold answer\navailable" as GOLD #CCFBF1
27
+ rectangle "Model samples\nK solutions" as GSOL #CCFBF1
28
+ }
29
+
30
+ rectangle "Self-Play Source\nModel-made challenges" as SLANE #EEF2FF {
31
+ rectangle "Curriculum picks\nskill + difficulty" as CURRIC #E0E7FF
32
+ rectangle "Model writes\na new question" as SQ #E0E7FF
33
+ rectangle "Model samples\nK solutions" as SSOL #E0E7FF
34
+ }
35
+
36
+ rectangle "Shared Grading\nanswer, steps, arithmetic, format\n+ question quality for self-play" as GRADERS #F1F5F9
37
+
38
+ rectangle "Group Comparison\nWhich attempts worked best?" as COMPARE #EDE9FE
39
+ rectangle "GRPO Update\nReinforce stronger reasoning" as GRPO #DDD6FE
40
+ rectangle "Improved Model\nfor the next round" as NEXT #DBEAFE
41
+
42
+ MODEL -down-> PHASE
43
+ PHASE -down-> SELECT
44
+
45
+ note right of PHASE
46
+ sets mix
47
+ end note
48
+
49
+ SELECT -left-> GQ : grounded slot
50
+ GQ --> GOLD
51
+ GOLD --> GSOL
52
+
53
+ SELECT -right-> CURRIC : self-play slot
54
+ CURRIC --> SQ
55
+ SQ --> SSOL
56
+
57
+ GSOL -down-> GRADERS
58
+ SSOL -down-> GRADERS
59
+ GRADERS -right-> COMPARE
60
+ COMPARE -right-> GRPO
61
+ GRPO -right-> NEXT
62
+ NEXT -up-> MODEL : repeat
63
+
64
+ note bottom of SELECT
65
+ Each batch is randomly interleaved.
66
+ Phase 1 uses grounded only.
67
+ Later phases add self-play slots by ratio.
68
+ end note
69
+ @enduml
docs/reward-system.puml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @startuml reward_system
2
+ !theme plain
3
+ top to bottom direction
4
+ skinparam backgroundColor #FEFEFE
5
+ skinparam defaultFontName Arial
6
+ skinparam defaultFontSize 14
7
+ skinparam ArrowColor #334155
8
+ skinparam RectangleBorderColor #64748B
9
+ skinparam RectangleFontColor #0F172A
10
+ skinparam roundcorner 10
11
+ skinparam linetype ortho
12
+ skinparam packageStyle rectangle
13
+ skinparam nodesep 54
14
+ skinparam ranksep 60
15
+
16
+ title AxiomForgeAI - Reward System
17
+
18
+ rectangle "Sampled Solution Attempt" as ATTEMPT #DBEAFE
19
+
20
+ rectangle "Grounded Reward\nknown-answer problem" as GROUNDED #ECFDF5 {
21
+ rectangle "Final answer\nmatches gold" as GOLD #CCFBF1
22
+ rectangle "PRM process score\nreasoning quality" as GPRM #CCFBF1
23
+ rectangle "Chain consistency\ncorrect prefix + final check" as GCHAIN #CCFBF1
24
+ rectangle "Format score\nparseable final answer" as GFORMAT #CCFBF1
25
+ }
26
+
27
+ rectangle "Self-Play Reward\ngenerated challenge" as SELFPLAY #EEF2FF {
28
+ rectangle "Question quality\nclarity, novelty, solvability" as QUALITY #E0E7FF
29
+ rectangle "Solution quality\nPRM + chain checks" as SOLUTION #E0E7FF
30
+ rectangle "Format score\nparseable final answer" as SFORMAT #E0E7FF
31
+ }
32
+
33
+ rectangle "Combined Reward\none score per attempt" as SCORE #F1F5F9
34
+ rectangle "GRPO Group Comparison\nrank attempts within the same problem" as COMPARE #EDE9FE
35
+ rectangle "Step-Answer Alignment\nreward paths where reasoning supports the result" as ALIGN #DDD6FE
36
+
37
+ ATTEMPT -left-> GROUNDED : grounded
38
+ ATTEMPT -right-> SELFPLAY : self-play
39
+
40
+ GOLD --> GPRM
41
+ GPRM --> GCHAIN
42
+ GCHAIN --> GFORMAT
43
+
44
+ QUALITY --> SOLUTION
45
+ SOLUTION --> SFORMAT
46
+
47
+ GFORMAT -down-> SCORE
48
+ SFORMAT -down-> SCORE
49
+ SCORE -right-> COMPARE
50
+ COMPARE -right-> ALIGN
51
+ @enduml
docs/training-phases.puml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @startuml training_phases
2
+ !theme plain
3
+ left to right direction
4
+ skinparam backgroundColor #FEFEFE
5
+ skinparam defaultFontName Arial
6
+ skinparam defaultFontSize 14
7
+ skinparam ArrowColor #334155
8
+ skinparam RectangleBorderColor #64748B
9
+ skinparam RectangleFontColor #0F172A
10
+ skinparam roundcorner 10
11
+ skinparam linetype ortho
12
+ skinparam packageStyle rectangle
13
+ skinparam nodesep 42
14
+ skinparam ranksep 42
15
+
16
+ title AxiomForgeAI - Training Phases
17
+
18
+ rectangle "Phase 1\nGrounded Only" as Warmup #ECFDF5
19
+ rectangle "Phase 2\nSelf-Play Ramp" as Ramp #EEF2FF
20
+ rectangle "Phase 3\nMixed Training" as Improve #F1F5F9
21
+ rectangle "Fallback\nGrounded Recovery" as Fallback #EDE9FE
22
+
23
+ Warmup --> Ramp
24
+ Ramp --> Improve
25
+ Improve --> Fallback : if quality drops
26
+ Fallback --> Improve : recover
27
+ @enduml
images/axiomforgeai_scenes/scene_01.svg ADDED
images/axiomforgeai_scenes/scene_02.svg ADDED
images/axiomforgeai_scenes/scene_03.svg ADDED
images/axiomforgeai_scenes/scene_04.svg ADDED
images/axiomforgeai_scenes/scene_05.svg ADDED
images/axiomforgeai_scenes/scene_06.svg ADDED
images/axiomforgeai_scenes/scene_07.svg ADDED
images/axiomforgeai_scenes/scene_08.svg ADDED
images/axiomforgeai_scenes/scene_09.svg ADDED
images/axiomforgeai_scenes/scene_10.svg ADDED
images/blog_flow/architecture.svg ADDED
images/blog_flow/grading.svg ADDED
images/blog_flow/grpo-loop.svg ADDED
images/blog_flow/task-sources.svg ADDED
images/environment_overview.svg ADDED
images/training_phases.svg ADDED
logs/grpo/grpo_20260426_024029.log ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2026-04-26 02:40:33,617 INFO __main__ - ======================================================================
2
+ 2026-04-26 02:40:33,617 INFO __main__ - GRPO run: grpo_20260426_024029
3
+ 2026-04-26 02:40:33,617 INFO __main__ - Checkpoints : checkpoints/grpo/grpo_20260426_024029
4
+ 2026-04-26 02:40:33,618 INFO __main__ - Logs : logs/grpo/grpo_20260426_024029
5
+ 2026-04-26 02:40:33,618 INFO __main__ - Console log : logs/grpo/grpo_20260426_024029/console_output.log
6
+ 2026-04-26 02:40:33,618 INFO __main__ - ======================================================================
7
+ 2026-04-26 02:40:33,736 INFO src.utils.attn_backend - Attention backend selected: flash_attention_2
8
+ 2026-04-26 02:40:33,736 INFO __main__ - Device: cuda:0 | attn: flash_attention_2
9
+ 2026-04-26 02:40:33,753 INFO __main__ - GPU: NVIDIA A100 80GB PCIe | 85.1 GB VRAM | capability sm_80
10
+ 2026-04-26 02:40:33,753 INFO __main__ - Run config: K=8 K_q=2 N=16 lr=5.0e-06 T=0.80 max_new=800 | clip_eps=0.20 kl_coef=0.0400 warmup=6 | diff_alpha=3.0 | self_play=70% grounded=30% | math_mix=30% math_maxdiff=3 | overlong_filter=True | eval_every=5 eval_N=100 | grad_clip=0.50 save_every=5 keep_last=3 | question_GRPO=ENABLED (K_q=2)
11
+ 2026-04-26 02:40:33,753 INFO __main__ - Loading model from checkpoints/dual_task_v1 ...
12
+ 2026-04-26 02:40:34,405 INFO __main__ - Tokenizer has no chat_template; loading from base model Qwen/Qwen2.5-Math-1.5B-Instruct
13
+ 2026-04-26 02:40:34,731 INFO __main__ - Chat template loaded successfully.
14
+ 2026-04-26 02:40:34,731 INFO __main__ - Detected PEFT adapter β€” loading base Qwen/Qwen2.5-Math-1.5B-Instruct then merging checkpoints/dual_task_v1
15
+ 2026-04-26 02:40:36,242 WARNING __main__ - All parameters were frozen on load (PEFT merge_and_unload bug). Re-enabled requires_grad β€” any prior frozen runs were training nothing.
16
+ 2026-04-26 02:40:36,242 INFO __main__ - Flash-Attn 2 active β€” gradient checkpointing OFF (Flash already gives O(T) attention memory).
17
+ 2026-04-26 02:40:36,243 INFO __main__ - Trainable parameters: 1,543,714,304 / 1,543,714,304 (100.0%)
18
+ 2026-04-26 02:40:36,244 INFO __main__ - Creating frozen reference policy (kl_coef=0.0400, ~3.1 GB VRAM)...
19
+ 2026-04-26 02:40:36,305 INFO __main__ - Reference policy ready.
20
+ 2026-04-26 02:40:36,306 INFO __main__ - LR schedule: 5.0e-06 warmup(6 iters) β†’ cosine decay(24 iters, min=5.0e-07)
21
+ 2026-04-26 02:40:36,415 INFO __main__ - Loaded 8792 QA pairs from data/sft/gsm8k_sft.jsonl
22
+ 2026-04-26 02:40:36,424 INFO __main__ - Loaded 4072 MATH pairs from data/math/math_numeric.jsonl
23
+ 2026-04-26 02:40:36,424 INFO __main__ - MATH mixing: 30% MATH (4072 problems) + 70% GSM8K (8792 problems)
24
+ 2026-04-26 02:40:36,424 INFO src.rl.prm_scorer - Loading PRM Qwen/Qwen2.5-Math-PRM-7B (4-bit=True, dtype=torch.bfloat16) on cuda:0 …
25
+
26
+ Some weights of the model checkpoint at Qwen/Qwen2.5-Math-PRM-7B were not used when initializing Qwen2ForProcessRewardModel: ['lm_head.weight']
27
+ - This IS expected if you are initializing Qwen2ForProcessRewardModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
28
+ - This IS NOT expected if you are initializing Qwen2ForProcessRewardModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
29
+ 2026-04-26 02:40:40,150 INFO src.rl.prm_scorer - PRM ready. GPU memory allocated: 9.97 GB step_sep_id=151651
30
+ 2026-04-26 02:40:40,151 INFO __main__ - PRM loaded: Qwen/Qwen2.5-Math-PRM-7B (4-bit)
31
+ 2026-04-26 02:40:40,154 INFO src.rl.unified_accuracy - Extraction cache not found at data/extraction_cache.json β€” will build on first use
32
+ 2026-04-26 02:40:40,154 INFO __main__ - Unified accuracy calculator ready (extractor=Qwen/Qwen2.5-0.5B-Instruct, cache=data/extraction_cache.json)
33
+ 2026-04-26 02:40:40,154 INFO __main__ - Warming up step-chain extractor (eager load)...
34
+ 2026-04-26 02:40:40,154 INFO src.rl.unified_accuracy - Loading step chain extractor: Qwen/Qwen2.5-0.5B-Instruct
35
+ 2026-04-26 02:40:41,033 INFO src.rl.unified_accuracy - Step chain extractor loaded
36
+ 2026-04-26 02:40:41,034 INFO __main__ - Extractor warmup complete
37
+ 2026-04-26 02:40:41,034 INFO src.rl.llm_question_classifier - LLMQuestionClassifier ready (model=Qwen2ForCausalLM, cache=10000, topics=24)
38
+ 2026-04-26 02:40:42,571 INFO __main__ - Detected structured dataset (8792 records) β€” bootstrapping curriculum from skill_ids instead of keyword classifier.
39
+ 2026-04-26 02:40:42,575 INFO src.rl.curriculum_manager - Curriculum bootstrapped from 8792 records across 1 topics
40
+ 2026-04-26 02:40:42,575 INFO __main__ - ======================================================================
41
+ 2026-04-26 02:40:42,575 INFO __main__ - INITIAL EVALUATION (Iteration 0)
42
+ 2026-04-26 02:40:42,575 INFO __main__ - ======================================================================
43
+
44
+
logs/grpo/grpo_20260426_032827.log ADDED
The diff for this file is too large to render. See raw diff
 
logs/grpo/grpo_20260426_032827/config.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "base_model": "checkpoints/dual_task_v1",
3
+ "output_dir": "checkpoints/grpo",
4
+ "gsm8k_data": "data/sft/gsm8k_sft.jsonl",
5
+ "eval_data_path": "data/sft/gsm8k_test.jsonl",
6
+ "num_iterations": 60,
7
+ "group_size": 10,
8
+ "q_group_size": 2,
9
+ "questions_per_iter": 20,
10
+ "learning_rate": 5e-06,
11
+ "max_new_tokens": 1000,
12
+ "temperature": 0.8,
13
+ "eval_every": 5,
14
+ "eval_max_samples": 150,
15
+ "eval_max_new_tokens": 1000,
16
+ "eval_pass_at_k": 0,
17
+ "use_prm": true,
18
+ "prm_model": "Qwen/Qwen2.5-Math-PRM-7B",
19
+ "skip_initial_eval": false,
20
+ "run_name": "grpo_20260426_032827",
21
+ "max_grad_norm": 0.5,
22
+ "kl_coef": 0.06,
23
+ "math_data": null,
24
+ "math_mix_ratio": 0.3,
25
+ "math_mix_ratio_late": 0.5,
26
+ "math_ramp_start": 18,
27
+ "math_max_difficulty": 3,
28
+ "clip_eps": 0.2,
29
+ "warmup_iters": 8,
30
+ "min_lr_ratio": 0.1,
31
+ "difficulty_alpha": 3.5,
32
+ "overlong_filter": true,
33
+ "save_every": 5,
34
+ "keep_last": 4,
35
+ "self_play_ratio": 0.7,
36
+ "min_warmup": 12,
37
+ "selfplay_gt_thresh": 0.65,
38
+ "selfplay_grounded_thresh": 0.65,
39
+ "selfplay_step_thresh": 0.68,
40
+ "selfplay_ramp_iters": 28,
41
+ "grounded_floor": 0.55,
42
+ "extractor_model": "Qwen/Qwen2.5-0.5B-Instruct",
43
+ "extraction_cache": "data/extraction_cache.json"
44
+ }
logs/grpo/grpo_20260426_032827/console_output.log ADDED
The diff for this file is too large to render. See raw diff
 
logs/grpo/grpo_20260426_032827/metrics.csv ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ iteration,timestamp,loss,mean_reward,std_reward,batch_accuracy,grounded_acc,gt_match_rate,step_accuracy,lccp,n_groups,skipped_groups,n_sp_groups,sp_ratio,sp_suspended,training_phase,learning_rate,iter_time_s,q_reward,q_valid_rate,q_novelty,q_solvability,chain_prm_corr,chain_scoring_on,eval_combined,eval_correct_rt,eval_prm,eval_step_acc,eval_lccp,eval_format,eval_n_scored,eval_final_ans
2
+ 1,2026-04-26T03:38:38,0.000610,0.914309,0.163605,0.960000,0.960000,0.780000,0.894861,0.814111,12,8,0,0.000000,0,GROUNDED_ONLY,0.000001,127.637996,0.000000,0.000000,0.000000,0.000000,0.000000,0,,,,,,,,
3
+ 2,2026-04-26T03:41:58,-0.000034,0.847892,0.216018,0.914141,0.914141,0.651500,0.866692,0.765381,18,2,0,0.000000,0,GROUNDED_ONLY,0.000002,199.518393,0.000000,0.000000,0.000000,0.000000,0.000000,0,,,,,,,,
4
+ 3,2026-04-26T03:45:08,0.000366,0.896391,0.170699,0.954545,0.954545,0.707100,0.876898,0.765238,12,8,0,0.000000,0,GROUNDED_ONLY,0.000002,189.836063,0.000000,0.000000,0.000000,0.000000,0.000000,0,,,,,,,,
5
+ 4,2026-04-26T03:48:10,0.000942,0.865431,0.218756,0.893939,0.893939,0.732300,0.858504,0.764982,11,9,0,0.000000,0,GROUNDED_ONLY,0.000003,182.125475,0.000000,0.000000,0.000000,0.000000,0.000000,0,,,,,,,,
6
+ 5,2026-04-26T03:59:39,0.000081,0.856875,0.239487,0.884422,0.884422,0.693500,0.918500,0.843100,16,4,0,0.000000,0,GROUNDED_ONLY,0.000003,201.679190,0.000000,0.000000,0.000000,0.000000,0.000000,0,0.919200,0.793300,0.903500,0.918500,0.843100,0.997700,150,0.793333
7
+ 6,2026-04-26T04:02:52,-0.000063,0.879253,0.215318,0.909548,0.909548,0.748700,0.884646,0.805897,12,8,0,0.000000,0,GROUNDED_ONLY,0.000004,193.350312,0.000000,0.000000,0.000000,0.000000,0.000000,0,,,,,,,,
8
+ 7,2026-04-26T04:06:20,0.001071,0.837888,0.223356,0.883249,0.883249,0.639600,0.813073,0.658069,14,6,0,0.000000,0,GROUNDED_ONLY,0.000004,208.223944,0.000000,0.000000,0.000000,0.000000,0.000000,0,,,,,,,,
9
+ 8,2026-04-26T04:09:11,-0.000257,0.875536,0.200109,0.895000,0.895000,0.690000,0.864722,0.747928,13,7,0,0.000000,0,GROUNDED_ONLY,0.000005,170.595953,0.000000,0.000000,0.000000,0.000000,0.000000,0,,,,,,,,
10
+ 9,2026-04-26T04:12:52,0.000060,0.906506,0.176914,0.964646,0.964646,0.803000,0.893573,0.817532,15,5,0,0.000000,0,GROUNDED_ONLY,0.000005,221.350669,0.000000,0.000000,0.000000,0.000000,0.000000,0,,,,,,,,
11
+ 10,2026-04-26T04:24:49,0.000425,0.880765,0.175501,0.954774,0.954774,0.683400,0.920500,0.842600,14,6,0,0.000000,0,GROUNDED_ONLY,0.000005,188.981772,0.000000,0.000000,0.000000,0.000000,0.000000,0,0.919900,0.793300,0.906600,0.920500,0.842600,0.998000,150,0.793333
12
+ 11,2026-04-26T04:27:11,-0.000557,0.969814,0.098322,0.985000,0.985000,0.930000,0.966268,0.921810,8,12,0,0.000000,0,GROUNDED_ONLY,0.000005,141.966778,0.000000,0.000000,0.000000,0.000000,0.000000,0,,,,,,,,
13
+ 12,2026-04-26T04:30:09,0.000073,0.849274,0.212864,0.900000,0.900000,0.650000,0.820526,0.687272,14,6,0,0.000000,0,SELFPLAY_RAMP,0.000005,177.954757,0.000000,0.000000,0.000000,0.000000,0.000000,0,,,,,,,,
14
+ 13,2026-04-26T04:39:26,0.000268,0.898824,0.185992,0.930000,0.930000,0.780000,0.870960,0.788730,14,6,0,0.000000,0,SELFPLAY_RAMP,0.000005,556.185637,0.000000,0.000000,0.000000,0.000000,-0.040000,0,,,,,,,,
15
+ 14,2026-04-26T04:48:54,0.000496,0.855832,0.208499,0.952381,0.947368,0.673700,0.857607,0.747807,18,3,1,0.036000,0,SELFPLAY_RAMP,0.000005,568.400518,0.763000,1.000000,0.428900,1.000000,0.209000,0,,,,,,,,
16
+ 15,2026-04-26T05:06:28,0.000023,0.927972,0.167187,0.937799,0.931217,0.836000,0.924200,0.842400,12,9,1,0.071000,0,SELFPLAY_RAMP,0.000005,550.143772,0.721800,1.000000,0.458000,1.000000,0.079000,0,0.926200,0.800000,0.907200,0.924200,0.842400,1.000000,150,0.800000
17
+ 16,2026-04-26T05:16:04,0.000330,0.914605,0.172733,0.949772,0.938547,0.832400,0.895523,0.843899,15,7,2,0.107000,0,SELFPLAY_RAMP,0.000005,575.528946,0.787800,1.000000,0.447500,0.960000,0.089000,0,,,,,,,,
18
+ 17,2026-04-26T05:26:20,-0.000137,0.888123,0.195006,0.938326,0.916168,0.700600,0.855796,0.768235,20,3,3,0.143000,0,SELFPLAY_RAMP,0.000005,616.018573,0.798200,1.000000,0.461600,1.000000,-0.191000,0,,,,,,,,
19
+ 18,2026-04-26T05:35:30,0.000079,0.866401,0.178010,0.953975,0.943396,0.591200,0.830780,0.692011,19,5,4,0.179000,0,SELFPLAY_RAMP,0.000005,550.572628,0.739400,1.000000,0.452000,0.976200,0.021000,0,,,,,,,,
20
+ 19,2026-04-26T05:44:13,0.000151,0.891281,0.172665,0.953586,0.949045,0.764300,0.851398,0.756874,16,8,4,0.214000,0,SELFPLAY_RAMP,0.000005,522.428960,0.733100,1.000000,0.456400,0.972500,0.075000,0,,,,,,,,
21
+ 20,2026-04-26T06:02:54,0.000244,0.896291,0.177842,0.927711,0.906040,0.798700,0.925300,0.842800,18,7,5,0.250000,0,SELFPLAY_RAMP,0.000004,619.886349,0.770000,1.000000,0.474100,0.945000,-0.118000,0,0.923400,0.800000,0.905600,0.925300,0.842800,1.000000,150,0.800000
22
+ 21,2026-04-26T06:11:04,0.000192,0.841732,0.187981,0.923077,0.914286,0.735700,0.819504,0.693061,21,5,6,0.286000,0,SELFPLAY_RAMP,0.000004,490.366938,0.697200,1.000000,0.449300,0.962500,0.209000,0,,,,,,,,
23
+ 22,2026-04-26T06:21:16,0.000579,0.917519,0.124242,0.984314,0.985294,0.904400,0.964735,0.928489,20,6,6,0.321000,0,SELFPLAY_RAMP,0.000004,611.872286,0.699800,1.000000,0.457100,0.979000,0.145000,0,,,,,,,,
24
+ 23,2026-04-26T06:28:41,0.000614,0.920698,0.147419,0.977011,0.950820,0.803300,0.907500,0.847631,18,9,7,0.357000,0,SELFPLAY_RAMP,0.000004,444.320885,0.726000,1.000000,0.441200,0.988500,0.143000,0,,,,,,,,
25
+ 24,2026-04-26T06:36:32,-0.000213,0.879590,0.173313,0.935714,0.933333,0.791700,0.898819,0.812292,20,8,8,0.393000,0,SELFPLAY_RAMP,0.000004,471.698962,0.662100,1.000000,0.440800,0.968800,0.082000,0,,,,,,,,
26
+ 25,2026-04-26T06:53:36,0.000344,0.844528,0.208658,0.927336,0.853211,0.605500,0.919800,0.846800,28,1,9,0.429000,0,SELFPLAY_RAMP,0.000004,524.655717,0.647100,1.000000,0.439400,0.967200,0.127000,0,0.922100,0.793300,0.903400,0.919800,0.846800,1.000000,150,0.793333
27
+ 26,2026-04-26T07:02:06,0.000421,0.866649,0.179636,0.920415,0.926606,0.789000,0.889846,0.794302,26,3,9,0.464000,0,SELFPLAY_RAMP,0.000004,509.677450,0.679200,1.000000,0.448800,0.931700,0.065000,0,,,,,,,,
28
+ 27,2026-04-26T07:12:03,-0.000227,0.877934,0.162866,0.956376,0.939394,0.686900,0.861628,0.740657,25,5,10,0.500000,0,SELFPLAY_RAMP,0.000004,597.521238,0.683100,1.000000,0.458400,0.975900,0.067000,0,,,,,,,,
29
+ 28,2026-04-26T07:22:06,0.000042,0.869600,0.159154,0.941935,0.877778,0.655600,0.833443,0.618623,29,2,11,0.536000,0,SELFPLAY_RAMP,0.000004,603.099793,0.669300,1.000000,0.448900,0.983600,0.047000,0,,,,,,,,
30
+ 29,2026-04-26T07:31:46,0.000377,0.867441,0.170826,0.947020,0.892857,0.726200,0.867407,0.760394,28,3,11,0.571000,0,SELFPLAY_RAMP,0.000003,579.690467,0.649600,1.000000,0.442500,0.973900,0.123000,0,,,,,,,,
31
+ 30,2026-04-26T07:48:26,-0.000299,0.870581,0.160260,0.965517,0.950000,0.800000,0.923200,0.850000,27,5,12,0.607000,0,SELFPLAY_RAMP,0.000003,503.087982,0.676400,1.000000,0.456600,0.969900,0.099000,0,0.920400,0.793300,0.904400,0.923200,0.850000,1.000000,150,0.793333
logs/metrics.jsonl ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"iteration": 0, "accuracy": 0.9162, "combined_score": 0.9162, "step_accuracy": 0.9111, "lccp": 0.8392, "correct_rate": 0.7867, "prm_mean": 0.8988, "prm_final": 0.9275, "sympy_mean": 0.0, "format_mean": 1.0, "n_scored": 150, "total": 150, "final_answer_correct": 118, "final_answer_accuracy": 0.7866666666666666}
2
+ {"iteration": 1, "loss": 0.0006103356778718686, "mean_reward": 0.914308755129325, "std_reward": 0.1636050993381563, "batch_accuracy": 0.96, "grounded_accuracy": 0.96, "gt_match_rate": 0.78, "step_accuracy": 0.8948611111111111, "lccp": 0.8141111111111111, "n_groups": 12, "skipped_groups": 8, "learning_rate": 1.0625000000000002e-06, "iter_time_s": 127.63799649500288, "training_phase": "GROUNDED_ONLY", "effective_sp_ratio": 0.0, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.0, "extraction_success_rate": 0.0, "chain_scoring_active": 0, "n_self_play_groups": 0, "q_gen_attempts": 0, "q_gen_valid": 0, "q_gen_valid_rate": 0.0, "mean_question_reward": 0.0, "q_quality_rate": 0.0, "q_topic_match": 0.0, "q_difficulty_fit": 0.0, "q_clarity": 0.0, "q_novelty": 0.0, "q_solvability": 0.0}
3
+ {"iteration": 2, "loss": -3.432962815471304e-05, "mean_reward": 0.8478923191518654, "std_reward": 0.2160182166583165, "batch_accuracy": 0.9141414141414141, "grounded_accuracy": 0.9141414141414141, "gt_match_rate": 0.6515, "step_accuracy": 0.8666916416916417, "lccp": 0.7653809153809155, "n_groups": 18, "skipped_groups": 2, "learning_rate": 1.6250000000000001e-06, "iter_time_s": 199.5183933188673, "training_phase": "GROUNDED_ONLY", "effective_sp_ratio": 0.0, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.0, "extraction_success_rate": 0.0, "chain_scoring_active": 0, "n_self_play_groups": 0, "q_gen_attempts": 0, "q_gen_valid": 0, "q_gen_valid_rate": 0.0, "mean_question_reward": 0.0, "q_quality_rate": 0.0, "q_topic_match": 0.0, "q_difficulty_fit": 0.0, "q_clarity": 0.0, "q_novelty": 0.0, "q_solvability": 0.0}
4
+ {"iteration": 3, "loss": 0.0003658987698145211, "mean_reward": 0.8963912433066207, "std_reward": 0.17069859725714537, "batch_accuracy": 0.9545454545454546, "grounded_accuracy": 0.9545454545454546, "gt_match_rate": 0.7071, "step_accuracy": 0.876897947731281, "lccp": 0.765237694404361, "n_groups": 12, "skipped_groups": 8, "learning_rate": 2.1875000000000002e-06, "iter_time_s": 189.83606291818433, "training_phase": "GROUNDED_ONLY", "effective_sp_ratio": 0.0, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.0, "extraction_success_rate": 0.0, "chain_scoring_active": 0, "n_self_play_groups": 0, "q_gen_attempts": 0, "q_gen_valid": 0, "q_gen_valid_rate": 0.0, "mean_question_reward": 0.0, "q_quality_rate": 0.0, "q_topic_match": 0.0, "q_difficulty_fit": 0.0, "q_clarity": 0.0, "q_novelty": 0.0, "q_solvability": 0.0}
5
+ {"iteration": 4, "loss": 0.0009415318305731158, "mean_reward": 0.8654313890820613, "std_reward": 0.21875612713334075, "batch_accuracy": 0.8939393939393939, "grounded_accuracy": 0.8939393939393939, "gt_match_rate": 0.7323, "step_accuracy": 0.8585036876703543, "lccp": 0.7649821628988295, "n_groups": 11, "skipped_groups": 9, "learning_rate": 2.7500000000000004e-06, "iter_time_s": 182.12547484994866, "training_phase": "GROUNDED_ONLY", "effective_sp_ratio": 0.0, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.0, "extraction_success_rate": 0.0, "chain_scoring_active": 0, "n_self_play_groups": 0, "q_gen_attempts": 0, "q_gen_valid": 0, "q_gen_valid_rate": 0.0, "mean_question_reward": 0.0, "q_quality_rate": 0.0, "q_topic_match": 0.0, "q_difficulty_fit": 0.0, "q_clarity": 0.0, "q_novelty": 0.0, "q_solvability": 0.0}
6
+ {"iteration": 5, "loss": 8.118284122815567e-05, "mean_reward": 0.8568747993989829, "std_reward": 0.23948718740823036, "batch_accuracy": 0.8844221105527639, "grounded_accuracy": 0.8844221105527639, "gt_match_rate": 0.6935, "step_accuracy": 0.9185, "lccp": 0.8431, "n_groups": 16, "skipped_groups": 4, "learning_rate": 3.3125000000000005e-06, "iter_time_s": 201.67919013393112, "training_phase": "GROUNDED_ONLY", "effective_sp_ratio": 0.0, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.0, "extraction_success_rate": 0.0, "chain_scoring_active": 0, "n_self_play_groups": 0, "q_gen_attempts": 0, "q_gen_valid": 0, "q_gen_valid_rate": 0.0, "mean_question_reward": 0.0, "q_quality_rate": 0.0, "q_topic_match": 0.0, "q_difficulty_fit": 0.0, "q_clarity": 0.0, "q_novelty": 0.0, "q_solvability": 0.0, "accuracy": 0.9192, "combined_score": 0.9192, "correct_rate": 0.7933, "prm_mean": 0.9035, "prm_final": 0.9305, "sympy_mean": 0.0, "format_mean": 0.9977, "n_scored": 150, "total": 150, "final_answer_correct": 119, "final_answer_accuracy": 0.7933333333333333}
7
+ {"iteration": 6, "loss": -6.271734067316477e-05, "mean_reward": 0.8792530329566163, "std_reward": 0.21531797453446344, "batch_accuracy": 0.9095477386934674, "grounded_accuracy": 0.9095477386934674, "gt_match_rate": 0.7487, "step_accuracy": 0.8846455219822055, "lccp": 0.8058971263242619, "n_groups": 12, "skipped_groups": 8, "learning_rate": 3.875e-06, "iter_time_s": 193.35031225602143, "training_phase": "GROUNDED_ONLY", "effective_sp_ratio": 0.0, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.0, "extraction_success_rate": 0.0, "chain_scoring_active": 0, "n_self_play_groups": 0, "q_gen_attempts": 0, "q_gen_valid": 0, "q_gen_valid_rate": 0.0, "mean_question_reward": 0.0, "q_quality_rate": 0.0, "q_topic_match": 0.0, "q_difficulty_fit": 0.0, "q_clarity": 0.0, "q_novelty": 0.0, "q_solvability": 0.0}
8
+ {"iteration": 7, "loss": 0.0010708057920315436, "mean_reward": 0.8378877251545859, "std_reward": 0.2233563664223874, "batch_accuracy": 0.883248730964467, "grounded_accuracy": 0.883248730964467, "gt_match_rate": 0.6396, "step_accuracy": 0.8130725309659319, "lccp": 0.6580686304671076, "n_groups": 14, "skipped_groups": 6, "learning_rate": 4.4375e-06, "iter_time_s": 208.22394350194372, "training_phase": "GROUNDED_ONLY", "effective_sp_ratio": 0.0, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.0, "extraction_success_rate": 0.0, "chain_scoring_active": 0, "n_self_play_groups": 0, "q_gen_attempts": 0, "q_gen_valid": 0, "q_gen_valid_rate": 0.0, "mean_question_reward": 0.0, "q_quality_rate": 0.0, "q_topic_match": 0.0, "q_difficulty_fit": 0.0, "q_clarity": 0.0, "q_novelty": 0.0, "q_solvability": 0.0}
9
+ {"iteration": 8, "loss": -0.0002566667799678376, "mean_reward": 0.8755362041151912, "std_reward": 0.20010863742401203, "batch_accuracy": 0.895, "grounded_accuracy": 0.895, "gt_match_rate": 0.69, "step_accuracy": 0.8647215007215007, "lccp": 0.7479280303030303, "n_groups": 13, "skipped_groups": 7, "learning_rate": 5e-06, "iter_time_s": 170.59595341305248, "training_phase": "GROUNDED_ONLY", "effective_sp_ratio": 0.0, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.0, "extraction_success_rate": 0.0, "chain_scoring_active": 0, "n_self_play_groups": 0, "q_gen_attempts": 0, "q_gen_valid": 0, "q_gen_valid_rate": 0.0, "mean_question_reward": 0.0, "q_quality_rate": 0.0, "q_topic_match": 0.0, "q_difficulty_fit": 0.0, "q_clarity": 0.0, "q_novelty": 0.0, "q_solvability": 0.0}
10
+ {"iteration": 9, "loss": 5.9516330460004004e-05, "mean_reward": 0.906506146327221, "std_reward": 0.1769136401553803, "batch_accuracy": 0.9646464646464646, "grounded_accuracy": 0.9646464646464646, "gt_match_rate": 0.803, "step_accuracy": 0.8935726310726311, "lccp": 0.8175324675324676, "n_groups": 15, "skipped_groups": 5, "learning_rate": 4.995894997002465e-06, "iter_time_s": 221.35066892812029, "training_phase": "GROUNDED_ONLY", "effective_sp_ratio": 0.0, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.0, "extraction_success_rate": 0.0, "chain_scoring_active": 0, "n_self_play_groups": 0, "q_gen_attempts": 0, "q_gen_valid": 0, "q_gen_valid_rate": 0.0, "mean_question_reward": 0.0, "q_quality_rate": 0.0, "q_topic_match": 0.0, "q_difficulty_fit": 0.0, "q_clarity": 0.0, "q_novelty": 0.0, "q_solvability": 0.0}
11
+ {"iteration": 10, "loss": 0.0004252615440886335, "mean_reward": 0.8807654454859567, "std_reward": 0.17550108931309533, "batch_accuracy": 0.9547738693467337, "grounded_accuracy": 0.9547738693467337, "gt_match_rate": 0.6834, "step_accuracy": 0.9205, "lccp": 0.8426, "n_groups": 14, "skipped_groups": 6, "learning_rate": 4.983594966720622e-06, "iter_time_s": 188.98177218902856, "training_phase": "GROUNDED_ONLY", "effective_sp_ratio": 0.0, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.0, "extraction_success_rate": 0.0, "chain_scoring_active": 0, "n_self_play_groups": 0, "q_gen_attempts": 0, "q_gen_valid": 0, "q_gen_valid_rate": 0.0, "mean_question_reward": 0.0, "q_quality_rate": 0.0, "q_topic_match": 0.0, "q_difficulty_fit": 0.0, "q_clarity": 0.0, "q_novelty": 0.0, "q_solvability": 0.0, "accuracy": 0.9199, "combined_score": 0.9199, "correct_rate": 0.7933, "prm_mean": 0.9066, "prm_final": 0.9408, "sympy_mean": 0.0, "format_mean": 0.998, "n_scored": 150, "total": 150, "final_answer_correct": 119, "final_answer_accuracy": 0.7933333333333333}
12
+ {"iteration": 11, "loss": -0.0005566358695432427, "mean_reward": 0.9698135460130081, "std_reward": 0.0983216960471261, "batch_accuracy": 0.985, "grounded_accuracy": 0.985, "gt_match_rate": 0.93, "step_accuracy": 0.9662678571428571, "lccp": 0.9218095238095237, "n_groups": 8, "skipped_groups": 12, "learning_rate": 4.963144790631074e-06, "iter_time_s": 141.96677790791728, "training_phase": "GROUNDED_ONLY", "effective_sp_ratio": 0.0, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.0, "extraction_success_rate": 0.0, "chain_scoring_active": 0, "n_self_play_groups": 0, "q_gen_attempts": 0, "q_gen_valid": 0, "q_gen_valid_rate": 0.0, "mean_question_reward": 0.0, "q_quality_rate": 0.0, "q_topic_match": 0.0, "q_difficulty_fit": 0.0, "q_clarity": 0.0, "q_novelty": 0.0, "q_solvability": 0.0}
13
+ {"iteration": 12, "loss": 7.270745637859883e-05, "mean_reward": 0.8492740230597824, "std_reward": 0.2128636238290247, "batch_accuracy": 0.9, "grounded_accuracy": 0.9, "gt_match_rate": 0.65, "step_accuracy": 0.8205257936507937, "lccp": 0.6872718253968253, "n_groups": 14, "skipped_groups": 6, "learning_rate": 4.934619089208618e-06, "iter_time_s": 177.9547567779664, "training_phase": "SELFPLAY_RAMP", "effective_sp_ratio": 0.0, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.0, "extraction_success_rate": 0.0, "chain_scoring_active": 0, "n_self_play_groups": 0, "q_gen_attempts": 0, "q_gen_valid": 0, "q_gen_valid_rate": 0.0, "mean_question_reward": 0.0, "q_quality_rate": 0.0, "q_topic_match": 0.0, "q_difficulty_fit": 0.0, "q_clarity": 0.0, "q_novelty": 0.0, "q_solvability": 0.0}
14
+ {"iteration": 13, "loss": 0.00026773045517204864, "mean_reward": 0.8988236995312778, "std_reward": 0.18599151493605476, "batch_accuracy": 0.93, "grounded_accuracy": 0.93, "gt_match_rate": 0.78, "step_accuracy": 0.8709603174603174, "lccp": 0.7887301587301587, "n_groups": 14, "skipped_groups": 6, "learning_rate": 4.898121949644228e-06, "iter_time_s": 556.1856374200433, "training_phase": "SELFPLAY_RAMP", "effective_sp_ratio": 0.0, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": -0.04, "extraction_success_rate": 1.0, "chain_scoring_active": 0, "n_self_play_groups": 0, "q_gen_attempts": 0, "q_gen_valid": 0, "q_gen_valid_rate": 0.0, "mean_question_reward": 0.0, "q_quality_rate": 0.0, "q_topic_match": 0.0, "q_difficulty_fit": 0.0, "q_clarity": 0.0, "q_novelty": 0.0, "q_solvability": 0.0}
15
+ {"iteration": 14, "loss": 0.0004961729192069066, "mean_reward": 0.8558324048863098, "std_reward": 0.20849902292009304, "batch_accuracy": 0.9523809523809523, "grounded_accuracy": 0.9473684210526315, "gt_match_rate": 0.6737, "step_accuracy": 0.8576065162907268, "lccp": 0.7478070175438597, "n_groups": 18, "skipped_groups": 3, "learning_rate": 4.853786546042184e-06, "iter_time_s": 568.4005180909298, "training_phase": "SELFPLAY_RAMP", "effective_sp_ratio": 0.036, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.209, "extraction_success_rate": 1.0, "chain_scoring_active": 0, "n_self_play_groups": 1, "q_gen_attempts": 1, "q_gen_valid": 1, "q_gen_valid_rate": 1.0, "mean_question_reward": 0.763, "q_quality_rate": 1.0, "q_topic_match": 0.575, "q_difficulty_fit": 0.89, "q_clarity": 1.0, "q_novelty": 0.4289, "q_solvability": 1.0}
16
+ {"iteration": 15, "loss": 2.3262581635208335e-05, "mean_reward": 0.927972135586315, "std_reward": 0.16718736928397065, "batch_accuracy": 0.937799043062201, "grounded_accuracy": 0.9312169312169312, "gt_match_rate": 0.836, "step_accuracy": 0.9242, "lccp": 0.8424, "n_groups": 12, "skipped_groups": 9, "learning_rate": 4.801774653482204e-06, "iter_time_s": 550.1437717408407, "training_phase": "SELFPLAY_RAMP", "effective_sp_ratio": 0.071, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.079, "extraction_success_rate": 0.98, "chain_scoring_active": 0, "n_self_play_groups": 1, "q_gen_attempts": 1, "q_gen_valid": 1, "q_gen_valid_rate": 1.0, "mean_question_reward": 0.7218, "q_quality_rate": 1.0, "q_topic_match": 0.35, "q_difficulty_fit": 0.9511, "q_clarity": 1.0, "q_novelty": 0.458, "q_solvability": 1.0, "accuracy": 0.9262, "combined_score": 0.9262, "correct_rate": 0.8, "prm_mean": 0.9072, "prm_final": 0.9404, "sympy_mean": 0.0, "format_mean": 1.0, "n_scored": 150, "total": 150, "final_answer_correct": 120, "final_answer_accuracy": 0.8}
17
+ {"iteration": 16, "loss": 0.0003296181123005226, "mean_reward": 0.9146047620088099, "std_reward": 0.17273258044260062, "batch_accuracy": 0.9497716894977168, "grounded_accuracy": 0.9385474860335196, "gt_match_rate": 0.8324, "step_accuracy": 0.8955234709424654, "lccp": 0.8438994897095455, "n_groups": 15, "skipped_groups": 7, "learning_rate": 4.742276057719723e-06, "iter_time_s": 575.5289459908381, "training_phase": "SELFPLAY_RAMP", "effective_sp_ratio": 0.107, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.089, "extraction_success_rate": 0.94, "chain_scoring_active": 0, "n_self_play_groups": 2, "q_gen_attempts": 2, "q_gen_valid": 2, "q_gen_valid_rate": 1.0, "mean_question_reward": 0.7878, "q_quality_rate": 1.0, "q_topic_match": 0.875, "q_difficulty_fit": 0.5838, "q_clarity": 1.0, "q_novelty": 0.4475, "q_solvability": 0.96}
18
+ {"iteration": 17, "loss": -0.00013719029248022708, "mean_reward": 0.8881227328092163, "std_reward": 0.1950058307020988, "batch_accuracy": 0.9383259911894273, "grounded_accuracy": 0.9161676646706587, "gt_match_rate": 0.7006, "step_accuracy": 0.8557955517536356, "lccp": 0.7682349586541203, "n_groups": 20, "skipped_groups": 3, "learning_rate": 4.675507862678258e-06, "iter_time_s": 616.0185732548125, "training_phase": "SELFPLAY_RAMP", "effective_sp_ratio": 0.143, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": -0.191, "extraction_success_rate": 1.0, "chain_scoring_active": 0, "n_self_play_groups": 3, "q_gen_attempts": 3, "q_gen_valid": 3, "q_gen_valid_rate": 1.0, "mean_question_reward": 0.7982, "q_quality_rate": 1.0, "q_topic_match": 0.69, "q_difficulty_fit": 0.8892, "q_clarity": 1.0, "q_novelty": 0.4616, "q_solvability": 1.0}
19
+ {"iteration": 18, "loss": 7.917114673641903e-05, "mean_reward": 0.8664005137011263, "std_reward": 0.178010205898339, "batch_accuracy": 0.9539748953974896, "grounded_accuracy": 0.9433962264150944, "gt_match_rate": 0.5912, "step_accuracy": 0.830780173704702, "lccp": 0.6920110811620246, "n_groups": 19, "skipped_groups": 5, "learning_rate": 4.601713698260728e-06, "iter_time_s": 550.572628196096, "training_phase": "SELFPLAY_RAMP", "effective_sp_ratio": 0.179, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.021, "extraction_success_rate": 0.98, "chain_scoring_active": 0, "n_self_play_groups": 4, "q_gen_attempts": 4, "q_gen_valid": 4, "q_gen_valid_rate": 1.0, "mean_question_reward": 0.7394, "q_quality_rate": 1.0, "q_topic_match": 0.6375, "q_difficulty_fit": 0.6293, "q_clarity": 1.0, "q_novelty": 0.452, "q_solvability": 0.9762}
20
+ {"iteration": 19, "loss": 0.00015087392284840462, "mean_reward": 0.8912812767256229, "std_reward": 0.1726645221785555, "batch_accuracy": 0.9535864978902954, "grounded_accuracy": 0.9490445859872612, "gt_match_rate": 0.7643, "step_accuracy": 0.8513975055376328, "lccp": 0.7568744772566428, "n_groups": 16, "skipped_groups": 8, "learning_rate": 4.521162831370364e-06, "iter_time_s": 522.4289600129705, "training_phase": "SELFPLAY_RAMP", "effective_sp_ratio": 0.214, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.075, "extraction_success_rate": 0.98, "chain_scoring_active": 0, "n_self_play_groups": 4, "q_gen_attempts": 4, "q_gen_valid": 4, "q_gen_valid_rate": 1.0, "mean_question_reward": 0.7331, "q_quality_rate": 1.0, "q_topic_match": 0.4813, "q_difficulty_fit": 0.8466, "q_clarity": 1.0, "q_novelty": 0.4564, "q_solvability": 0.9725}
21
+ {"iteration": 20, "loss": 0.00024373266084391312, "mean_reward": 0.8962914079724992, "std_reward": 0.1778417367801085, "batch_accuracy": 0.927710843373494, "grounded_accuracy": 0.9060402684563759, "gt_match_rate": 0.7987, "step_accuracy": 0.9253, "lccp": 0.8428, "n_groups": 18, "skipped_groups": 7, "learning_rate": 4.434149183384978e-06, "iter_time_s": 619.8863487117924, "training_phase": "SELFPLAY_RAMP", "effective_sp_ratio": 0.25, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": -0.118, "extraction_success_rate": 0.96, "chain_scoring_active": 0, "n_self_play_groups": 5, "q_gen_attempts": 5, "q_gen_valid": 5, "q_gen_valid_rate": 1.0, "mean_question_reward": 0.77, "q_quality_rate": 1.0, "q_topic_match": 0.723, "q_difficulty_fit": 0.703, "q_clarity": 1.0, "q_novelty": 0.4741, "q_solvability": 0.945, "accuracy": 0.9234, "combined_score": 0.9234, "correct_rate": 0.8, "prm_mean": 0.9056, "prm_final": 0.9353, "sympy_mean": 0.0, "format_mean": 1.0, "n_scored": 150, "total": 150, "final_answer_correct": 120, "final_answer_accuracy": 0.8}
22
+ {"iteration": 21, "loss": 0.0001916794737033862, "mean_reward": 0.8417323480901788, "std_reward": 0.1879809468583581, "batch_accuracy": 0.9230769230769231, "grounded_accuracy": 0.9142857142857143, "gt_match_rate": 0.7357, "step_accuracy": 0.8195039682539682, "lccp": 0.6930612244897959, "n_groups": 21, "skipped_groups": 5, "learning_rate": 4.340990257669732e-06, "iter_time_s": 490.36693838005885, "training_phase": "SELFPLAY_RAMP", "effective_sp_ratio": 0.286, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.209, "extraction_success_rate": 0.92, "chain_scoring_active": 0, "n_self_play_groups": 6, "q_gen_attempts": 6, "q_gen_valid": 6, "q_gen_valid_rate": 1.0, "mean_question_reward": 0.6972, "q_quality_rate": 1.0, "q_topic_match": 0.5742, "q_difficulty_fit": 0.4754, "q_clarity": 1.0, "q_novelty": 0.4493, "q_solvability": 0.9625}
23
+ {"iteration": 22, "loss": 0.000578732604299148, "mean_reward": 0.9175190043251262, "std_reward": 0.12424225720214971, "batch_accuracy": 0.984313725490196, "grounded_accuracy": 0.9852941176470589, "gt_match_rate": 0.9044, "step_accuracy": 0.9647345301757068, "lccp": 0.9284886681945506, "n_groups": 20, "skipped_groups": 6, "learning_rate": 4.2420259810417895e-06, "iter_time_s": 611.8722857821267, "training_phase": "SELFPLAY_RAMP", "effective_sp_ratio": 0.321, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.145, "extraction_success_rate": 0.92, "chain_scoring_active": 0, "n_self_play_groups": 6, "q_gen_attempts": 6, "q_gen_valid": 6, "q_gen_valid_rate": 1.0, "mean_question_reward": 0.6998, "q_quality_rate": 1.0, "q_topic_match": 0.6189, "q_difficulty_fit": 0.3856, "q_clarity": 1.0, "q_novelty": 0.4571, "q_solvability": 0.979}
24
+ {"iteration": 23, "loss": 0.0006137362383419208, "mean_reward": 0.9206978778568132, "std_reward": 0.14741914089456262, "batch_accuracy": 0.9770114942528736, "grounded_accuracy": 0.9508196721311475, "gt_match_rate": 0.8033, "step_accuracy": 0.9075003548364204, "lccp": 0.847631466893762, "n_groups": 18, "skipped_groups": 9, "learning_rate": 4.137617463414222e-06, "iter_time_s": 444.32088500098325, "training_phase": "SELFPLAY_RAMP", "effective_sp_ratio": 0.357, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.143, "extraction_success_rate": 1.0, "chain_scoring_active": 0, "n_self_play_groups": 7, "q_gen_attempts": 7, "q_gen_valid": 7, "q_gen_valid_rate": 1.0, "mean_question_reward": 0.726, "q_quality_rate": 1.0, "q_topic_match": 0.5621, "q_difficulty_fit": 0.6634, "q_clarity": 1.0, "q_novelty": 0.4412, "q_solvability": 0.9885}
25
+ {"iteration": 24, "loss": -0.00021296025724950595, "mean_reward": 0.8795895609748888, "std_reward": 0.1733128827089799, "batch_accuracy": 0.9357142857142857, "grounded_accuracy": 0.9333333333333333, "gt_match_rate": 0.7917, "step_accuracy": 0.8988194444444446, "lccp": 0.8122916666666666, "n_groups": 20, "skipped_groups": 8, "learning_rate": 4.0281456801451e-06, "iter_time_s": 471.6989622868132, "training_phase": "SELFPLAY_RAMP", "effective_sp_ratio": 0.393, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.082, "extraction_success_rate": 0.98, "chain_scoring_active": 0, "n_self_play_groups": 8, "q_gen_attempts": 8, "q_gen_valid": 8, "q_gen_valid_rate": 1.0, "mean_question_reward": 0.6621, "q_quality_rate": 1.0, "q_topic_match": 0.5344, "q_difficulty_fit": 0.3108, "q_clarity": 1.0, "q_novelty": 0.4408, "q_solvability": 0.9688}
26
+ {"iteration": 25, "loss": 0.0003441530472758002, "mean_reward": 0.8445275205076134, "std_reward": 0.20865777545087066, "batch_accuracy": 0.9273356401384083, "grounded_accuracy": 0.8532110091743119, "gt_match_rate": 0.6055, "step_accuracy": 0.9198, "lccp": 0.8468, "n_groups": 28, "skipped_groups": 1, "learning_rate": 3.9140100818997275e-06, "iter_time_s": 524.655717118876, "training_phase": "SELFPLAY_RAMP", "effective_sp_ratio": 0.429, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.127, "extraction_success_rate": 0.94, "chain_scoring_active": 0, "n_self_play_groups": 9, "q_gen_attempts": 9, "q_gen_valid": 9, "q_gen_valid_rate": 1.0, "mean_question_reward": 0.6471, "q_quality_rate": 1.0, "q_topic_match": 0.505, "q_difficulty_fit": 0.2634, "q_clarity": 1.0, "q_novelty": 0.4394, "q_solvability": 0.9672, "accuracy": 0.9221, "combined_score": 0.9221, "correct_rate": 0.7933, "prm_mean": 0.9034, "prm_final": 0.9329, "sympy_mean": 0.0, "format_mean": 1.0, "n_scored": 150, "total": 150, "final_answer_correct": 119, "final_answer_accuracy": 0.7933333333333333}
27
+ {"iteration": 26, "loss": 0.0004209962865808428, "mean_reward": 0.8666489827432893, "std_reward": 0.1796360842988206, "batch_accuracy": 0.9204152249134948, "grounded_accuracy": 0.926605504587156, "gt_match_rate": 0.789, "step_accuracy": 0.8898463666812292, "lccp": 0.7943024610455803, "n_groups": 26, "skipped_groups": 3, "learning_rate": 3.795627137098479e-06, "iter_time_s": 509.6774504878558, "training_phase": "SELFPLAY_RAMP", "effective_sp_ratio": 0.464, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.065, "extraction_success_rate": 0.94, "chain_scoring_active": 0, "n_self_play_groups": 9, "q_gen_attempts": 9, "q_gen_valid": 9, "q_gen_valid_rate": 1.0, "mean_question_reward": 0.6792, "q_quality_rate": 1.0, "q_topic_match": 0.6639, "q_difficulty_fit": 0.2476, "q_clarity": 1.0, "q_novelty": 0.4488, "q_solvability": 0.9317}
28
+ {"iteration": 27, "loss": -0.00022697661013808103, "mean_reward": 0.877933982604161, "std_reward": 0.1628662024521015, "batch_accuracy": 0.9563758389261745, "grounded_accuracy": 0.9393939393939394, "gt_match_rate": 0.6869, "step_accuracy": 0.8616281866281865, "lccp": 0.7406565656565657, "n_groups": 25, "skipped_groups": 5, "learning_rate": 3.673428812268702e-06, "iter_time_s": 597.5212381640449, "training_phase": "SELFPLAY_RAMP", "effective_sp_ratio": 0.5, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.067, "extraction_success_rate": 0.92, "chain_scoring_active": 0, "n_self_play_groups": 10, "q_gen_attempts": 10, "q_gen_valid": 10, "q_gen_valid_rate": 1.0, "mean_question_reward": 0.6831, "q_quality_rate": 1.0, "q_topic_match": 0.5699, "q_difficulty_fit": 0.3583, "q_clarity": 1.0, "q_novelty": 0.4584, "q_solvability": 0.9759}
29
+ {"iteration": 28, "loss": 4.199455770111822e-05, "mean_reward": 0.8695997487614422, "std_reward": 0.15915376074701193, "batch_accuracy": 0.9419354838709677, "grounded_accuracy": 0.8777777777777778, "gt_match_rate": 0.6556, "step_accuracy": 0.8334434828062279, "lccp": 0.6186230200445887, "n_groups": 29, "skipped_groups": 2, "learning_rate": 3.5478609958457035e-06, "iter_time_s": 603.0997926741838, "training_phase": "SELFPLAY_RAMP", "effective_sp_ratio": 0.536, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.047, "extraction_success_rate": 0.8, "chain_scoring_active": 0, "n_self_play_groups": 11, "q_gen_attempts": 11, "q_gen_valid": 11, "q_gen_valid_rate": 1.0, "mean_question_reward": 0.6693, "q_quality_rate": 1.0, "q_topic_match": 0.5931, "q_difficulty_fit": 0.23, "q_clarity": 1.0, "q_novelty": 0.4489, "q_solvability": 0.9836}
30
+ {"iteration": 29, "loss": 0.0003765096731578004, "mean_reward": 0.8674408392873937, "std_reward": 0.17082623284979875, "batch_accuracy": 0.9470198675496688, "grounded_accuracy": 0.8928571428571429, "gt_match_rate": 0.7262, "step_accuracy": 0.8674065194639727, "lccp": 0.7603936306964257, "n_groups": 28, "skipped_groups": 3, "learning_rate": 3.419381871174205e-06, "iter_time_s": 579.6904674370307, "training_phase": "SELFPLAY_RAMP", "effective_sp_ratio": 0.571, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.123, "extraction_success_rate": 0.84, "chain_scoring_active": 0, "n_self_play_groups": 11, "q_gen_attempts": 11, "q_gen_valid": 11, "q_gen_valid_rate": 1.0, "mean_question_reward": 0.6496, "q_quality_rate": 1.0, "q_topic_match": 0.5636, "q_difficulty_fit": 0.1695, "q_clarity": 1.0, "q_novelty": 0.4425, "q_solvability": 0.9739}
31
+ {"iteration": 30, "loss": -0.00029927124827130075, "mean_reward": 0.8705812118012987, "std_reward": 0.16025951815561293, "batch_accuracy": 0.9655172413793104, "grounded_accuracy": 0.95, "gt_match_rate": 0.8, "step_accuracy": 0.9232, "lccp": 0.85, "n_groups": 27, "skipped_groups": 5, "learning_rate": 3.2884602446470037e-06, "iter_time_s": 503.08798154001124, "training_phase": "SELFPLAY_RAMP", "effective_sp_ratio": 0.607, "selfplay_suspended": 0, "chain_arith_score": null, "chain_dep_score": null, "chain_integrity_score": null, "sp_chain_integrity_score": null, "chain_prm_correlation": 0.099, "extraction_success_rate": 0.92, "chain_scoring_active": 0, "n_self_play_groups": 12, "q_gen_attempts": 12, "q_gen_valid": 12, "q_gen_valid_rate": 1.0, "mean_question_reward": 0.6764, "q_quality_rate": 1.0, "q_topic_match": 0.6752, "q_difficulty_fit": 0.1485, "q_clarity": 1.0, "q_novelty": 0.4566, "q_solvability": 0.9699, "accuracy": 0.9204, "combined_score": 0.9204, "correct_rate": 0.7933, "prm_mean": 0.9044, "prm_final": 0.9289, "sympy_mean": 0.0, "format_mean": 1.0, "n_scored": 150, "total": 150, "final_answer_correct": 119, "final_answer_accuracy": 0.7933333333333333}
models.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Data models for the AxiomForgeAI math RL environment.
9
+
10
+ The AxiomForgeAI environment presents math questions drawn from an adaptive
11
+ curriculum; external agents submit step-by-step solutions and receive scored
12
+ observations. The environment integrates with the GRPO training pipeline
13
+ defined in scripts/run_grpo_training.py.
14
+ """
15
+
16
+ from openenv.core.env_server.types import Action, Observation
17
+ from pydantic import Field
18
+
19
+
20
+ class AxiomforgeaiAction(Action):
21
+ """Action for the AxiomForgeAI math environment.
22
+
23
+ The agent submits a step-by-step solution to the current question.
24
+ Solutions should follow the format::
25
+
26
+ Step 1: <reasoning>
27
+ Step 2: <reasoning>
28
+ ...
29
+ Final Answer: <numeric value>
30
+ """
31
+
32
+ solution: str = Field(
33
+ default="",
34
+ description=(
35
+ "Step-by-step solution to the current math question. "
36
+ "Use 'Step N: ...' lines and end with 'Final Answer: <value>'."
37
+ ),
38
+ )
39
+
40
+
41
+ class AxiomforgeaiObservation(Observation):
42
+ """Observation from the AxiomForgeAI math environment.
43
+
44
+ On reset the question is populated and reward/feedback are empty.
45
+ After a step the reward and feedback reflect the quality of the submitted
46
+ solution; done=True signals the end of the single-step episode.
47
+ """
48
+
49
+ question: str = Field(
50
+ default="",
51
+ description="Math question the agent must solve.",
52
+ )
53
+ topic: str = Field(
54
+ default="",
55
+ description="Mathematical topic of the question (e.g. 'algebra', 'geometry').",
56
+ )
57
+ difficulty: float = Field(
58
+ default=0.5,
59
+ description="Estimated difficulty of the question in [0, 1].",
60
+ )
61
+ feedback: str = Field(
62
+ default="",
63
+ description=(
64
+ "Human-readable feedback on the submitted solution "
65
+ "(empty on reset, populated after step)."
66
+ ),
67
+ )
openenv.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ spec_version: 1
2
+ name: AxiomForgeAI
3
+ type: space
4
+ runtime: fastapi
5
+ app: server.app:app
6
+ port: 8000
7
+
pyproject.toml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ [build-system]
8
+ requires = ["setuptools>=45", "wheel"]
9
+ build-backend = "setuptools.build_meta"
10
+
11
+ [project]
12
+ name = "openenv-AxiomForgeAI"
13
+ version = "0.1.0"
14
+ description = "Axiomforgeai environment for OpenEnv"
15
+ requires-python = ">=3.10"
16
+ dependencies = [
17
+ # Core OpenEnv runtime (provides FastAPI server + HTTP client types)
18
+ # install from github
19
+ # "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git",
20
+ "openenv-core[core]>=0.2.2",
21
+ # Environment-specific dependencies
22
+ # Add all dependencies needed for your environment here
23
+ # Examples:
24
+ # "numpy>=1.19.0",
25
+ # "torch>=2.0.0",
26
+ # "gymnasium>=0.29.0",
27
+ # "openspiel>=1.0.0",
28
+ # "smolagents>=1.22.0,<2",
29
+ ]
30
+
31
+ [project.optional-dependencies]
32
+ dev = [
33
+ "pytest>=8.0.0",
34
+ "pytest-cov>=4.0.0",
35
+ ]
36
+
37
+ [project.scripts]
38
+ # Server entry point - enables running via: uv run --project . server
39
+ # or: python -m AxiomForgeAI.server.app
40
+ server = "AxiomForgeAI.server.app:main"
41
+
42
+ [tool.setuptools]
43
+ include-package-data = true
44
+ packages = [
45
+ "AxiomForgeAI",
46
+ "AxiomForgeAI.server",
47
+ "src",
48
+ "src.config",
49
+ "src.rl",
50
+ "src.sft",
51
+ "src.utils",
52
+ "src.self_play",
53
+ "scripts",
54
+ ]
55
+ package-dir = { "AxiomForgeAI" = ".", "AxiomForgeAI.server" = "server" }
requirements.txt ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.2.1
2
+ aiohappyeyeballs==2.6.1
3
+ aiohttp==3.13.5
4
+ aiohttp-cors==0.8.1
5
+ aiosignal==1.4.0
6
+ airportsdata==20260315
7
+ annotated-doc==0.0.4
8
+ annotated-types==0.7.0
9
+ anyio==4.13.0
10
+ astor==0.8.1
11
+ attrs==26.1.0
12
+ bitsandbytes==0.44.1
13
+ blake3==1.0.8
14
+ certifi==2026.4.22
15
+ cffi==2.0.0
16
+ charset-normalizer==3.4.7
17
+ click==8.3.2
18
+ cloudpickle==3.1.2
19
+ colorful==0.5.8
20
+ compressed-tensors==0.9.0
21
+ cryptography==46.0.7
22
+ datasets==3.2.0
23
+ depyf==0.18.0
24
+ dill==0.3.8
25
+ diskcache==5.6.3
26
+ distlib==0.4.0
27
+ distro==1.9.0
28
+ einops==0.8.2
29
+ fastapi==0.136.0
30
+ filelock==3.29.0
31
+ frozenlist==1.8.0
32
+ fsspec==2024.9.0
33
+ gguf==0.10.0
34
+ google-api-core==2.30.3
35
+ google-auth==2.49.2
36
+ googleapis-common-protos==1.74.0
37
+ grpcio==1.80.0
38
+ h11==0.16.0
39
+ hf-xet==1.4.3
40
+ hjson==3.1.0
41
+ httpcore==1.0.9
42
+ httptools==0.7.1
43
+ httpx==0.28.1
44
+ huggingface-hub==0.36.2
45
+ idna==3.12
46
+ importlib-metadata==9.0.0
47
+ interegular==0.3.3
48
+ jinja2==3.1.6
49
+ jiter==0.14.0
50
+ jsonschema==4.26.0
51
+ jsonschema-specifications==2025.9.1
52
+ lark==1.2.2
53
+ linkify-it-py==2.1.0
54
+ lm-format-enforcer==0.10.12
55
+ markdown-it-py==4.0.0
56
+ markupsafe==3.0.3
57
+ mdit-py-plugins==0.5.0
58
+ mdurl==0.1.2
59
+ memray==1.19.3
60
+ mistral-common==1.11.0
61
+ mpmath==1.3.0
62
+ msgpack==1.1.2
63
+ msgspec==0.21.1
64
+ multidict==6.7.1
65
+ multiprocess==0.70.16
66
+ nest-asyncio==1.6.0
67
+ networkx==3.6.1
68
+ ninja==1.13.0
69
+ numpy==1.26.4
70
+ nvidia-cublas-cu12==12.4.5.8
71
+ nvidia-cuda-cupti-cu12==12.4.127
72
+ nvidia-cuda-nvrtc-cu12==12.4.127
73
+ nvidia-cuda-runtime-cu12==12.4.127
74
+ nvidia-cudnn-cu12==9.1.0.70
75
+ nvidia-cufft-cu12==11.2.1.3
76
+ nvidia-curand-cu12==10.3.5.147
77
+ nvidia-cusolver-cu12==11.6.1.9
78
+ nvidia-cusparse-cu12==12.3.1.170
79
+ nvidia-ml-py==13.595.45
80
+ nvidia-nccl-cu12==2.21.5
81
+ nvidia-nvjitlink-cu12==12.4.127
82
+ nvidia-nvtx-cu12==12.4.127
83
+ openai==2.32.0
84
+ opencensus==0.11.4
85
+ opencensus-context==0.1.3
86
+ opencv-python-headless==4.11.0.86
87
+ outlines==0.1.11
88
+ outlines-core==0.1.26
89
+ packaging==26.1
90
+ pandas==3.0.2
91
+ partial-json-parser==0.2.1.1.post7
92
+ peft==0.19.1
93
+ pillow==12.2.0
94
+ platformdirs==4.9.6
95
+ prometheus-client==0.25.0
96
+ prometheus-fastapi-instrumentator==7.1.0
97
+ propcache==0.4.1
98
+ proto-plus==1.27.2
99
+ protobuf==7.34.1
100
+ psutil==7.2.2
101
+ py-cpuinfo==9.0.0
102
+ py-spy==0.4.1
103
+ pyarrow==24.0.0
104
+ pyasn1==0.6.3
105
+ pyasn1-modules==0.4.2
106
+ pycountry==26.2.16
107
+ pycparser==3.0
108
+ pydantic==2.13.3
109
+ pydantic-core==2.46.3
110
+ pydantic-extra-types==2.11.1
111
+ pygments==2.20.0
112
+ python-dateutil==2.9.0.post0
113
+ python-discovery==1.2.2
114
+ python-dotenv==1.2.2
115
+ pyyaml==6.0.3
116
+ pyzmq==27.1.0
117
+ ray==2.39.0
118
+ referencing==0.37.0
119
+ regex==2026.4.4
120
+ requests==2.33.1
121
+ rich==15.0.0
122
+ rpds-py==0.30.0
123
+ safetensors==0.7.0
124
+ scipy>=1.14.0
125
+ sentencepiece==0.2.1
126
+ setuptools==82.0.1
127
+ six==1.17.0
128
+ smart-open==7.6.0
129
+ sniffio==1.3.1
130
+ starlette==0.52.1
131
+ sympy==1.13.1
132
+ textual==8.2.4
133
+ tiktoken==0.12.0
134
+ tokenizers==0.20.3
135
+ torch==2.5.1
136
+ torchaudio==2.5.1
137
+ torchvision==0.20.1
138
+ tqdm==4.67.3
139
+ transformers==4.46.3
140
+ triton==3.1.0
141
+ trl==0.12.1
142
+ typing-extensions==4.15.0
143
+ typing-inspection==0.4.2
144
+ uc-micro-py==2.0.0
145
+ urllib3==2.6.3
146
+ uvicorn==0.45.0
147
+ uvloop==0.22.1
148
+ virtualenv==21.2.4
149
+ vllm==0.7.0
150
+ watchfiles==1.1.1
151
+ websockets==16.0
152
+ wrapt==2.1.2
153
+ xformers==0.0.28.post3
154
+ xgrammar==0.1.33
155
+ xxhash==3.6.0
156
+ yarl==1.23.0
157
+ zipp==3.23.1
158
+ matplotlib==3.10.9
159
+ flash-attn==2.8.3
160
+ gradio>=4.44.0
scripts/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Training and evaluation scripts for math reasoning models."""
scripts/convert_gsm8k_to_sft.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Convert OpenAI GSM8K to SFT JSONL aligned with MathAgent solver format:
4
+
5
+ Step 1: ...
6
+ Step 2: ...
7
+ ...
8
+ Final Answer: <integer>
9
+
10
+ Each record uses a chat messages list for Qwen-style fine-tuning.
11
+
12
+ Usage
13
+ -----
14
+ # From Hugging Face (default; same data as in test.ipynb)
15
+ python scripts/convert_gsm8k_to_sft.py \\
16
+ --output data/sft/gsm8k_sft.jsonl \\
17
+ --splits train test
18
+
19
+ # From a saved JSONL with columns \"question\" and \"answer\" (GSM8K schema)
20
+ python scripts/convert_gsm8k_to_sft.py \\
21
+ --source jsonl \\
22
+ --input path/to/file.jsonl \\
23
+ --output data/sft/gsm8k_sft.jsonl
24
+
25
+ Requires: pip install datasets (and datasets will pull pyarrow as needed)
26
+ """
27
+
28
+ from __future__ import annotations
29
+
30
+ import argparse
31
+ import json
32
+ import re
33
+ from pathlib import Path
34
+ from typing import Any, Iterator
35
+
36
+ # Keep in sync with src.agent.math_agent.SOLVER_SYSTEM_PROMPT
37
+ SOLVER_SYSTEM_PROMPT = (
38
+ "You are a step-by-step math solver. "
39
+ "Solve the given problem one step at a time. "
40
+ "Each step must be on its own line, starting with 'Step N:'. "
41
+ "End with a line starting with 'Final Answer:'. "
42
+ "Write every mathematical expression in Python/SymPy syntax "
43
+ "so it can be verified programmatically."
44
+ )
45
+
46
+ USER_WRAPPER = (
47
+ "Solve the following problem. Show your reasoning as numbered steps, "
48
+ "then give the final numeric answer on the last line.\n\nProblem:\n{question}"
49
+ )
50
+
51
+
52
+ def parse_gsm8k_answer(raw_answer: str) -> tuple[str, str]:
53
+ """
54
+ Split GSM8K 'answer' field into reasoning text and final integer string.
55
+
56
+ GSM8K ends solutions with a line like: #### 42
57
+ """
58
+ text = raw_answer.strip()
59
+ parts = re.split(r"\s*####\s*", text, maxsplit=1)
60
+ reasoning = parts[0].strip()
61
+ final = parts[1].strip() if len(parts) > 1 else ""
62
+ # Normalize final (sometimes extra whitespace or commas)
63
+ final = re.sub(r"[,\s]+", "", final)
64
+ final_match = re.search(r"-?\d+", final)
65
+ final_clean = final_match.group(0) if final_match else final
66
+ return reasoning, final_clean
67
+
68
+
69
+ def reasoning_to_step_lines(reasoning: str) -> list[str]:
70
+ """Turn reasoning into non-empty lines; each line becomes one Step N:."""
71
+ lines: list[str] = []
72
+ for raw in reasoning.splitlines():
73
+ line = raw.strip()
74
+ if line:
75
+ lines.append(line)
76
+ if not lines:
77
+ # Rare: single blob without newlines β€” split on sentence boundaries lightly
78
+ blob = reasoning.strip()
79
+ if blob:
80
+ chunks = re.split(r"(?<=[.!?])\s+", blob)
81
+ lines = [c.strip() for c in chunks if c.strip()]
82
+ return lines
83
+
84
+
85
+ def build_assistant_content(reasoning: str, final_answer: str) -> str:
86
+ lines = reasoning_to_step_lines(reasoning)
87
+ out_parts: list[str] = []
88
+ for i, line in enumerate(lines, start=1):
89
+ # Prefer SymPy-friendly numerics: ** not ^, ascii-friendly
90
+ cleaned = line.replace("^", "**")
91
+ out_parts.append(f"Step {i}: {cleaned}")
92
+ body = "\n".join(out_parts)
93
+ if final_answer:
94
+ body = f"{body}\nFinal Answer: {final_answer}" if body else f"Final Answer: {final_answer}"
95
+ return body
96
+
97
+
98
+ def row_to_record(
99
+ question: str,
100
+ answer: str,
101
+ example_id: str,
102
+ split: str,
103
+ ) -> dict[str, Any] | None:
104
+ reasoning, final_answer = parse_gsm8k_answer(answer)
105
+ if not final_answer and "####" not in answer:
106
+ return None
107
+ assistant = build_assistant_content(reasoning, final_answer)
108
+ if not assistant.strip():
109
+ return None
110
+
111
+ user_content = USER_WRAPPER.format(question=question.strip())
112
+
113
+ return {
114
+ "id": f"gsm8k_{example_id}",
115
+ "skill_id": "gsm8k_grade_school",
116
+ "source": "openai/gsm8k",
117
+ "split": split,
118
+ "messages": [
119
+ {"role": "system", "content": SOLVER_SYSTEM_PROMPT},
120
+ {"role": "user", "content": user_content},
121
+ {"role": "assistant", "content": assistant},
122
+ ],
123
+ # Convenience for non-chat trainers
124
+ "text": f"<|system|>\n{SOLVER_SYSTEM_PROMPT}\n<|user|>\n{user_content}\n<|assistant|>\n{assistant}",
125
+ }
126
+
127
+
128
+ def iter_hf_rows(dataset_name: str, config: str, splits: list[str]) -> Iterator[tuple[str, str, dict]]:
129
+ from datasets import load_dataset
130
+
131
+ ds = load_dataset(dataset_name, config)
132
+ for split in splits:
133
+ if split not in ds:
134
+ raise KeyError(f"Split {split!r} not in dataset. Available: {list(ds.keys())}")
135
+ for i, row in enumerate(ds[split]):
136
+ yield f"{split}_{i}", split, row
137
+
138
+
139
+ def main() -> None:
140
+ p = argparse.ArgumentParser(description="Convert GSM8K to SFT JSONL (chat messages).")
141
+ p.add_argument(
142
+ "--source",
143
+ choices=("hf", "jsonl"),
144
+ default="hf",
145
+ help="Load from Hugging Face dataset or a local JSONL file.",
146
+ )
147
+ p.add_argument("--dataset", default="openai/gsm8k", help="HF dataset id when --source hf.")
148
+ p.add_argument("--config", default="main", help="HF config name when --source hf.")
149
+ p.add_argument("--splits", nargs="+", default=["train", "test"], help="HF splits to export.")
150
+ p.add_argument("--input", type=Path, help="Local JSONL path when --source jsonl.")
151
+ p.add_argument(
152
+ "--output",
153
+ type=Path,
154
+ default=Path("data/sft/gsm8k_sft.jsonl"),
155
+ help="Output JSONL path.",
156
+ )
157
+ args = p.parse_args()
158
+
159
+ if args.source == "jsonl" and not args.input:
160
+ raise SystemExit("--input is required when --source jsonl")
161
+
162
+ args.output.parent.mkdir(parents=True, exist_ok=True)
163
+
164
+ n_ok, n_skip = 0, 0
165
+
166
+ def process(example_id: str, split: str, row: dict) -> None:
167
+ nonlocal n_ok, n_skip
168
+ q = row.get("question", "")
169
+ a = row.get("answer", "")
170
+ rec = row_to_record(q, a, example_id, split)
171
+ if rec is None:
172
+ n_skip += 1
173
+ return
174
+ out_f.write(json.dumps(rec, ensure_ascii=False) + "\n")
175
+ n_ok += 1
176
+
177
+ with args.output.open("w", encoding="utf-8") as out_f:
178
+ if args.source == "hf":
179
+ for example_id, split, row in iter_hf_rows(args.dataset, args.config, args.splits):
180
+ process(example_id, split, row)
181
+ else:
182
+ for i, line in enumerate(args.input.open(encoding="utf-8")):
183
+ line = line.strip()
184
+ if not line:
185
+ continue
186
+ row = json.loads(line)
187
+ process(str(i), "jsonl", row)
188
+
189
+ print(f"Wrote {n_ok} examples to {args.output} ({n_skip} skipped).")
190
+
191
+
192
+ if __name__ == "__main__":
193
+ main()
scripts/create_dual_task_dataset.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Create dual-task training dataset by mixing question-generation and solution-generation examples.
4
+
5
+ This script:
6
+ 1. Loads existing solution data (GSM8K format)
7
+ 2. Loads question-generation data (synthetic)
8
+ 3. Adds task prefixes to distinguish tasks
9
+ 4. Mixes datasets according to specified ratio
10
+ 5. Shuffles and splits into train/validation
11
+
12
+ Usage:
13
+ python scripts/create_dual_task_dataset.py \
14
+ --solution-data data/sft/gsm8k_sft.jsonl \
15
+ --question-data data/sft/question_generation.jsonl \
16
+ --output-train data/sft/dual_task_train.jsonl \
17
+ --output-val data/sft/dual_task_val.jsonl \
18
+ --mix-ratio 0.8 \
19
+ --val-split 0.1
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ import argparse
25
+ import json
26
+ import random
27
+ import sys
28
+ from pathlib import Path
29
+ from typing import Any
30
+
31
+ ROOT = Path(__file__).resolve().parents[1]
32
+ sys.path.insert(0, str(ROOT))
33
+
34
+ from src.config.prompts import SOLVE_TASK_PREFIX, GENERATE_TASK_PREFIX
35
+
36
+
37
+ def load_jsonl(path: Path) -> list[dict[str, Any]]:
38
+ """Load JSONL file into list of records."""
39
+ records = []
40
+ with path.open(encoding="utf-8") as f:
41
+ for line in f:
42
+ line = line.strip()
43
+ if line:
44
+ records.append(json.loads(line))
45
+ return records
46
+
47
+
48
+ def add_solve_prefix(record: dict[str, Any]) -> dict[str, Any]:
49
+ """
50
+ Add 'Solve Problem' task prefix to user message.
51
+
52
+ This signals the model to generate a step-by-step solution.
53
+ """
54
+ modified = record.copy()
55
+ modified["messages"] = []
56
+
57
+ for msg in record["messages"]:
58
+ new_msg = msg.copy()
59
+ if msg["role"] == "user":
60
+ # Add task prefix to user content
61
+ content = msg["content"]
62
+ if not content.startswith(SOLVE_TASK_PREFIX):
63
+ new_msg["content"] = SOLVE_TASK_PREFIX + content
64
+ modified["messages"].append(new_msg)
65
+
66
+ # Update text field if present
67
+ if "text" in modified:
68
+ # Find and update user section
69
+ text = modified["text"]
70
+ if "<|user|>" in text:
71
+ parts = text.split("<|user|>")
72
+ if len(parts) > 1:
73
+ user_part = parts[1]
74
+ if not user_part.strip().startswith(SOLVE_TASK_PREFIX):
75
+ parts[1] = f"\n{SOLVE_TASK_PREFIX}" + user_part
76
+ modified["text"] = "<|user|>".join(parts)
77
+
78
+ # Mark as solve task
79
+ modified["task_type"] = "solve"
80
+
81
+ return modified
82
+
83
+
84
+ def verify_question_prefix(record: dict[str, Any]) -> dict[str, Any]:
85
+ """
86
+ Verify question generation record has proper prefix.
87
+
88
+ Should already have it from generation script, but double-check.
89
+ """
90
+ modified = record.copy()
91
+ modified["messages"] = []
92
+
93
+ for msg in record["messages"]:
94
+ new_msg = msg.copy()
95
+ if msg["role"] == "user":
96
+ content = msg["content"]
97
+ if not content.startswith(GENERATE_TASK_PREFIX):
98
+ new_msg["content"] = GENERATE_TASK_PREFIX + content
99
+ modified["messages"].append(new_msg)
100
+
101
+ # Update text field if present
102
+ if "text" in modified:
103
+ text = modified["text"]
104
+ if "<|user|>" in text:
105
+ parts = text.split("<|user|>")
106
+ if len(parts) > 1:
107
+ user_part = parts[1]
108
+ if not user_part.strip().startswith(GENERATE_TASK_PREFIX):
109
+ parts[1] = f"\n{GENERATE_TASK_PREFIX}" + user_part
110
+ modified["text"] = "<|user|>".join(parts)
111
+
112
+ # Mark as question generation task
113
+ modified["task_type"] = "generate"
114
+
115
+ return modified
116
+
117
+
118
+ def sample_with_ratio(
119
+ solution_records: list[dict[str, Any]],
120
+ question_records: list[dict[str, Any]],
121
+ mix_ratio: float,
122
+ target_total: int | None = None,
123
+ ) -> list[dict[str, Any]]:
124
+ """
125
+ Sample and mix datasets according to specified ratio.
126
+
127
+ Args:
128
+ solution_records: Solution examples
129
+ question_records: Question generation examples
130
+ mix_ratio: Fraction of solutions in final dataset (0.8 = 80% solutions, 20% questions)
131
+ target_total: Target total examples (None = use all available data)
132
+
133
+ Returns:
134
+ Mixed dataset
135
+ """
136
+ n_solutions = len(solution_records)
137
+ n_questions = len(question_records)
138
+
139
+ if target_total is None:
140
+ # Use all available data
141
+ target_total = n_solutions + n_questions
142
+
143
+ # Calculate target counts
144
+ n_sol_target = int(target_total * mix_ratio)
145
+ n_q_target = target_total - n_sol_target
146
+
147
+ # Check availability
148
+ if n_sol_target > n_solutions:
149
+ print(f"Warning: Requested {n_sol_target} solutions but only {n_solutions} available.")
150
+ n_sol_target = n_solutions
151
+
152
+ if n_q_target > n_questions:
153
+ print(f"Warning: Requested {n_q_target} questions but only {n_questions} available.")
154
+ n_q_target = n_questions
155
+
156
+ # Sample
157
+ selected_solutions = random.sample(solution_records, n_sol_target)
158
+ selected_questions = random.sample(question_records, n_q_target)
159
+
160
+ print(f"Sampled {n_sol_target} solutions and {n_q_target} questions")
161
+ print(f"Actual ratio: {n_sol_target/(n_sol_target+n_q_target):.2%} solutions, "
162
+ f"{n_q_target/(n_sol_target+n_q_target):.2%} questions")
163
+
164
+ return selected_solutions + selected_questions
165
+
166
+
167
+ def write_jsonl(records: list[dict[str, Any]], path: Path) -> None:
168
+ """Write records to JSONL file."""
169
+ path.parent.mkdir(parents=True, exist_ok=True)
170
+ with path.open("w", encoding="utf-8") as f:
171
+ for record in records:
172
+ f.write(json.dumps(record, ensure_ascii=False) + "\n")
173
+
174
+
175
+ def main() -> None:
176
+ parser = argparse.ArgumentParser(
177
+ description="Create dual-task training dataset from solution and question-generation examples."
178
+ )
179
+ parser.add_argument(
180
+ "--solution-data",
181
+ type=Path,
182
+ required=True,
183
+ help="Path to solution training data (GSM8K format)",
184
+ )
185
+ parser.add_argument(
186
+ "--question-data",
187
+ type=Path,
188
+ required=True,
189
+ help="Path to question-generation training data",
190
+ )
191
+ parser.add_argument(
192
+ "--output-train",
193
+ type=Path,
194
+ required=True,
195
+ help="Output path for training split",
196
+ )
197
+ parser.add_argument(
198
+ "--output-val",
199
+ type=Path,
200
+ required=True,
201
+ help="Output path for validation split",
202
+ )
203
+ parser.add_argument(
204
+ "--mix-ratio",
205
+ type=float,
206
+ default=0.8,
207
+ help="Fraction of solutions in mixed dataset (default: 0.8 = 80%% solutions)",
208
+ )
209
+ parser.add_argument(
210
+ "--val-split",
211
+ type=float,
212
+ default=0.1,
213
+ help="Fraction of data to use for validation (default: 0.1 = 10%%)",
214
+ )
215
+ parser.add_argument(
216
+ "--seed",
217
+ type=int,
218
+ default=42,
219
+ help="Random seed for reproducibility",
220
+ )
221
+ parser.add_argument(
222
+ "--max-total",
223
+ type=int,
224
+ default=None,
225
+ help="Maximum total examples to include (None = use all available)",
226
+ )
227
+ args = parser.parse_args()
228
+
229
+ # Validate inputs
230
+ if not args.solution_data.exists():
231
+ raise SystemExit(f"Error: Solution data not found at {args.solution_data}")
232
+ if not args.question_data.exists():
233
+ raise SystemExit(f"Error: Question data not found at {args.question_data}")
234
+
235
+ if not (0 < args.mix_ratio < 1):
236
+ raise SystemExit("Error: --mix-ratio must be between 0 and 1")
237
+ if not (0 < args.val_split < 1):
238
+ raise SystemExit("Error: --val-split must be between 0 and 1")
239
+
240
+ # Set random seed
241
+ random.seed(args.seed)
242
+
243
+ print("=" * 60)
244
+ print("Dual-Task Dataset Creation")
245
+ print("=" * 60)
246
+
247
+ # Load data
248
+ print("\n1. Loading data...")
249
+ print(f" Solution data: {args.solution_data}")
250
+ solution_records = load_jsonl(args.solution_data)
251
+ print(f" Loaded {len(solution_records)} solution examples")
252
+
253
+ print(f" Question data: {args.question_data}")
254
+ question_records = load_jsonl(args.question_data)
255
+ print(f" Loaded {len(question_records)} question-generation examples")
256
+
257
+ # Add task prefixes
258
+ print("\n2. Adding task prefixes...")
259
+ print(" Adding 'Solve Problem' prefix to solution examples...")
260
+ solution_records = [add_solve_prefix(r) for r in solution_records]
261
+
262
+ print(" Verifying 'Generate Question' prefix on question examples...")
263
+ question_records = [verify_question_prefix(r) for r in question_records]
264
+
265
+ # Mix datasets
266
+ print(f"\n3. Mixing datasets (ratio: {args.mix_ratio:.0%} solutions, {1-args.mix_ratio:.0%} questions)...")
267
+ mixed_records = sample_with_ratio(
268
+ solution_records=solution_records,
269
+ question_records=question_records,
270
+ mix_ratio=args.mix_ratio,
271
+ target_total=args.max_total,
272
+ )
273
+
274
+ # Shuffle
275
+ print(f"\n4. Shuffling {len(mixed_records)} total examples...")
276
+ random.shuffle(mixed_records)
277
+
278
+ # Split train/val
279
+ n_val = int(len(mixed_records) * args.val_split)
280
+ n_train = len(mixed_records) - n_val
281
+
282
+ train_records = mixed_records[:n_train]
283
+ val_records = mixed_records[n_train:]
284
+
285
+ print(f"\n5. Splitting data:")
286
+ print(f" Training: {len(train_records)} examples ({len(train_records)/len(mixed_records):.1%})")
287
+ print(f" Validation: {len(val_records)} examples ({len(val_records)/len(mixed_records):.1%})")
288
+
289
+ # Verify split composition
290
+ train_solve = sum(1 for r in train_records if r.get("task_type") == "solve")
291
+ train_gen = sum(1 for r in train_records if r.get("task_type") == "generate")
292
+ val_solve = sum(1 for r in val_records if r.get("task_type") == "solve")
293
+ val_gen = sum(1 for r in val_records if r.get("task_type") == "generate")
294
+
295
+ print(f"\n Train composition:")
296
+ print(f" Solve: {train_solve} ({train_solve/len(train_records):.1%})")
297
+ print(f" Generate: {train_gen} ({train_gen/len(train_records):.1%})")
298
+
299
+ print(f" Val composition:")
300
+ print(f" Solve: {val_solve} ({val_solve/len(val_records):.1%})")
301
+ print(f" Generate: {val_gen} ({val_gen/len(val_records):.1%})")
302
+
303
+ # Write outputs
304
+ print(f"\n6. Writing output files...")
305
+ print(f" Training data: {args.output_train}")
306
+ write_jsonl(train_records, args.output_train)
307
+
308
+ print(f" Validation data: {args.output_val}")
309
+ write_jsonl(val_records, args.output_val)
310
+
311
+ print("\n" + "=" * 60)
312
+ print("Dual-task dataset creation complete!")
313
+ print("=" * 60)
314
+ print(f"\nOutput files:")
315
+ print(f" Train: {args.output_train} ({len(train_records)} examples)")
316
+ print(f" Val: {args.output_val} ({len(val_records)} examples)")
317
+ print(f"\nNext step: Train dual-task model using these files")
318
+
319
+
320
+ if __name__ == "__main__":
321
+ main()
scripts/demo_before_after.py ADDED
@@ -0,0 +1,591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Before / after demo β€” baseline vs GRPO-trained policy.
2
+
3
+ Designed for hackathon judges: loads both models, runs greedy evaluation on
4
+ a fixed problem set, and prints a clean side-by-side comparison with full
5
+ solution text for the most interesting examples.
6
+
7
+ Features
8
+ --------
9
+ * Handles all checkpoint types: HF model IDs, GRPO full-weight saves,
10
+ PEFT/LoRA adapter directories.
11
+ * Automatically loads the chat template from the base model when the
12
+ checkpoint tokenizer doesn't have one (fixes the 0% accuracy bug that
13
+ silently swallows TemplateErrors).
14
+ * Reads ``metrics.jsonl`` (if present) and prints the full accuracy curve,
15
+ showing judges the training progression at a glance.
16
+ * Saves machine-readable JSON (for grading scripts) and prints a human-
17
+ readable Markdown table.
18
+ * Shows full solution text for the best wins and worst regressions.
19
+
20
+ Quick-start
21
+ -----------
22
+ After a GRPO run, point at ``best_policy/``::
23
+
24
+ python scripts/demo_before_after.py \\
25
+ --baseline-model checkpoints/dual_task_v1 \\
26
+ --trained-model checkpoints/grpo/<run>/best_policy \\
27
+ --problems data/sft/gsm8k_sft.jsonl \\
28
+ --max-samples 100
29
+
30
+ Include the training curve::
31
+
32
+ python scripts/demo_before_after.py \\
33
+ --baseline-model checkpoints/dual_task_v1 \\
34
+ --trained-model checkpoints/grpo/<run>/best_policy \\
35
+ --metrics-jsonl checkpoints/grpo/<run>/metrics.jsonl \\
36
+ --problems data/sft/gsm8k_sft.jsonl \\
37
+ --max-samples 100 \\
38
+ --records-out results/demo.json
39
+ """
40
+
41
+ from __future__ import annotations
42
+
43
+ import argparse
44
+ import json
45
+ import logging
46
+ import re
47
+ import sys
48
+ import time
49
+ import types
50
+ from dataclasses import dataclass, field
51
+ from pathlib import Path
52
+ from typing import Dict, List, Optional, Tuple
53
+
54
+ import torch
55
+ from peft import PeftModel
56
+ from tqdm.auto import tqdm
57
+ from transformers import AutoModelForCausalLM, AutoTokenizer
58
+
59
+ sys.path.insert(0, str(Path(__file__).parent.parent))
60
+ from src.sft.solution_format import extract_final_answer_numeric_str
61
+ from src.utils.attn_backend import select_attn_implementation
62
+
63
+ logging.basicConfig(
64
+ level=logging.INFO,
65
+ format="%(asctime)s %(levelname)-8s %(name)s - %(message)s",
66
+ )
67
+ logger = logging.getLogger(__name__)
68
+
69
+ _SEP = "=" * 78
70
+ _SEP2 = "-" * 78
71
+
72
+
73
+ # ---------------------------------------------------------------------------
74
+ # Data
75
+ # ---------------------------------------------------------------------------
76
+
77
+ @dataclass
78
+ class Problem:
79
+ question: str
80
+ gold_final: str
81
+
82
+
83
+ def _parse_gold(answer: str) -> str:
84
+ m = re.search(r"####\s*([-0-9.,/ ]+)", answer)
85
+ if m:
86
+ return m.group(1).strip().replace(",", "")
87
+ return answer.strip().splitlines()[-1].strip()
88
+
89
+
90
+ def _load_problems(path: Path, max_samples: int) -> List[Problem]:
91
+ """Accept GSM8K ``{question, answer}`` or SFT ``{messages}`` JSONL."""
92
+ out: List[Problem] = []
93
+ with path.open(encoding="utf-8") as fh:
94
+ for line in fh:
95
+ if max_samples > 0 and len(out) >= max_samples:
96
+ break
97
+ line = line.strip()
98
+ if not line:
99
+ continue
100
+ obj = json.loads(line)
101
+ if "question" in obj and "answer" in obj:
102
+ out.append(Problem(
103
+ question=obj["question"].strip(),
104
+ gold_final=_parse_gold(obj["answer"]),
105
+ ))
106
+ elif "messages" in obj:
107
+ user = next(
108
+ (m["content"] for m in obj["messages"] if m.get("role") == "user"), ""
109
+ ).strip()
110
+ asst = next(
111
+ (m["content"] for m in obj["messages"] if m.get("role") == "assistant"), ""
112
+ )
113
+ gold = extract_final_answer_numeric_str(asst) or ""
114
+ out.append(Problem(question=user, gold_final=gold.strip()))
115
+ return out
116
+
117
+
118
+ # ---------------------------------------------------------------------------
119
+ # Model loading β€” handles HF IDs, full-weight saves, and PEFT adapters
120
+ # ---------------------------------------------------------------------------
121
+
122
+ def _ensure_chat_template(
123
+ tokenizer: AutoTokenizer,
124
+ fallback_model: str = "Qwen/Qwen2.5-Math-1.5B-Instruct",
125
+ ) -> None:
126
+ """Load chat template from *fallback_model* when the checkpoint lacks one.
127
+
128
+ SFT adapter checkpoints often omit the chat_template from their tokenizer
129
+ config. Without it, ``apply_chat_template`` raises a TemplateError that
130
+ is silently swallowed inside ``evaluate_gsm8k``, returning 0% accuracy.
131
+ """
132
+ if tokenizer.chat_template is not None:
133
+ return
134
+ logger.info("Tokenizer missing chat_template β€” loading from %s", fallback_model)
135
+ try:
136
+ _base_tok = AutoTokenizer.from_pretrained(fallback_model, trust_remote_code=True)
137
+ if _base_tok.chat_template is not None:
138
+ tokenizer.chat_template = _base_tok.chat_template
139
+ logger.info("Chat template loaded.")
140
+ except Exception as exc:
141
+ logger.warning("Could not load chat template: %s", exc)
142
+
143
+
144
+ def _load_model(
145
+ checkpoint: str,
146
+ base_model_id: str,
147
+ device: torch.device,
148
+ dtype: torch.dtype,
149
+ attn_impl: str,
150
+ ) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
151
+ """Load model + tokenizer from any checkpoint style.
152
+
153
+ Handles:
154
+ * HuggingFace model ID (e.g. ``Qwen/Qwen2.5-Math-1.5B-Instruct``)
155
+ * GRPO full-weight save (directory with ``model.safetensors`` / pytorch_model*)
156
+ * PEFT/LoRA adapter dir (directory with ``adapter_config.json``)
157
+ """
158
+ # PEFT shim β€” prevents crash in merge_and_unload on some versions.
159
+ if "transformers.integrations.tensor_parallel" not in sys.modules:
160
+ sys.modules["transformers.integrations.tensor_parallel"] = types.ModuleType(
161
+ "tensor_parallel"
162
+ )
163
+
164
+ ckpt_path = Path(checkpoint)
165
+ is_adapter = ckpt_path.is_dir() and (ckpt_path / "adapter_config.json").exists()
166
+ is_local_full = ckpt_path.is_dir() and not is_adapter
167
+
168
+ # Tokenizer
169
+ tok_src = checkpoint if (ckpt_path.is_dir() and (ckpt_path / "tokenizer_config.json").exists()) else base_model_id
170
+ tokenizer = AutoTokenizer.from_pretrained(tok_src, trust_remote_code=True)
171
+ if tokenizer.pad_token is None:
172
+ tokenizer.pad_token = tokenizer.eos_token
173
+ tokenizer.padding_side = "left" # standard for generation
174
+ _ensure_chat_template(tokenizer, fallback_model=base_model_id)
175
+
176
+ load_kw = dict(
177
+ torch_dtype=dtype,
178
+ low_cpu_mem_usage=True,
179
+ device_map={"": device},
180
+ trust_remote_code=True,
181
+ attn_implementation=attn_impl,
182
+ )
183
+
184
+ if is_adapter:
185
+ # Read base model from pipeline_meta.json if present
186
+ meta_file = ckpt_path / "pipeline_meta.json"
187
+ _base = base_model_id
188
+ if meta_file.exists():
189
+ _base = json.loads(meta_file.read_text()).get("base_model", _base)
190
+ logger.info("PEFT adapter β€” loading base %s then merging %s", _base, checkpoint)
191
+ _base_mdl = AutoModelForCausalLM.from_pretrained(_base, **load_kw)
192
+ model = PeftModel.from_pretrained(_base_mdl, checkpoint).merge_and_unload()
193
+ model = model.to(device)
194
+ else:
195
+ # Full weights (GRPO save) or HF model ID
196
+ src = checkpoint if is_local_full else checkpoint
197
+ logger.info("Loading full-weight model from %s", src)
198
+ model = AutoModelForCausalLM.from_pretrained(src, **load_kw)
199
+
200
+ # Re-enable requires_grad isn't needed for eval, but ensure eval mode.
201
+ model.eval()
202
+ n = sum(p.numel() for p in model.parameters())
203
+ logger.info("Loaded: %s (%.2fB params, %.1f GB VRAM est.)",
204
+ checkpoint, n / 1e9, n * 2 / 1e9)
205
+ return model, tokenizer
206
+
207
+
208
+ # ---------------------------------------------------------------------------
209
+ # Generation
210
+ # ---------------------------------------------------------------------------
211
+
212
+ def _build_prompt(tokenizer: AutoTokenizer, question: str) -> str:
213
+ """Format question using the model's chat template (matches training format)."""
214
+ if tokenizer.chat_template is None:
215
+ return question
216
+ msgs = [
217
+ {"role": "system", "content": "You are a helpful math assistant. Solve the problem step-by-step and end with 'Final Answer: <number>'."},
218
+ {"role": "user", "content": question},
219
+ ]
220
+ try:
221
+ return tokenizer.apply_chat_template(
222
+ msgs, tokenize=False, add_generation_prompt=True
223
+ )
224
+ except Exception:
225
+ return question
226
+
227
+
228
+ def _stop_ids(tokenizer: AutoTokenizer) -> List[int]:
229
+ ids = []
230
+ if tokenizer.eos_token_id is not None:
231
+ ids.append(tokenizer.eos_token_id)
232
+ im_end = tokenizer.convert_tokens_to_ids("<|im_end|>")
233
+ if isinstance(im_end, int) and im_end not in ids:
234
+ ids.append(im_end)
235
+ return ids or None # type: ignore[return-value]
236
+
237
+
238
+ @torch.no_grad()
239
+ def _generate(
240
+ model: AutoModelForCausalLM,
241
+ tokenizer: AutoTokenizer,
242
+ question: str,
243
+ max_new_tokens: int,
244
+ device: torch.device,
245
+ ) -> str:
246
+ prompt = _build_prompt(tokenizer, question)
247
+ enc = tokenizer(
248
+ prompt,
249
+ return_tensors="pt",
250
+ truncation=True,
251
+ max_length=1024,
252
+ ).to(device)
253
+ prompt_len = enc["input_ids"].shape[1]
254
+
255
+ out = model.generate(
256
+ input_ids=enc["input_ids"],
257
+ attention_mask=enc["attention_mask"],
258
+ max_new_tokens=max_new_tokens,
259
+ do_sample=False, # greedy β€” deterministic for reproducibility
260
+ temperature=1.0,
261
+ pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
262
+ eos_token_id=_stop_ids(tokenizer),
263
+ use_cache=True,
264
+ )
265
+ return tokenizer.decode(out[0][prompt_len:], skip_special_tokens=True)
266
+
267
+
268
+ # ---------------------------------------------------------------------------
269
+ # Scoring
270
+ # ---------------------------------------------------------------------------
271
+
272
+ def _normalize(x: str) -> str:
273
+ if not x:
274
+ return ""
275
+ s = x.strip().replace(",", "").replace("$", "").strip()
276
+ try:
277
+ f = float(s)
278
+ return f"{int(f)}" if f == int(f) else f"{f}"
279
+ except ValueError:
280
+ return s
281
+
282
+
283
+ @dataclass
284
+ class Record:
285
+ question: str
286
+ gold: str
287
+ pred: str
288
+ correct: bool
289
+ solution_text: str
290
+
291
+
292
+ def _score_model(
293
+ model: AutoModelForCausalLM,
294
+ tokenizer: AutoTokenizer,
295
+ problems: List[Problem],
296
+ max_new_tokens: int,
297
+ device: torch.device,
298
+ label: str,
299
+ ) -> Tuple[int, List[Record]]:
300
+ records: List[Record] = []
301
+ correct = 0
302
+ for prob in tqdm(problems, desc=f"Scoring {label}", unit="q", dynamic_ncols=True):
303
+ try:
304
+ text = _generate(model, tokenizer, prob.question, max_new_tokens, device)
305
+ except Exception as exc:
306
+ text = f"[generation error: {exc}]"
307
+ pred = extract_final_answer_numeric_str(text) or ""
308
+ ok = bool(pred) and _normalize(pred) == _normalize(prob.gold_final)
309
+ if ok:
310
+ correct += 1
311
+ records.append(Record(
312
+ question=prob.question,
313
+ gold=prob.gold_final,
314
+ pred=pred,
315
+ correct=ok,
316
+ solution_text=text,
317
+ ))
318
+ return correct, records
319
+
320
+
321
+ # ---------------------------------------------------------------------------
322
+ # Metrics curve
323
+ # ---------------------------------------------------------------------------
324
+
325
+ def _load_metrics_curve(path: Path) -> List[Dict]:
326
+ """Read metrics.jsonl and return rows that contain GSM8K accuracy."""
327
+ rows = []
328
+ if not path.exists():
329
+ return rows
330
+ with path.open(encoding="utf-8") as f:
331
+ for line in f:
332
+ line = line.strip()
333
+ if not line:
334
+ continue
335
+ try:
336
+ obj = json.loads(line)
337
+ if "accuracy" in obj or "iteration" in obj:
338
+ rows.append(obj)
339
+ except json.JSONDecodeError:
340
+ pass
341
+ return rows
342
+
343
+
344
+ def _print_curve(rows: List[Dict]) -> None:
345
+ if not rows:
346
+ return
347
+ print(f"\n{_SEP}")
348
+ print("TRAINING ACCURACY CURVE (from metrics.jsonl)")
349
+ print(_SEP)
350
+ print(f"{'Iter':>5} {'GSM8K%':>7} {'Reward':>7} {'Batch%':>7} {'LR':>10} {'Time(s)':>8}")
351
+ print(_SEP2)
352
+ for r in rows:
353
+ it = r.get("iteration", "")
354
+ acc = r.get("accuracy", None)
355
+ rwd = r.get("mean_reward", None)
356
+ bat = r.get("batch_accuracy", None)
357
+ lr = r.get("learning_rate", None)
358
+ ts = r.get("iter_time_s", None)
359
+ acc_s = f"{100*acc:.1f}%" if acc is not None else "β€”"
360
+ rwd_s = f"{rwd:.3f}" if rwd is not None else "β€”"
361
+ bat_s = f"{100*bat:.1f}%" if bat is not None else "β€”"
362
+ lr_s = f"{lr:.2e}" if lr is not None else "β€”"
363
+ ts_s = f"{ts:.1f}" if ts is not None else "β€”"
364
+ print(f"{it:>5} {acc_s:>7} {rwd_s:>7} {bat_s:>7} {lr_s:>10} {ts_s:>8}")
365
+ print()
366
+
367
+
368
+ # ---------------------------------------------------------------------------
369
+ # Output
370
+ # ---------------------------------------------------------------------------
371
+
372
+ def _print_summary(
373
+ base_correct: int,
374
+ tr_correct: int,
375
+ base_records: List[Record],
376
+ tr_records: List[Record],
377
+ baseline_name: str,
378
+ trained_name: str,
379
+ n_solutions: int = 3,
380
+ ) -> None:
381
+ n = len(base_records)
382
+ wins = [(p, b, t) for p, b, t in zip(base_records, base_records, tr_records) if not b.correct and t.correct]
383
+ losses = [(p, b, t) for p, b, t in zip(base_records, base_records, tr_records) if b.correct and not t.correct]
384
+ both_wrong = sum(1 for b, t in zip(base_records, tr_records) if not b.correct and not t.correct)
385
+ both_right = sum(1 for b, t in zip(base_records, tr_records) if b.correct and t.correct)
386
+
387
+ delta = tr_correct - base_correct
388
+ sign = "+" if delta >= 0 else ""
389
+
390
+ print(f"\n{_SEP}")
391
+ print("BEFORE vs AFTER β€” GSM8K accuracy (greedy decoding, fixed seed)")
392
+ print(_SEP)
393
+ print(f" Baseline : {baseline_name}")
394
+ print(f" Trained : {trained_name}")
395
+ print(_SEP2)
396
+ print(f" Baseline accuracy : {base_correct}/{n} ({100*base_correct/n:.1f}%)")
397
+ print(f" Trained accuracy : {tr_correct}/{n} ({100*tr_correct/n:.1f}%)")
398
+ print(f" Delta : {sign}{delta} problems ({sign}{100*delta/n:.1f} pp)")
399
+ print(_SEP2)
400
+ print(f" Newly correct (wins) : {len(wins)}")
401
+ print(f" Newly wrong (losses) : {len(losses)}")
402
+ print(f" Both correct : {both_right}")
403
+ print(f" Both wrong : {both_wrong}")
404
+ print(_SEP)
405
+
406
+ if wins:
407
+ print(f"\n{'='*78}")
408
+ print(f"WINS β€” problems the RL model now solves that the baseline could not")
409
+ print(f"{'='*78}")
410
+ for i, (_, base_r, tr_r) in enumerate(wins[:n_solutions]):
411
+ print(f"\n[Win {i+1}/{min(n_solutions, len(wins))}]")
412
+ _print_problem(base_r, tr_r)
413
+
414
+ if losses:
415
+ print(f"\n{'='*78}")
416
+ print(f"REGRESSIONS β€” problems the baseline solved but the RL model now misses")
417
+ print(f"{'='*78}")
418
+ for i, (_, base_r, tr_r) in enumerate(losses[:min(2, len(losses))]):
419
+ print(f"\n[Regression {i+1}/{min(2, len(losses))}]")
420
+ _print_problem(base_r, tr_r, is_regression=True)
421
+
422
+ print(f"\n{_SEP}")
423
+ pct_gain = 100 * delta / max(n - base_correct, 1)
424
+ print(f"SUMMARY: RL training fixed {len(wins)} problems, regressed {len(losses)}.")
425
+ print(f" Net: {sign}{delta} pts. Relative gain on previously-wrong: {pct_gain:+.1f}%")
426
+ print(_SEP)
427
+
428
+
429
+ def _print_problem(base_r: Record, tr_r: Record, is_regression: bool = False) -> None:
430
+ q = base_r.question
431
+ # Truncate long questions
432
+ if len(q) > 250:
433
+ q = q[:247] + "..."
434
+ print(f" Q : {q}")
435
+ print(f" Gold : {base_r.gold}")
436
+ if not is_regression:
437
+ print(f" Before : {base_r.pred!r:30s} βœ—")
438
+ print(f" After : {tr_r.pred!r:30s} βœ“")
439
+ # Show trained solution (truncated)
440
+ sol = tr_r.solution_text.strip()
441
+ if sol:
442
+ lines = sol.splitlines()
443
+ show = "\n ".join(lines[:12])
444
+ if len(lines) > 12:
445
+ show += f"\n ... ({len(lines)-12} more lines)"
446
+ print(f"\n Solution (trained model):\n {show}")
447
+ else:
448
+ print(f" Before : {base_r.pred!r:30s} βœ“")
449
+ print(f" After : {tr_r.pred!r:30s} βœ—")
450
+
451
+
452
+ # ---------------------------------------------------------------------------
453
+ # CLI
454
+ # ---------------------------------------------------------------------------
455
+
456
+ def main() -> int:
457
+ parser = argparse.ArgumentParser(
458
+ description=__doc__,
459
+ formatter_class=argparse.RawDescriptionHelpFormatter,
460
+ )
461
+ parser.add_argument(
462
+ "--baseline-model", default="checkpoints/dual_task_v1",
463
+ help="Pre-RL checkpoint. HF model ID, full-weight dir, or PEFT adapter dir.",
464
+ )
465
+ parser.add_argument(
466
+ "--trained-model", required=True,
467
+ help="Post-RL checkpoint (GRPO best_policy/ dir, or iteration checkpoint).",
468
+ )
469
+ parser.add_argument(
470
+ "--base-model-for-adapter", default="Qwen/Qwen2.5-Math-1.5B-Instruct",
471
+ help="Base model used when loading a PEFT adapter checkpoint.",
472
+ )
473
+ parser.add_argument(
474
+ "--problems", type=Path, default=Path("data/sft/gsm8k_sft.jsonl"),
475
+ help="JSONL eval set. Defaults to GSM8K training split (first --max-samples rows).",
476
+ )
477
+ parser.add_argument("--max-samples", type=int, default=100)
478
+ parser.add_argument("--max-new-tokens", type=int, default=512)
479
+ parser.add_argument(
480
+ "--metrics-jsonl", type=Path, default=None,
481
+ help="Path to metrics.jsonl from a GRPO run β€” prints the accuracy curve.",
482
+ )
483
+ parser.add_argument(
484
+ "--n-solutions", type=int, default=3,
485
+ help="Number of win/loss examples to print in full.",
486
+ )
487
+ parser.add_argument(
488
+ "--records-out", type=Path, default=None,
489
+ help="Save full per-problem JSON records here (for judge grading scripts).",
490
+ )
491
+ parser.add_argument(
492
+ "--device", default="cuda" if torch.cuda.is_available() else "cpu",
493
+ )
494
+ parser.add_argument(
495
+ "--dtype", default="bfloat16",
496
+ choices=["float32", "float16", "bfloat16"],
497
+ )
498
+ args = parser.parse_args()
499
+
500
+ if not args.problems.is_file():
501
+ logger.error("Problems file not found: %s", args.problems)
502
+ return 2
503
+
504
+ dtype_map = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16}
505
+ dtype = dtype_map[args.dtype]
506
+ device = torch.device(args.device)
507
+ attn = select_attn_implementation()
508
+ logger.info("Device: %s | dtype: %s | attn: %s", device, args.dtype, attn)
509
+
510
+ # Print training curve if available
511
+ if args.metrics_jsonl:
512
+ curve = _load_metrics_curve(args.metrics_jsonl)
513
+ _print_curve(curve)
514
+
515
+ problems = _load_problems(args.problems, args.max_samples)
516
+ if not problems:
517
+ logger.error("No problems loaded from %s", args.problems)
518
+ return 2
519
+ logger.info("Evaluating on %d problems from %s", len(problems), args.problems)
520
+
521
+ # ── Baseline ──────────────────────────────────────────────────────────
522
+ logger.info("%s\nScoring BASELINE: %s\n%s", _SEP, args.baseline_model, _SEP)
523
+ t0 = time.perf_counter()
524
+ base_model, base_tok = _load_model(
525
+ args.baseline_model, args.base_model_for_adapter, device, dtype, attn
526
+ )
527
+ base_correct, base_records = _score_model(
528
+ base_model, base_tok, problems, args.max_new_tokens, device, "baseline"
529
+ )
530
+ del base_model
531
+ if torch.cuda.is_available():
532
+ torch.cuda.empty_cache()
533
+ logger.info("Baseline done in %.1fs β€” accuracy: %d/%d (%.1f%%)",
534
+ time.perf_counter() - t0,
535
+ base_correct, len(problems),
536
+ 100 * base_correct / len(problems))
537
+
538
+ # ── Trained ───────────────────────────────────────────────────────────
539
+ logger.info("%s\nScoring TRAINED: %s\n%s", _SEP, args.trained_model, _SEP)
540
+ t0 = time.perf_counter()
541
+ tr_model, tr_tok = _load_model(
542
+ args.trained_model, args.base_model_for_adapter, device, dtype, attn
543
+ )
544
+ tr_correct, tr_records = _score_model(
545
+ tr_model, tr_tok, problems, args.max_new_tokens, device, "trained"
546
+ )
547
+ del tr_model
548
+ if torch.cuda.is_available():
549
+ torch.cuda.empty_cache()
550
+ logger.info("Trained done in %.1fs β€” accuracy: %d/%d (%.1f%%)",
551
+ time.perf_counter() - t0,
552
+ tr_correct, len(problems),
553
+ 100 * tr_correct / len(problems))
554
+
555
+ # ── Summary ───────────────────────────────────────────────────────────
556
+ _print_summary(
557
+ base_correct, tr_correct,
558
+ base_records, tr_records,
559
+ baseline_name=args.baseline_model,
560
+ trained_name=args.trained_model,
561
+ n_solutions=args.n_solutions,
562
+ )
563
+
564
+ # ── Save records ──────────────────────────────────────────────────────
565
+ if args.records_out:
566
+ args.records_out.parent.mkdir(parents=True, exist_ok=True)
567
+ payload = {
568
+ "baseline_model": args.baseline_model,
569
+ "trained_model": args.trained_model,
570
+ "n_problems": len(problems),
571
+ "baseline": {
572
+ "correct": base_correct,
573
+ "accuracy": base_correct / len(problems),
574
+ "records": [vars(r) for r in base_records],
575
+ },
576
+ "trained": {
577
+ "correct": tr_correct,
578
+ "accuracy": tr_correct / len(problems),
579
+ "records": [vars(r) for r in tr_records],
580
+ },
581
+ }
582
+ args.records_out.write_text(
583
+ json.dumps(payload, indent=2, ensure_ascii=False), encoding="utf-8"
584
+ )
585
+ logger.info("Per-problem records saved to %s", args.records_out)
586
+
587
+ return 0
588
+
589
+
590
+ if __name__ == "__main__":
591
+ sys.exit(main())
scripts/dual_task_sft_pipeline.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dual-task SFT pipeline: train model on both question generation and solution tasks.
3
+
4
+ This pipeline trains a single model that can:
5
+ 1. Generate math questions when prompted with "### Task: Generate Question"
6
+ 2. Solve math problems when prompted with "### Task: Solve Problem"
7
+
8
+ Examples
9
+ --------
10
+ # Train dual-task model
11
+ python scripts/dual_task_sft_pipeline.py train \\
12
+ --data data/sft/dual_task_train.jsonl \\
13
+ --output-dir checkpoints/dual_task_v1 \\
14
+ --epochs 2
15
+
16
+ # Infer - Question Generation
17
+ python scripts/dual_task_sft_pipeline.py infer \\
18
+ --adapter checkpoints/dual_task_v1 \\
19
+ --task generate \\
20
+ --prompt "Create a word problem about fractions and money requiring 3 steps."
21
+
22
+ # Infer - Solution Generation
23
+ python scripts/dual_task_sft_pipeline.py infer \\
24
+ --adapter checkpoints/dual_task_v1 \\
25
+ --task solve \\
26
+ --problem "Janet has 16 eggs. She eats 3. How many are left?"
27
+
28
+ Dependencies: torch, transformers, peft, datasets, accelerate, bitsandbytes, trl
29
+ """
30
+
31
+ from __future__ import annotations
32
+
33
+ import os
34
+
35
+ if "HF_HUB_DISABLE_XET" not in os.environ:
36
+ os.environ["HF_HUB_DISABLE_XET"] = "1"
37
+
38
+ import argparse
39
+ import json
40
+ import math
41
+ import sys
42
+ from pathlib import Path
43
+
44
+ ROOT = Path(__file__).resolve().parents[1]
45
+ sys.path.insert(0, str(ROOT))
46
+
47
+ from src.config.prompts import (
48
+ SOLVE_TASK_PREFIX,
49
+ GENERATE_TASK_PREFIX,
50
+ SOLVER_SYSTEM_PROMPT,
51
+ GENERATOR_SYSTEM_PROMPT,
52
+ )
53
+
54
+
55
+ def _warmup_steps_from_ratio(
56
+ num_examples: int,
57
+ per_device_train_batch_size: int,
58
+ gradient_accumulation_steps: int,
59
+ num_train_epochs: float,
60
+ warmup_ratio: float,
61
+ ) -> int:
62
+ """Calculate warmup steps from ratio."""
63
+ if warmup_ratio <= 0:
64
+ return 0
65
+ num_batches = max(
66
+ 1,
67
+ (num_examples + per_device_train_batch_size - 1) // per_device_train_batch_size,
68
+ )
69
+ num_update_steps_per_epoch = max(1, num_batches // gradient_accumulation_steps)
70
+ total_optimizer_steps = max(1, math.ceil(num_train_epochs * num_update_steps_per_epoch))
71
+ return min(total_optimizer_steps, int(total_optimizer_steps * warmup_ratio))
72
+
73
+
74
+ def cmd_train(args: argparse.Namespace) -> None:
75
+ try:
76
+ import torch
77
+ from datasets import load_dataset
78
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
79
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
80
+ from trl import SFTConfig, SFTTrainer
81
+ except ImportError as e:
82
+ raise SystemExit(
83
+ "Missing dependency for training. Install:\n"
84
+ " pip install torch transformers peft datasets accelerate bitsandbytes trl\n"
85
+ f"Original error: {e}"
86
+ ) from e
87
+
88
+ data_path = Path(args.data)
89
+ if not data_path.is_file():
90
+ raise SystemExit(f"Data file not found: {data_path}")
91
+
92
+ out_dir = Path(args.output_dir)
93
+ out_dir.mkdir(parents=True, exist_ok=True)
94
+
95
+ compute_dtype = getattr(torch, args.bnb_compute_dtype)
96
+ bnb_config = BitsAndBytesConfig(
97
+ load_in_4bit=True,
98
+ bnb_4bit_compute_dtype=compute_dtype,
99
+ bnb_4bit_quant_type="nf4",
100
+ bnb_4bit_use_double_quant=True,
101
+ )
102
+
103
+ tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
104
+ if tokenizer.pad_token is None:
105
+ tokenizer.pad_token = tokenizer.eos_token
106
+ tokenizer.padding_side = "right"
107
+
108
+ print(f"Loading model {args.model} …")
109
+ model = AutoModelForCausalLM.from_pretrained(
110
+ args.model,
111
+ quantization_config=bnb_config,
112
+ device_map="auto",
113
+ trust_remote_code=True,
114
+ dtype=compute_dtype,
115
+ )
116
+ model = prepare_model_for_kbit_training(model)
117
+
118
+ peft = LoraConfig(
119
+ r=args.lora_rank,
120
+ lora_alpha=args.lora_alpha,
121
+ lora_dropout=args.lora_dropout,
122
+ bias="none",
123
+ task_type="CAUSAL_LM",
124
+ target_modules=list(args.target_modules.split(",")),
125
+ )
126
+ model = get_peft_model(model, peft)
127
+ model.config.use_cache = False
128
+ model.print_trainable_parameters()
129
+
130
+ print(f"Loading dual-task dataset from {data_path} …")
131
+ ds = load_dataset("json", data_files=str(data_path), split="train")
132
+ if args.max_samples and args.max_samples > 0:
133
+ ds = ds.select(range(min(args.max_samples, len(ds))))
134
+
135
+ task_counts = {"solve": 0, "generate": 0, "unknown": 0}
136
+ for example in ds:
137
+ task_type = example.get("task_type", "unknown")
138
+ task_counts[task_type] = task_counts.get(task_type, 0) + 1
139
+
140
+ print(f"Dataset composition:")
141
+ print(f" Total examples: {len(ds)}")
142
+ print(f" Solve tasks: {task_counts['solve']} ({task_counts['solve']/len(ds):.1%})")
143
+ print(f" Generate tasks: {task_counts['generate']} ({task_counts['generate']/len(ds):.1%})")
144
+ if task_counts['unknown'] > 0:
145
+ print(f" Unknown tasks: {task_counts['unknown']}")
146
+
147
+ def formatting_func(example):
148
+ return tokenizer.apply_chat_template(
149
+ example["messages"],
150
+ tokenize=False,
151
+ add_generation_prompt=False,
152
+ )
153
+
154
+ if args.warmup_steps is not None:
155
+ warmup_steps = max(0, args.warmup_steps)
156
+ else:
157
+ warmup_steps = _warmup_steps_from_ratio(
158
+ len(ds),
159
+ args.batch_size,
160
+ args.grad_accum,
161
+ args.epochs,
162
+ args.warmup_ratio,
163
+ )
164
+
165
+ sft_args = SFTConfig(
166
+ output_dir=str(out_dir),
167
+ num_train_epochs=args.epochs,
168
+ per_device_train_batch_size=args.batch_size,
169
+ gradient_accumulation_steps=args.grad_accum,
170
+ learning_rate=args.learning_rate,
171
+ logging_steps=args.logging_steps,
172
+ save_steps=args.save_steps,
173
+ save_total_limit=3,
174
+ bf16=args.bf16 and torch.cuda.is_available(),
175
+ fp16=args.fp16 and torch.cuda.is_available() and not args.bf16,
176
+ max_length=args.max_seq_length,
177
+ warmup_steps=warmup_steps,
178
+ lr_scheduler_type="cosine",
179
+ report_to="none",
180
+ gradient_checkpointing=True,
181
+ )
182
+
183
+ print("\nStarting dual-task training...")
184
+ trainer = SFTTrainer(
185
+ model=model,
186
+ args=sft_args,
187
+ train_dataset=ds,
188
+ processing_class=tokenizer,
189
+ formatting_func=formatting_func,
190
+ )
191
+
192
+ trainer.train()
193
+ trainer.save_model(str(out_dir))
194
+ tokenizer.save_pretrained(str(out_dir))
195
+
196
+ with (out_dir / "pipeline_meta.json").open("w", encoding="utf-8") as f:
197
+ json.dump(
198
+ {
199
+ "pipeline_type": "dual_task",
200
+ "base_model": args.model,
201
+ "data": str(data_path),
202
+ "lora_rank": args.lora_rank,
203
+ "epochs": args.epochs,
204
+ "task_distribution": task_counts,
205
+ },
206
+ f,
207
+ indent=2,
208
+ )
209
+ print(f"\nSaved dual-task adapter and tokenizer to {out_dir}")
210
+
211
+
212
+ def cmd_infer(args: argparse.Namespace) -> None:
213
+ import torch
214
+ from peft import PeftModel
215
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
216
+
217
+ adapter = Path(args.adapter)
218
+ meta_path = adapter / "pipeline_meta.json"
219
+ base_model = args.base_model
220
+
221
+ if meta_path.is_file():
222
+ meta = json.loads(meta_path.read_text(encoding="utf-8"))
223
+ base_model = meta.get("base_model", base_model)
224
+ pipeline_type = meta.get("pipeline_type", "unknown")
225
+ if pipeline_type != "dual_task":
226
+ print(f"Warning: Adapter trained with pipeline_type='{pipeline_type}', expected 'dual_task'")
227
+
228
+ compute_dtype = getattr(torch, args.bnb_compute_dtype)
229
+ bnb_config = BitsAndBytesConfig(
230
+ load_in_4bit=True,
231
+ bnb_4bit_compute_dtype=compute_dtype,
232
+ bnb_4bit_quant_type="nf4",
233
+ bnb_4bit_use_double_quant=True,
234
+ )
235
+
236
+ tokenizer = AutoTokenizer.from_pretrained(adapter, trust_remote_code=True)
237
+ if tokenizer.pad_token is None:
238
+ tokenizer.pad_token = tokenizer.eos_token
239
+
240
+ print(f"Loading base {base_model} + adapter {adapter} …")
241
+ base = AutoModelForCausalLM.from_pretrained(
242
+ base_model,
243
+ quantization_config=bnb_config,
244
+ device_map="auto",
245
+ trust_remote_code=True,
246
+ )
247
+ model = PeftModel.from_pretrained(base, str(adapter))
248
+ model.eval()
249
+
250
+ if args.task == "solve":
251
+ system_prompt = SOLVER_SYSTEM_PROMPT
252
+ user_content = (
253
+ f"{SOLVE_TASK_PREFIX}"
254
+ "Solve the following problem. Show your reasoning as numbered steps, "
255
+ "then give the final numeric answer on the last line.\n\n"
256
+ f"Problem:\n{args.problem.strip()}"
257
+ )
258
+ elif args.task == "generate":
259
+ system_prompt = GENERATOR_SYSTEM_PROMPT
260
+ user_content = f"{GENERATE_TASK_PREFIX}{args.prompt.strip()}"
261
+ else:
262
+ raise ValueError(f"Unknown task: {args.task}. Must be 'solve' or 'generate'")
263
+
264
+ messages = [
265
+ {"role": "system", "content": system_prompt},
266
+ {"role": "user", "content": user_content},
267
+ ]
268
+
269
+ prompt = tokenizer.apply_chat_template(
270
+ messages, tokenize=False, add_generation_prompt=True
271
+ )
272
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
273
+
274
+ print(f"\nTask: {args.task}")
275
+ print(f"Prompt length: {inputs['input_ids'].shape[1]} tokens")
276
+ print("\nGenerating...")
277
+
278
+ with torch.no_grad():
279
+ out = model.generate(
280
+ **inputs,
281
+ max_new_tokens=args.max_new_tokens,
282
+ temperature=args.temperature,
283
+ top_p=args.top_p,
284
+ do_sample=not args.greedy,
285
+ pad_token_id=tokenizer.pad_token_id,
286
+ )
287
+
288
+ gen_ids = out[0, inputs["input_ids"].shape[1] :]
289
+ text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
290
+
291
+ print("\n" + "=" * 60)
292
+ print("Generated Output")
293
+ print("=" * 60)
294
+ print(text)
295
+ print("=" * 60)
296
+
297
+ if args.task == "solve":
298
+ print("\n--- Format Validation ---")
299
+ from src.sft.solution_format import validate_sympy_solution_format
300
+ r = validate_sympy_solution_format(text)
301
+ print(json.dumps(r.__dict__, indent=2))
302
+
303
+
304
+ def build_parser() -> argparse.ArgumentParser:
305
+ p = argparse.ArgumentParser(description="Dual-task SFT pipeline (train / infer)")
306
+ sub = p.add_subparsers(dest="command", required=True)
307
+
308
+ tr = sub.add_parser("train", help="Train dual-task model on mixed dataset")
309
+ tr.add_argument("--data", type=str, required=True, help="Dual-task training JSONL")
310
+ tr.add_argument("--output-dir", type=str, required=True, help="Output directory for adapter")
311
+ tr.add_argument("--model", type=str, default="Qwen/Qwen2.5-Math-1.5B-Instruct", help="Base model")
312
+ tr.add_argument("--epochs", type=float, default=2.0, help="Training epochs (default: 2.0 for dual-task)")
313
+ tr.add_argument("--batch-size", type=int, default=1)
314
+ tr.add_argument("--grad-accum", type=int, default=8)
315
+ tr.add_argument("--learning-rate", type=float, default=2e-4)
316
+ tr.add_argument("--max-samples", type=int, default=0, help="0 = use full dataset")
317
+ tr.add_argument("--lora-rank", type=int, default=16)
318
+ tr.add_argument("--lora-alpha", type=int, default=32)
319
+ tr.add_argument("--lora-dropout", type=float, default=0.05)
320
+ tr.add_argument(
321
+ "--target-modules",
322
+ type=str,
323
+ default="q_proj,v_proj,o_proj,gate_proj",
324
+ )
325
+ tr.add_argument("--max-seq-length", type=int, default=2048)
326
+ tr.add_argument("--save-steps", type=int, default=200)
327
+ tr.add_argument("--logging-steps", type=int, default=10)
328
+ tr.add_argument("--warmup-ratio", type=float, default=0.03)
329
+ tr.add_argument("--warmup-steps", type=int, default=None)
330
+ tr.add_argument("--bf16", action="store_true", default=True)
331
+ tr.add_argument("--no-bf16", dest="bf16", action="store_false")
332
+ tr.add_argument("--fp16", action="store_true")
333
+ tr.add_argument("--bnb-compute-dtype", type=str, default="bfloat16")
334
+ tr.set_defaults(func=cmd_train)
335
+
336
+ inf = sub.add_parser("infer", help="Generate with dual-task model")
337
+ inf.add_argument("--adapter", type=str, required=True, help="Adapter directory")
338
+ inf.add_argument(
339
+ "--base-model",
340
+ type=str,
341
+ default="Qwen/Qwen2.5-Math-1.5B-Instruct",
342
+ help="Base model (auto-detected from pipeline_meta.json if present)",
343
+ )
344
+ inf.add_argument(
345
+ "--task",
346
+ type=str,
347
+ required=True,
348
+ choices=["solve", "generate"],
349
+ help="Task type: 'solve' for problem solving, 'generate' for question generation",
350
+ )
351
+ inf.add_argument(
352
+ "--problem",
353
+ type=str,
354
+ default="",
355
+ help="Math problem to solve (required if --task solve)",
356
+ )
357
+ inf.add_argument(
358
+ "--prompt",
359
+ type=str,
360
+ default="",
361
+ help="Question generation prompt (required if --task generate)",
362
+ )
363
+ inf.add_argument("--max-new-tokens", type=int, default=1024)
364
+ inf.add_argument("--temperature", type=float, default=0.7)
365
+ inf.add_argument("--top-p", type=float, default=0.95)
366
+ inf.add_argument("--greedy", action="store_true", help="Use greedy decoding")
367
+ inf.add_argument("--bnb-compute-dtype", type=str, default="bfloat16")
368
+ inf.set_defaults(func=cmd_infer)
369
+
370
+ return p
371
+
372
+
373
+ def main() -> None:
374
+ parser = build_parser()
375
+ args = parser.parse_args()
376
+
377
+ if args.command == "infer":
378
+ if args.task == "solve" and not args.problem:
379
+ raise SystemExit("Error: --problem is required when --task solve")
380
+ if args.task == "generate" and not args.prompt:
381
+ raise SystemExit("Error: --prompt is required when --task generate")
382
+
383
+ if str(ROOT) not in sys.path:
384
+ sys.path.insert(0, str(ROOT))
385
+
386
+ args.func(args)
387
+
388
+
389
+ if __name__ == "__main__":
390
+ main()
scripts/eval_sft_inference.py ADDED
@@ -0,0 +1,565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Run batch inference for a trained QLoRA adapter and report quality metrics.
4
+
5
+ This helps decide whether another SFT epoch is needed before RL.
6
+
7
+ Examples
8
+ --------
9
+ # Evaluate on GSM8K test split (first 100 samples)
10
+ python scripts/eval_sft_inference.py \
11
+ --adapter checkpoints/gsm8k_sft \
12
+ --max-samples 100
13
+
14
+ # Evaluate on local JSONL with {question, answer} rows
15
+ python scripts/eval_sft_inference.py \
16
+ --adapter checkpoints/gsm8k_sft \
17
+ --source jsonl \
18
+ --input data/raw/gsm8k_test.jsonl \
19
+ --max-samples 50 \
20
+ --output-json reports/sft_eval.json
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ import argparse
26
+ import json
27
+ import os
28
+ import re
29
+ import sys
30
+ from dataclasses import asdict, dataclass
31
+ from pathlib import Path
32
+ from typing import Any, Optional
33
+
34
+ # Prefer classic HTTP Hub downloads by default.
35
+ if "HF_HUB_DISABLE_XET" not in os.environ:
36
+ os.environ["HF_HUB_DISABLE_XET"] = "1"
37
+
38
+ # Ensure project-root imports work when invoked as `python scripts/...`.
39
+ ROOT = Path(__file__).resolve().parents[1]
40
+ if str(ROOT) not in sys.path:
41
+ sys.path.insert(0, str(ROOT))
42
+
43
+ import torch
44
+ from datasets import load_dataset
45
+ from peft import PeftModel
46
+ from sympy import simplify
47
+ from sympy.parsing.sympy_parser import parse_expr
48
+ from tqdm.auto import tqdm
49
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
50
+
51
+ from scripts.convert_gsm8k_to_sft import parse_gsm8k_answer
52
+ from src.config.prompts import create_solver_messages
53
+ from src.sft.solution_format import extract_final_answer_numeric_str, validate_sympy_solution_format
54
+ from src.sft.sympy_normalize import normalize_for_parse_expr
55
+
56
+
57
+ @dataclass
58
+ class EvalRow:
59
+ index: int
60
+ question: str
61
+ gold_final: str
62
+ pred_final: str
63
+ exact_match: Optional[bool]
64
+ format_ok: bool
65
+ step_count: int
66
+ scratchpad_leak: bool
67
+ output_text: str
68
+
69
+
70
+ def _norm_expr(s: str) -> str:
71
+ s = s.strip()
72
+ s = s.replace("^", "**")
73
+ s = re.sub(r"[,$€£\s]+", "", s)
74
+ return s
75
+
76
+
77
+ def _equiv_expr(a: str, b: str) -> Optional[bool]:
78
+ """Check if two answer strings are mathematically equivalent.
79
+
80
+ Uses the same normalization as CurriculumMathEnvironment._answers_equivalent
81
+ so eval and training agree on what counts as "correct".
82
+ """
83
+ if not a or not b:
84
+ return None
85
+ a_n = normalize_for_parse_expr(_norm_expr(a))
86
+ b_n = normalize_for_parse_expr(_norm_expr(b))
87
+ try:
88
+ return bool(simplify(parse_expr(a_n) - parse_expr(b_n)) == 0)
89
+ except Exception:
90
+ return a_n == b_n
91
+
92
+
93
+ def _iter_examples(args: argparse.Namespace) -> list[dict[str, str]]:
94
+ rows: list[dict[str, str]] = []
95
+ if args.source == "hf":
96
+ ds = load_dataset(args.dataset, args.config, split=args.split)
97
+ if args.max_samples > 0:
98
+ ds = ds.select(range(min(args.max_samples, len(ds))))
99
+ for row in ds:
100
+ _, final = parse_gsm8k_answer(row["answer"])
101
+ rows.append({"question": row["question"].strip(), "gold_final": final})
102
+ return rows
103
+
104
+ in_path = Path(args.input)
105
+ if not in_path.is_file():
106
+ raise SystemExit(f"Input JSONL not found: {in_path}")
107
+ with in_path.open(encoding="utf-8") as f:
108
+ for line in f:
109
+ if args.max_samples > 0 and len(rows) >= args.max_samples:
110
+ break
111
+ line = line.strip()
112
+ if not line:
113
+ continue
114
+ o = json.loads(line)
115
+ if "question" in o and "answer" in o:
116
+ _, final = parse_gsm8k_answer(o["answer"])
117
+ rows.append({"question": o["question"].strip(), "gold_final": final})
118
+ continue
119
+ if "messages" in o:
120
+ user = next((m["content"] for m in o["messages"] if m.get("role") == "user"), "").strip()
121
+ asst = next((m["content"] for m in o["messages"] if m.get("role") == "assistant"), "")
122
+ gold = extract_final_answer_numeric_str(asst) or ""
123
+ user = re.sub(r"^Solve the following problem\..*?Problem:\n", "", user, flags=re.S)
124
+ rows.append({"question": user.strip(), "gold_final": gold.strip()})
125
+ continue
126
+ raise SystemExit("JSONL rows must contain either {question, answer} or {messages}.")
127
+ return rows
128
+
129
+
130
+ def _generate(
131
+ model: Any,
132
+ tokenizer: Any,
133
+ problem: str,
134
+ max_new_tokens: int,
135
+ temperature: float,
136
+ top_p: float,
137
+ greedy: bool,
138
+ ) -> str:
139
+ # Use the canonical solver prompt (same system + user format as GRPO training)
140
+ # so eval measures the model under the exact distribution it was trained on.
141
+ messages = create_solver_messages(problem.strip())
142
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
143
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
144
+ # HuggingFace warns once-per-call when `temperature`/`top_p` are passed
145
+ # alongside `do_sample=False`. Skip those kwargs entirely in greedy mode
146
+ # so long eval loops don't spam the log.
147
+ gen_kwargs = {
148
+ "max_new_tokens": max_new_tokens,
149
+ "do_sample": not greedy,
150
+ "pad_token_id": tokenizer.pad_token_id,
151
+ }
152
+ if not greedy:
153
+ gen_kwargs["temperature"] = temperature
154
+ gen_kwargs["top_p"] = top_p
155
+ with torch.no_grad():
156
+ out = model.generate(**inputs, **gen_kwargs)
157
+ gen_ids = out[0, inputs["input_ids"].shape[1] :]
158
+ return tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
159
+
160
+
161
+ def main() -> None:
162
+ p = argparse.ArgumentParser(description="Batch eval for SFT adapter inference.")
163
+ p.add_argument("--adapter", type=Path, required=True, help="Adapter directory from training step.")
164
+ p.add_argument("--base-model", type=str, default="Qwen/Qwen2.5-Math-1.5B-Instruct")
165
+ p.add_argument("--source", choices=("hf", "jsonl"), default="hf")
166
+ p.add_argument("--dataset", type=str, default="openai/gsm8k")
167
+ p.add_argument("--config", type=str, default="main")
168
+ p.add_argument("--split", type=str, default="test")
169
+ p.add_argument("--input", type=Path, help="JSONL path for --source jsonl")
170
+ p.add_argument("--max-samples", type=int, default=100)
171
+ p.add_argument("--max-new-tokens", type=int, default=512)
172
+ p.add_argument("--temperature", type=float, default=0.0)
173
+ p.add_argument("--top-p", type=float, default=1.0)
174
+ p.add_argument("--greedy", action="store_true", default=True)
175
+ p.add_argument("--no-greedy", dest="greedy", action="store_false")
176
+ p.add_argument("--bnb-compute-dtype", type=str, default="bfloat16")
177
+ p.add_argument("--show-samples", type=int, default=3)
178
+ p.add_argument("--output-json", type=Path, default=None)
179
+ args = p.parse_args()
180
+
181
+ if args.source == "jsonl" and not args.input:
182
+ raise SystemExit("--input is required when --source jsonl")
183
+
184
+ meta_path = args.adapter / "pipeline_meta.json"
185
+ base_model = args.base_model
186
+ if meta_path.is_file():
187
+ meta = json.loads(meta_path.read_text(encoding="utf-8"))
188
+ base_model = meta.get("base_model", base_model)
189
+
190
+ rows = _iter_examples(args)
191
+ if not rows:
192
+ raise SystemExit("No evaluation examples loaded.")
193
+ print(f"Loaded {len(rows)} evaluation examples.")
194
+
195
+ compute_dtype = getattr(torch, args.bnb_compute_dtype)
196
+ bnb_config = BitsAndBytesConfig(
197
+ load_in_4bit=True,
198
+ bnb_4bit_compute_dtype=compute_dtype,
199
+ bnb_4bit_quant_type="nf4",
200
+ bnb_4bit_use_double_quant=True,
201
+ )
202
+
203
+ print(f"Loading base {base_model} + adapter {args.adapter} …")
204
+ tokenizer = AutoTokenizer.from_pretrained(args.adapter, trust_remote_code=True)
205
+ if tokenizer.pad_token is None:
206
+ tokenizer.pad_token = tokenizer.eos_token
207
+ base = AutoModelForCausalLM.from_pretrained(
208
+ base_model,
209
+ quantization_config=bnb_config,
210
+ device_map="auto",
211
+ trust_remote_code=True,
212
+ )
213
+ model = PeftModel.from_pretrained(base, str(args.adapter))
214
+ model.eval()
215
+
216
+ results: list[EvalRow] = []
217
+ for i, row in enumerate(rows):
218
+ text = _generate(
219
+ model=model,
220
+ tokenizer=tokenizer,
221
+ problem=row["question"],
222
+ max_new_tokens=args.max_new_tokens,
223
+ temperature=args.temperature,
224
+ top_p=args.top_p,
225
+ greedy=args.greedy,
226
+ )
227
+ fmt = validate_sympy_solution_format(text)
228
+ pred_final = extract_final_answer_numeric_str(text) or ""
229
+ exact = _equiv_expr(pred_final, row["gold_final"])
230
+ results.append(
231
+ EvalRow(
232
+ index=i,
233
+ question=row["question"],
234
+ gold_final=row["gold_final"],
235
+ pred_final=pred_final,
236
+ exact_match=exact,
237
+ format_ok=fmt.ok,
238
+ step_count=fmt.step_count,
239
+ scratchpad_leak=("<<" in text and ">>" in text),
240
+ output_text=text,
241
+ )
242
+ )
243
+ if i < args.show_samples:
244
+ print(f"\n=== Sample {i} ===")
245
+ print("Q:", row["question"])
246
+ print("Gold:", row["gold_final"])
247
+ print("Pred:", pred_final)
248
+ print("Format OK:", fmt.ok, "| Steps:", fmt.step_count)
249
+ print(text)
250
+
251
+ n = len(results)
252
+ n_format_ok = sum(1 for r in results if r.format_ok)
253
+ n_scratch = sum(1 for r in results if r.scratchpad_leak)
254
+ em_scored = [r for r in results if r.exact_match is not None]
255
+ n_em = sum(1 for r in em_scored if r.exact_match)
256
+
257
+ print("\n=== Summary ===")
258
+ print(f"Samples: {n}")
259
+ print(f"Format OK: {n_format_ok}/{n} ({100.0 * n_format_ok / n:.2f}%)")
260
+ print(f"Scratchpad leakage (<< >>): {n_scratch}/{n} ({100.0 * n_scratch / n:.2f}%)")
261
+ if em_scored:
262
+ print(f"Exact match (final answer): {n_em}/{len(em_scored)} ({100.0 * n_em / len(em_scored):.2f}%)")
263
+ else:
264
+ print("Exact match (final answer): N/A (missing gold labels)")
265
+
266
+ if args.output_json is not None:
267
+ args.output_json.parent.mkdir(parents=True, exist_ok=True)
268
+ payload = {
269
+ "summary": {
270
+ "samples": n,
271
+ "format_ok": n_format_ok,
272
+ "format_ok_rate": n_format_ok / n,
273
+ "scratchpad_leakage": n_scratch,
274
+ "scratchpad_leakage_rate": n_scratch / n,
275
+ "exact_match_scored": len(em_scored),
276
+ "exact_match": n_em,
277
+ "exact_match_rate": (n_em / len(em_scored)) if em_scored else None,
278
+ },
279
+ "results": [asdict(r) for r in results],
280
+ }
281
+ args.output_json.write_text(json.dumps(payload, indent=2, ensure_ascii=False), encoding="utf-8")
282
+ print(f"Wrote detailed report to {args.output_json}")
283
+
284
+
285
+ def _infer_dataset_name(data_path: str) -> str:
286
+ """Derive a short human-readable dataset label from the file path."""
287
+ stem = Path(data_path).stem.lower() # e.g. "aqua_validation", "gsm8k_test"
288
+ if "aqua" in stem:
289
+ return "AQuA-RAT"
290
+ if "math" in stem:
291
+ return "MATH"
292
+ if "gsm" in stem:
293
+ return "GSM8K"
294
+ return Path(data_path).stem # fallback: raw filename stem
295
+
296
+
297
+ def evaluate_gsm8k(
298
+ model: Any,
299
+ tokenizer: Any,
300
+ data_path: str = "data/sft/gsm8k_test.jsonl",
301
+ max_samples: int = 500,
302
+ max_new_tokens: int = 512,
303
+ temperature: float = 0.0,
304
+ top_p: float = 1.0,
305
+ reward_fn: Any = None,
306
+ pass_at_k: int = 0,
307
+ dataset_name: str = "",
308
+ pass_at_k_temperature: float = 0.8,
309
+ ) -> dict:
310
+ """
311
+ Evaluate *model* on a math JSONL file using the SAME scoring
312
+ function used during GRPO training.
313
+
314
+ Args:
315
+ model : AutoModelForCausalLM (already on correct device).
316
+ tokenizer : Matching AutoTokenizer.
317
+ data_path : Path to JSONL with {question, answer} rows.
318
+ max_samples : Evaluation cap.
319
+ max_new_tokens / temperature / top_p : generation hyper-params.
320
+ reward_fn : callable(question: str, solution: str, gold: str) -> dict
321
+ Must return at minimum {"combined_score": float} and
322
+ optionally {"gt_match": bool, "prm_mean_score": float,
323
+ "sympy_score": float, "format_score": float}.
324
+ When supplied the primary accuracy metric becomes the
325
+ mean combined_score β€” identical to the GRPO training
326
+ objective β€” so every component (correctness, PRM step
327
+ quality, SymPy verification, format) contributes and
328
+ improvements in any of them show up immediately.
329
+ When None the function falls back to final-answer
330
+ exact-match accuracy (coarse binary).
331
+
332
+ Returns dict keys:
333
+ accuracy – mean combined_score per solution (or exact-match if no reward_fn)
334
+ combined_score – same as accuracy (alias)
335
+ correct_rate – fraction of solutions with gt_match == True
336
+ prm_mean – mean PRM step-quality score per solution
337
+ sympy_mean – mean SymPy verification score
338
+ format_mean – mean format compliance score
339
+ n_scored – solutions successfully scored by reward_fn
340
+ total – total solutions evaluated
341
+ # fallback (no reward_fn):
342
+ exact_match_rate – fraction of final answers matching gold
343
+ """
344
+ import logging as _logging
345
+ _logger = _logging.getLogger(__name__)
346
+
347
+ greedy = temperature < 1e-6
348
+ rows: list[dict] = []
349
+
350
+ p = Path(data_path)
351
+ if p.is_file():
352
+ with p.open(encoding="utf-8") as fh:
353
+ for line in fh:
354
+ if max_samples > 0 and len(rows) >= max_samples:
355
+ break
356
+ line = line.strip()
357
+ if not line:
358
+ continue
359
+ obj = json.loads(line)
360
+ if "question" in obj and "gold_final" in obj and obj["gold_final"]:
361
+ # Pre-extracted format (our gsm8k_test.jsonl)
362
+ rows.append({"question": obj["question"].strip(), "gold_final": obj["gold_final"].strip()})
363
+ elif "question" in obj and "answer" in obj:
364
+ _, final = parse_gsm8k_answer(obj["answer"])
365
+ if final:
366
+ rows.append({"question": obj["question"].strip(), "gold_final": final})
367
+ elif "messages" in obj:
368
+ task_type = obj.get("task_type", "solve")
369
+ if task_type != "solve":
370
+ continue # skip question-generation entries
371
+ user = next(
372
+ (m["content"] for m in obj["messages"] if m.get("role") == "user"), ""
373
+ ).strip()
374
+ asst = next(
375
+ (m["content"] for m in obj["messages"] if m.get("role") == "assistant"), ""
376
+ )
377
+ gold = extract_final_answer_numeric_str(asst) or ""
378
+ if not gold:
379
+ continue # skip entries with no parseable gold answer
380
+ user = re.sub(r"^Solve the following problem\..*?Problem:\n", "", user, flags=re.S)
381
+ rows.append({"question": user.strip(), "gold_final": gold.strip()})
382
+ else:
383
+ _logger.warning(
384
+ f"evaluate_gsm8k: {data_path} not found; loading openai/gsm8k from Hub."
385
+ )
386
+ try:
387
+ ds = load_dataset("openai/gsm8k", "main", split="test")
388
+ if max_samples > 0:
389
+ ds = ds.select(range(min(max_samples, len(ds))))
390
+ for row in ds:
391
+ _, final = parse_gsm8k_answer(row["answer"])
392
+ rows.append({"question": row["question"].strip(), "gold_final": final})
393
+ except Exception as exc:
394
+ _logger.error(f"Could not load GSM8K: {exc}")
395
+ return {"accuracy": 0.0, "correct": 0, "total": 0, "exact_match_rate": 0.0}
396
+
397
+ if not rows:
398
+ return {"accuracy": 0.0, "correct": 0, "total": 0, "exact_match_rate": 0.0}
399
+
400
+ correct = 0
401
+ total = len(rows)
402
+ _n_errors = 0
403
+ _MAX_ERROR_WARNINGS = 3
404
+
405
+ # Per-solution reward accumulators (populated when reward_fn is supplied).
406
+ _combined: list[float] = []
407
+ _gt_match: list[float] = []
408
+ _prm_comp: list[float] = []
409
+ _prm_final: list[float] = []
410
+ _step_acc: list[float] = [] # fraction of steps rated correct by PRM (>0.5)
411
+ _lccp: list[float] = [] # longest correct consecutive prefix ratio
412
+ _sympy_comp:list[float] = []
413
+ _fmt_comp: list[float] = []
414
+
415
+ # Pass@K accumulators: for each problem, did ANY of K samples get it right?
416
+ _pak_any_correct: list[int] = [] # 1 if any of K samples correct, else 0
417
+
418
+ _eval_label = dataset_name or _infer_dataset_name(data_path)
419
+ pbar = tqdm(
420
+ rows, total=total, desc=f"{_eval_label} eval",
421
+ unit="q", dynamic_ncols=True, leave=True,
422
+ )
423
+ for i, row in enumerate(pbar):
424
+ pred_text = ""
425
+ try:
426
+ pred_text = _generate(
427
+ model=model, tokenizer=tokenizer,
428
+ problem=row["question"],
429
+ max_new_tokens=max_new_tokens,
430
+ temperature=temperature, top_p=top_p, greedy=greedy,
431
+ )
432
+ pred_final = extract_final_answer_numeric_str(pred_text) or ""
433
+ if _equiv_expr(pred_final, row["gold_final"]):
434
+ correct += 1
435
+ except Exception as exc:
436
+ _n_errors += 1
437
+ if _n_errors <= _MAX_ERROR_WARNINGS:
438
+ _logger.warning(
439
+ "evaluate_gsm8k: sample %d raised %s: %s. "
440
+ "If all fail check that tokenizer has a chat_template.",
441
+ i, type(exc).__name__, exc,
442
+ )
443
+ elif _n_errors == _MAX_ERROR_WARNINGS + 1:
444
+ _logger.warning(
445
+ "evaluate_gsm8k: suppressing further errors (%d so far).",
446
+ _n_errors,
447
+ )
448
+ _logger.debug("Sample %d error: %s", i, exc, exc_info=True)
449
+
450
+ # ── Pass@K: sample K solutions at T=0.8 and check if any is correct ─
451
+ # This is the fair comparison to batch_acc during training (also K samples
452
+ # at T=0.8). Greedy (pass@1) is pessimistic; pass@k shows the upper bound
453
+ # the model can achieve with sampling, matching the training regime.
454
+ if pass_at_k > 1 and row.get("gold_final"):
455
+ _any = 0
456
+ for _ in range(pass_at_k):
457
+ try:
458
+ s = _generate(
459
+ model=model, tokenizer=tokenizer,
460
+ problem=row["question"],
461
+ max_new_tokens=max_new_tokens,
462
+ temperature=pass_at_k_temperature,
463
+ top_p=top_p, greedy=False,
464
+ )
465
+ pf = extract_final_answer_numeric_str(s) or ""
466
+ if _equiv_expr(pf, row["gold_final"]):
467
+ _any = 1
468
+ break
469
+ except Exception:
470
+ pass
471
+ _pak_any_correct.append(_any)
472
+
473
+ # ── Apply the SAME reward function used during GRPO training ──────────
474
+ if reward_fn is not None and pred_text:
475
+ try:
476
+ r = reward_fn(row["question"], pred_text, row["gold_final"])
477
+ _combined.append(float(r.get("combined_score", 0.0)))
478
+ _gt_match.append(1.0 if r.get("gt_match", False) else 0.0)
479
+ _prm_comp.append(float(r.get("prm_mean_score", 0.0)))
480
+ _prm_final.append(float(r.get("prm_final_score", 0.0)))
481
+ _step_acc.append(float(r.get("step_accuracy", 0.0)))
482
+ _lccp.append(float(r.get("lccp", 0.0)))
483
+ _sympy_comp.append(float(r.get("sympy_score", 0.0)))
484
+ _fmt_comp.append(float(r.get("format_score", 0.0)))
485
+ except Exception as rfn_exc:
486
+ _logger.debug("reward_fn failed for sample %d: %s", i, rfn_exc)
487
+
488
+ done = i + 1
489
+ # Periodically flush the CUDA allocator's free-block pool so that
490
+ # fragmentation from large KV-cache + PRM tensors doesn't accumulate
491
+ # and cause per-sample allocation time to grow throughout the run.
492
+ if done % 20 == 0:
493
+ import gc; gc.collect()
494
+ if torch.cuda.is_available():
495
+ torch.cuda.empty_cache()
496
+
497
+ # Live bar: show training-objective score when available, else acc.
498
+ if _combined:
499
+ _pf: dict = dict(
500
+ score=f"{sum(_combined) / len(_combined):.3f}",
501
+ correct=f"{sum(_gt_match):.0f}/{len(_combined)}",
502
+ step_acc=f"{sum(_step_acc)/len(_step_acc):.1%}" if _step_acc else "β€”",
503
+ lccp=f"{sum(_lccp)/len(_lccp):.1%}" if _lccp else "β€”",
504
+ )
505
+ else:
506
+ _pf = dict(acc=f"{correct / done:.1%}", correct=f"{correct}/{done}")
507
+ pbar.set_postfix(**_pf, refresh=False)
508
+
509
+ # ── Aggregate ──────────────────────────────────────────────────────────
510
+ n_scored = len(_combined)
511
+ _avg = lambda lst: round(sum(lst) / len(lst), 4) if lst else 0.0
512
+
513
+ # Pass@K: fraction of problems where any of K sampled solutions was correct.
514
+ pass_at_k_score = _avg(_pak_any_correct) if _pak_any_correct else None
515
+
516
+ if reward_fn is not None:
517
+ combined_score = _avg(_combined)
518
+ result: dict = {
519
+ # PRIMARY: mean training-objective score.
520
+ # Formula: 0.50Γ—correct + 0.40Γ—process(prm_final, prm_mean) + 0.10Γ—format
521
+ "accuracy": combined_score,
522
+ "combined_score": combined_score,
523
+ # PROCESS metrics β€” improve before correct_rate does
524
+ "step_accuracy": _avg(_step_acc),
525
+ "lccp": _avg(_lccp), # chain integrity: how far into solution stays correct
526
+ # Answer correctness
527
+ "correct_rate": _avg(_gt_match),
528
+ # PRM components
529
+ "prm_mean": _avg(_prm_comp),
530
+ "prm_final": _avg(_prm_final),
531
+ # Format / SymPy (informational)
532
+ "sympy_mean": _avg(_sympy_comp),
533
+ "format_mean": _avg(_fmt_comp),
534
+ "n_scored": n_scored,
535
+ "total": total,
536
+ "final_answer_correct": correct,
537
+ "final_answer_accuracy": correct / total if total else 0.0,
538
+ }
539
+ else:
540
+ _logger.warning(
541
+ "evaluate_gsm8k: no reward_fn provided β€” using final-answer accuracy. "
542
+ "Pass reward_fn=math_env.compute_grounded_reward for full training-objective eval."
543
+ )
544
+ fa_acc = correct / total if total else 0.0
545
+ result = {
546
+ "accuracy": fa_acc,
547
+ "combined_score": fa_acc,
548
+ "correct_rate": fa_acc,
549
+ "prm_mean": 0.0,
550
+ "sympy_mean": 0.0,
551
+ "format_mean": 0.0,
552
+ "n_scored": 0,
553
+ "total": total,
554
+ "final_answer_correct": correct,
555
+ "final_answer_accuracy": fa_acc,
556
+ }
557
+ # Attach pass@k if it was computed
558
+ if pass_at_k_score is not None:
559
+ result["pass_at_k"] = pass_at_k_score
560
+ result["pass_at_k_k"] = pass_at_k
561
+ return result
562
+
563
+
564
+ if __name__ == "__main__":
565
+ main()
scripts/gsm8k_sft_pipeline.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ End-to-end GSM8K pipeline: prepare JSONL β†’ QLoRA SFT β†’ save adapter β†’ inference.
4
+
5
+ The trained model follows ``Step N:`` / ``Final Answer:`` formatting with SymPy-friendly
6
+ expressions (see ``src.agent.math_agent.SOLVER_SYSTEM_PROMPT``).
7
+
8
+ Examples
9
+ --------
10
+ # 1) Only build training JSONL from Hugging Face GSM8K
11
+ python scripts/gsm8k_sft_pipeline.py prepare --output data/sft/gsm8k_sft.jsonl
12
+
13
+ # 2) Fine-tune (requires GPU recommended)
14
+ python scripts/gsm8k_sft_pipeline.py train \\
15
+ --data data/sft/gsm8k_sft.jsonl \\
16
+ --output-dir checkpoints/gsm8k_sft
17
+
18
+ # 3) Run inference with saved adapter
19
+ python scripts/gsm8k_sft_pipeline.py infer \\
20
+ --adapter checkpoints/gsm8k_sft \\
21
+ --problem \"Janet has 16 eggs. She eats 3. How many are left?\"
22
+
23
+ # Full chain
24
+ python scripts/gsm8k_sft_pipeline.py all --output-dir checkpoints/gsm8k_sft
25
+
26
+ Dependencies: torch, transformers, peft, datasets, accelerate, bitsandbytes, trl, sympy
27
+
28
+ Tip: if downloads fail with XET / "Background writer channel closed", export ``HF_HUB_DISABLE_XET=1``
29
+ before running (this script sets it by default unless already set).
30
+ """
31
+
32
+ from __future__ import annotations
33
+
34
+ import os
35
+
36
+ # hf-xet can error or segfault on interrupted/large shards; classic HTTP download is more robust.
37
+ if "HF_HUB_DISABLE_XET" not in os.environ:
38
+ os.environ["HF_HUB_DISABLE_XET"] = "1"
39
+
40
+ import argparse
41
+ import json
42
+ import math
43
+ import subprocess
44
+ import sys
45
+ from pathlib import Path
46
+
47
+ # Project root (…/Maths_LLM)
48
+ ROOT = Path(__file__).resolve().parents[1]
49
+
50
+
51
+ def cmd_prepare(args: argparse.Namespace) -> None:
52
+ cmd = [
53
+ sys.executable,
54
+ str(ROOT / "scripts" / "convert_gsm8k_to_sft.py"),
55
+ "--output",
56
+ str(Path(args.output)),
57
+ "--splits",
58
+ *args.splits,
59
+ ]
60
+ if args.source == "jsonl":
61
+ cmd.extend(["--source", "jsonl", "--input", str(args.input)])
62
+ print("Running:", " ".join(cmd))
63
+ subprocess.check_call(cmd, cwd=str(ROOT))
64
+ if args.strip_scratchpads:
65
+ _rewrite_jsonl_strip_scratchpads(Path(args.output))
66
+
67
+
68
+ def _rewrite_jsonl_strip_scratchpads(jsonl_path: Path) -> None:
69
+ from src.sft.solution_format import strip_gsm8k_scratchpads
70
+
71
+ tmp = jsonl_path.with_suffix(".jsonl.tmp")
72
+ n = 0
73
+ with jsonl_path.open(encoding="utf-8") as fin, tmp.open("w", encoding="utf-8") as fout:
74
+ for line in fin:
75
+ o = json.loads(line)
76
+ for m in o.get("messages", []):
77
+ if m.get("role") == "assistant":
78
+ m["content"] = strip_gsm8k_scratchpads(m["content"])
79
+ if "text" in o:
80
+ sys_p = next(x["content"] for x in o["messages"] if x["role"] == "system")
81
+ usr = next(x["content"] for x in o["messages"] if x["role"] == "user")
82
+ asst = next(x["content"] for x in o["messages"] if x["role"] == "assistant")
83
+ o["text"] = (
84
+ f"<|system|>\n{sys_p}\n<|user|>\n{usr}\n<|assistant|>\n{asst}"
85
+ )
86
+ fout.write(json.dumps(o, ensure_ascii=False) + "\n")
87
+ n += 1
88
+ tmp.replace(jsonl_path)
89
+ print(f"Stripped <<>> scratchpads in {n} records β†’ {jsonl_path}")
90
+
91
+
92
+ def _warmup_steps_from_ratio(
93
+ num_examples: int,
94
+ per_device_train_batch_size: int,
95
+ gradient_accumulation_steps: int,
96
+ num_train_epochs: float,
97
+ warmup_ratio: float,
98
+ ) -> int:
99
+ """Approximate HF Trainer optimizer steps; used to map legacy warmup_ratio β†’ warmup_steps."""
100
+ if warmup_ratio <= 0:
101
+ return 0
102
+ num_batches = max(
103
+ 1,
104
+ (num_examples + per_device_train_batch_size - 1) // per_device_train_batch_size,
105
+ )
106
+ num_update_steps_per_epoch = max(1, num_batches // gradient_accumulation_steps)
107
+ total_optimizer_steps = max(1, math.ceil(num_train_epochs * num_update_steps_per_epoch))
108
+ return min(total_optimizer_steps, int(total_optimizer_steps * warmup_ratio))
109
+
110
+
111
+ def cmd_train(args: argparse.Namespace) -> None:
112
+ try:
113
+ import torch
114
+ from datasets import load_dataset
115
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
116
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
117
+ from trl import SFTConfig, SFTTrainer
118
+ except ImportError as e:
119
+ raise SystemExit(
120
+ "Missing dependency for training. Install:\n"
121
+ " pip install torch transformers peft datasets accelerate bitsandbytes trl sympy\n"
122
+ f"Original error: {e}"
123
+ ) from e
124
+
125
+ data_path = Path(args.data)
126
+ if not data_path.is_file():
127
+ raise SystemExit(f"Data file not found: {data_path}")
128
+
129
+ out_dir = Path(args.output_dir)
130
+ out_dir.mkdir(parents=True, exist_ok=True)
131
+
132
+ compute_dtype = getattr(torch, args.bnb_compute_dtype)
133
+ bnb_config = BitsAndBytesConfig(
134
+ load_in_4bit=True,
135
+ bnb_4bit_compute_dtype=compute_dtype,
136
+ bnb_4bit_quant_type="nf4",
137
+ bnb_4bit_use_double_quant=True,
138
+ )
139
+
140
+ tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
141
+ if tokenizer.pad_token is None:
142
+ tokenizer.pad_token = tokenizer.eos_token
143
+ tokenizer.padding_side = "right"
144
+
145
+ print(f"Loading model {args.model} …")
146
+ model = AutoModelForCausalLM.from_pretrained(
147
+ args.model,
148
+ quantization_config=bnb_config,
149
+ device_map="auto",
150
+ trust_remote_code=True,
151
+ dtype=compute_dtype,
152
+ )
153
+ model = prepare_model_for_kbit_training(model)
154
+ peft = LoraConfig(
155
+ r=args.lora_rank,
156
+ lora_alpha=args.lora_alpha,
157
+ lora_dropout=args.lora_dropout,
158
+ bias="none",
159
+ task_type="CAUSAL_LM",
160
+ target_modules=list(args.target_modules.split(",")),
161
+ )
162
+ model = get_peft_model(model, peft)
163
+ model.config.use_cache = False
164
+ model.print_trainable_parameters()
165
+
166
+ ds = load_dataset("json", data_files=str(data_path), split="train")
167
+ if args.max_samples and args.max_samples > 0:
168
+ ds = ds.select(range(min(args.max_samples, len(ds))))
169
+
170
+ def formatting_func(example):
171
+ return tokenizer.apply_chat_template(
172
+ example["messages"],
173
+ tokenize=False,
174
+ add_generation_prompt=False,
175
+ )
176
+
177
+ if args.warmup_steps is not None:
178
+ warmup_steps = max(0, args.warmup_steps)
179
+ else:
180
+ warmup_steps = _warmup_steps_from_ratio(
181
+ len(ds),
182
+ args.batch_size,
183
+ args.grad_accum,
184
+ args.epochs,
185
+ args.warmup_ratio,
186
+ )
187
+
188
+ sft_args = SFTConfig(
189
+ output_dir=str(out_dir),
190
+ num_train_epochs=args.epochs,
191
+ per_device_train_batch_size=args.batch_size,
192
+ gradient_accumulation_steps=args.grad_accum,
193
+ learning_rate=args.learning_rate,
194
+ logging_steps=args.logging_steps,
195
+ save_steps=args.save_steps,
196
+ save_total_limit=3,
197
+ bf16=args.bf16 and torch.cuda.is_available(),
198
+ fp16=args.fp16 and torch.cuda.is_available() and not args.bf16,
199
+ max_length=args.max_seq_length,
200
+ warmup_steps=warmup_steps,
201
+ lr_scheduler_type="cosine",
202
+ report_to="none",
203
+ gradient_checkpointing=True,
204
+ )
205
+
206
+ trainer = SFTTrainer(
207
+ model=model,
208
+ args=sft_args,
209
+ train_dataset=ds,
210
+ processing_class=tokenizer,
211
+ formatting_func=formatting_func,
212
+ )
213
+
214
+ trainer.train()
215
+ trainer.save_model(str(out_dir))
216
+ tokenizer.save_pretrained(str(out_dir))
217
+
218
+ with (out_dir / "pipeline_meta.json").open("w", encoding="utf-8") as f:
219
+ json.dump(
220
+ {
221
+ "base_model": args.model,
222
+ "data": str(data_path),
223
+ "lora_rank": args.lora_rank,
224
+ "epochs": args.epochs,
225
+ },
226
+ f,
227
+ indent=2,
228
+ )
229
+ print(f"Saved adapter and tokenizer to {out_dir}")
230
+
231
+
232
+ def cmd_infer(args: argparse.Namespace) -> None:
233
+ import torch
234
+ from peft import PeftModel
235
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
236
+
237
+ from src.agent.math_agent import SOLVER_SYSTEM_PROMPT
238
+
239
+ adapter = Path(args.adapter)
240
+ meta_path = adapter / "pipeline_meta.json"
241
+ base_model = args.base_model
242
+ if meta_path.is_file():
243
+ meta = json.loads(meta_path.read_text(encoding="utf-8"))
244
+ base_model = meta.get("base_model", base_model)
245
+
246
+ compute_dtype = getattr(torch, args.bnb_compute_dtype)
247
+ bnb_config = BitsAndBytesConfig(
248
+ load_in_4bit=True,
249
+ bnb_4bit_compute_dtype=compute_dtype,
250
+ bnb_4bit_quant_type="nf4",
251
+ bnb_4bit_use_double_quant=True,
252
+ )
253
+
254
+ tokenizer = AutoTokenizer.from_pretrained(adapter, trust_remote_code=True)
255
+ if tokenizer.pad_token is None:
256
+ tokenizer.pad_token = tokenizer.eos_token
257
+
258
+ print(f"Loading base {base_model} + adapter {adapter} …")
259
+ base = AutoModelForCausalLM.from_pretrained(
260
+ base_model,
261
+ quantization_config=bnb_config,
262
+ device_map="auto",
263
+ trust_remote_code=True,
264
+ )
265
+ model = PeftModel.from_pretrained(base, str(adapter))
266
+ model.eval()
267
+
268
+ user_content = (
269
+ "Solve the following problem. Show your reasoning as numbered steps, "
270
+ "then give the final numeric answer on the last line.\n\n"
271
+ f"Problem:\n{args.problem.strip()}"
272
+ )
273
+ messages = [
274
+ {"role": "system", "content": SOLVER_SYSTEM_PROMPT},
275
+ {"role": "user", "content": user_content},
276
+ ]
277
+ prompt = tokenizer.apply_chat_template(
278
+ messages, tokenize=False, add_generation_prompt=True
279
+ )
280
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
281
+
282
+ with torch.no_grad():
283
+ out = model.generate(
284
+ **inputs,
285
+ max_new_tokens=args.max_new_tokens,
286
+ temperature=args.temperature,
287
+ top_p=args.top_p,
288
+ do_sample=not args.greedy,
289
+ pad_token_id=tokenizer.pad_token_id,
290
+ )
291
+
292
+ gen_ids = out[0, inputs["input_ids"].shape[1] :]
293
+ text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
294
+ print("\n--- Generated ---\n")
295
+ print(text)
296
+ print("\n--- Format check ---")
297
+ from src.sft.solution_format import validate_sympy_solution_format
298
+
299
+ r = validate_sympy_solution_format(text)
300
+ print(json.dumps(r.__dict__, indent=2))
301
+
302
+
303
+ def cmd_all(args: argparse.Namespace) -> None:
304
+ out_jsonl = Path(args.data) if args.data else ROOT / "data" / "sft" / "gsm8k_sft.jsonl"
305
+ ns = argparse.Namespace(
306
+ output=out_jsonl,
307
+ source=args.prepare_source,
308
+ input=args.input,
309
+ splits=args.splits,
310
+ strip_scratchpads=args.strip_scratchpads,
311
+ )
312
+ cmd_prepare(ns)
313
+ train_ns = argparse.Namespace(
314
+ data=str(out_jsonl),
315
+ output_dir=args.output_dir,
316
+ model=args.model,
317
+ epochs=args.epochs,
318
+ batch_size=args.batch_size,
319
+ grad_accum=args.grad_accum,
320
+ learning_rate=args.learning_rate,
321
+ max_samples=args.max_samples,
322
+ lora_rank=args.lora_rank,
323
+ lora_alpha=args.lora_alpha,
324
+ lora_dropout=args.lora_dropout,
325
+ target_modules=args.target_modules,
326
+ max_seq_length=args.max_seq_length,
327
+ save_steps=args.save_steps,
328
+ logging_steps=args.logging_steps,
329
+ warmup_ratio=args.warmup_ratio,
330
+ warmup_steps=args.warmup_steps,
331
+ bf16=args.bf16,
332
+ fp16=args.fp16,
333
+ bnb_compute_dtype=args.bnb_compute_dtype,
334
+ )
335
+ cmd_train(train_ns)
336
+ if args.problem:
337
+ infer_ns = argparse.Namespace(
338
+ adapter=Path(args.output_dir),
339
+ base_model=args.model,
340
+ problem=args.problem,
341
+ max_new_tokens=args.max_new_tokens,
342
+ temperature=args.temperature,
343
+ top_p=args.top_p,
344
+ greedy=args.greedy,
345
+ bnb_compute_dtype=args.bnb_compute_dtype,
346
+ )
347
+ cmd_infer(infer_ns)
348
+
349
+
350
+ def build_parser() -> argparse.ArgumentParser:
351
+ p = argparse.ArgumentParser(description="GSM8K SFT pipeline (prepare / train / infer / all)")
352
+ sub = p.add_subparsers(dest="command", required=True)
353
+
354
+ pr = sub.add_parser("prepare", help="Run convert_gsm8k_to_sft.py")
355
+ pr.add_argument("--output", type=str, default=str(ROOT / "data" / "sft" / "gsm8k_sft.jsonl"))
356
+ pr.add_argument("--source", choices=("hf", "jsonl"), default="hf")
357
+ pr.add_argument("--input", type=str, help="JSONL path for --source jsonl")
358
+ pr.add_argument("--splits", nargs="+", default=["train", "test"])
359
+ pr.add_argument(
360
+ "--strip-scratchpads",
361
+ action="store_true",
362
+ help="Remove GSM8K <<...>> traces from assistant text after conversion.",
363
+ )
364
+ pr.set_defaults(func=cmd_prepare)
365
+
366
+ tr = sub.add_parser("train", help="QLoRA SFT on JSONL with messages field")
367
+ tr.add_argument("--data", type=str, required=True, help="JSONL from prepare step")
368
+ tr.add_argument("--output-dir", type=str, required=True)
369
+ tr.add_argument("--model", type=str, default="Qwen/Qwen2.5-Math-1.5B-Instruct")
370
+ tr.add_argument("--epochs", type=float, default=1.0)
371
+ tr.add_argument("--batch-size", type=int, default=1)
372
+ tr.add_argument("--grad-accum", type=int, default=8)
373
+ tr.add_argument("--learning-rate", type=float, default=2e-4)
374
+ tr.add_argument("--max-samples", type=int, default=0, help="0 = use full dataset")
375
+ tr.add_argument("--lora-rank", type=int, default=16)
376
+ tr.add_argument("--lora-alpha", type=int, default=32)
377
+ tr.add_argument("--lora-dropout", type=float, default=0.05)
378
+ tr.add_argument(
379
+ "--target-modules",
380
+ type=str,
381
+ default="q_proj,v_proj,o_proj,gate_proj",
382
+ )
383
+ tr.add_argument("--max-seq-length", type=int, default=2048)
384
+ tr.add_argument("--save-steps", type=int, default=200)
385
+ tr.add_argument("--logging-steps", type=int, default=10)
386
+ tr.add_argument(
387
+ "--warmup-ratio",
388
+ type=float,
389
+ default=0.03,
390
+ help="Used only if --warmup-steps is not set; converted to warmup_steps.",
391
+ )
392
+ tr.add_argument(
393
+ "--warmup-steps",
394
+ type=int,
395
+ default=None,
396
+ help="LR warmup steps; if set, overrides --warmup-ratio.",
397
+ )
398
+ tr.add_argument("--bf16", action="store_true", default=True)
399
+ tr.add_argument("--no-bf16", dest="bf16", action="store_false")
400
+ tr.add_argument("--fp16", action="store_true")
401
+ tr.add_argument("--bnb-compute-dtype", type=str, default="bfloat16")
402
+ tr.set_defaults(func=cmd_train)
403
+
404
+ inf = sub.add_parser("infer", help="Generate with saved adapter")
405
+ inf.add_argument("--adapter", type=str, required=True, help="Directory from train step")
406
+ inf.add_argument(
407
+ "--base-model",
408
+ type=str,
409
+ default="Qwen/Qwen2.5-Math-1.5B-Instruct",
410
+ help="Must match base used in training if no pipeline_meta.json",
411
+ )
412
+ inf.add_argument("--problem", type=str, required=True)
413
+ inf.add_argument("--max-new-tokens", type=int, default=1024)
414
+ inf.add_argument("--temperature", type=float, default=0.7)
415
+ inf.add_argument("--top-p", type=float, default=0.95)
416
+ inf.add_argument("--greedy", action="store_true")
417
+ inf.add_argument("--bnb-compute-dtype", type=str, default="bfloat16")
418
+ inf.set_defaults(func=cmd_infer)
419
+
420
+ al = sub.add_parser("all", help="prepare + train [+ infer if --problem]")
421
+ al.add_argument("--data", type=str, default=None, help="Output JSONL path (default data/sft/gsm8k_sft.jsonl)")
422
+ al.add_argument("--prepare-source", choices=("hf", "jsonl"), default="hf")
423
+ al.add_argument("--input", type=str, help="For jsonl prepare")
424
+ al.add_argument("--splits", nargs="+", default=["train", "test"])
425
+ al.add_argument("--strip-scratchpads", action="store_true")
426
+ al.add_argument("--output-dir", type=str, required=True)
427
+ al.add_argument("--model", type=str, default="Qwen/Qwen2.5-Math-1.5B-Instruct")
428
+ al.add_argument("--epochs", type=float, default=1.0)
429
+ al.add_argument("--batch-size", type=int, default=1)
430
+ al.add_argument("--grad-accum", type=int, default=8)
431
+ al.add_argument("--learning-rate", type=float, default=2e-4)
432
+ al.add_argument("--max-samples", type=int, default=0)
433
+ al.add_argument("--lora-rank", type=int, default=16)
434
+ al.add_argument("--lora-alpha", type=int, default=32)
435
+ al.add_argument("--lora-dropout", type=float, default=0.05)
436
+ al.add_argument("--target-modules", type=str, default="q_proj,v_proj,o_proj,gate_proj")
437
+ al.add_argument("--max-seq-length", type=int, default=2048)
438
+ al.add_argument("--save-steps", type=int, default=200)
439
+ al.add_argument("--logging-steps", type=int, default=10)
440
+ al.add_argument(
441
+ "--warmup-ratio",
442
+ type=float,
443
+ default=0.03,
444
+ help="Used only if --warmup-steps is not set; converted to warmup_steps.",
445
+ )
446
+ al.add_argument(
447
+ "--warmup-steps",
448
+ type=int,
449
+ default=None,
450
+ help="LR warmup steps; if set, overrides --warmup-ratio.",
451
+ )
452
+ al.add_argument("--bf16", action="store_true", default=True)
453
+ al.add_argument("--no-bf16", dest="bf16", action="store_false")
454
+ al.add_argument("--fp16", action="store_true")
455
+ al.add_argument("--bnb-compute-dtype", type=str, default="bfloat16")
456
+ al.add_argument("--problem", type=str, default="", help="If set, run infer after train")
457
+ al.add_argument("--max-new-tokens", type=int, default=1024)
458
+ al.add_argument("--temperature", type=float, default=0.7)
459
+ al.add_argument("--top-p", type=float, default=0.95)
460
+ al.add_argument("--greedy", action="store_true")
461
+ al.set_defaults(func=cmd_all)
462
+
463
+ return p
464
+
465
+
466
+ def main() -> None:
467
+ parser = build_parser()
468
+ args = parser.parse_args()
469
+ if str(ROOT) not in sys.path:
470
+ sys.path.insert(0, str(ROOT))
471
+ args.func(args)
472
+
473
+
474
+ if __name__ == "__main__":
475
+ main()
scripts/launch_grpo.sh ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ set -euo pipefail
2
+
3
+ # ── Flash-Attention 2 install (if missing) ────────────────────────────────────
4
+ # flash-attn requires (torch version, CUDA version, Python version) alignment.
5
+ # MAX_JOBS caps parallel compilation; prebuilt wheel installs in <30 s.
6
+ # In the prior run (grpo_20260425_151304), flash-attn was absent β†’ SDPA fallback
7
+ # β†’ iter times of 262-330 s once question-gen started (vs ~150 s with Flash).
8
+ if ! python -c "import flash_attn; assert int(flash_attn.__version__.split('.')[0]) >= 2" 2>/dev/null; then
9
+ echo "[launch] flash-attn not found or < v2 β€” installing now …"
10
+ MAX_JOBS=4 pip install flash-attn --no-build-isolation -q
11
+ echo "[launch] flash-attn installed."
12
+ else
13
+ FLASH_VER=$(python -c "import flash_attn; print(flash_attn.__version__)" 2>/dev/null)
14
+ echo "[launch] flash-attn ${FLASH_VER} already installed β€” skipping install."
15
+ fi
16
+
17
+ # ── GPU / allocator ───────────────────────────────────────────────────────────
18
+ export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-0}
19
+ # expandable_segments: recovers 2-4 GB fragmented VRAM during long Flash+HF runs
20
+ export PYTORCH_CUDA_ALLOC_CONF=${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}
21
+
22
+ # ── CPU / threading ───────────────────────────────────────────────────────────
23
+ export OMP_NUM_THREADS=${OMP_NUM_THREADS:-8}
24
+ export MKL_NUM_THREADS=${MKL_NUM_THREADS:-8}
25
+ export TOKENIZERS_PARALLELISM=${TOKENIZERS_PARALLELISM:-false}
26
+
27
+ # ── Triton / Flash-Attn compilation cache ─────────────────────────────────────
28
+ # Persists JIT kernels across runs β€” avoids ~30 s recompile each launch.
29
+ export TRITON_CACHE_DIR=${TRITON_CACHE_DIR:-/tmp/triton_cache}
30
+ export FLASH_ATTENTION_SKIP_CUDA_BUILD=${FLASH_ATTENTION_SKIP_CUDA_BUILD:-FALSE}
31
+
32
+ # ── HuggingFace hub robustness ────────────────────────────────────────────────
33
+ export HF_HUB_DISABLE_XET=${HF_HUB_DISABLE_XET:-1}
34
+ export HF_HUB_ENABLE_HF_TRANSFER=${HF_HUB_ENABLE_HF_TRANSFER:-0}
35
+ export TRANSFORMERS_VERBOSITY=${TRANSFORMERS_VERBOSITY:-warning}
36
+
37
+ # ── Python path ───────────────────────────────────────────────────────────────
38
+ export PYTHONPATH="${PYTHONPATH:-}:$(pwd)"
39
+
40
+ # ── Pre-flight: GPU info ───────────────────────────────────────────────────────
41
+ if command -v nvidia-smi >/dev/null 2>&1; then
42
+ echo "─── nvidia-smi ───────────────────────────────────────────────────"
43
+ nvidia-smi --query-gpu=name,memory.total,memory.free,driver_version \
44
+ --format=csv,noheader || true
45
+ echo "──────────────────────────────────────────────────────────────────"
46
+ fi
47
+
48
+ # ── Confirm attention backend ─────────────────────────────────────────────────
49
+ python - <<'PYEOF'
50
+ import sys; sys.path.insert(0, '.')
51
+ from src.utils.attn_backend import select_attn_implementation
52
+ impl = select_attn_implementation()
53
+ tag = {
54
+ "flash_attention_2": "FAST β€” Flash-Attn 2 active (O(T) memory, ~1.5-2Γ— faster)",
55
+ "sdpa": "OK β€” SDPA active (install flash-attn for ~2Γ— speedup)",
56
+ "eager": "SLOW β€” Eager fallback (install flash-attn for best speed)",
57
+ }.get(impl, impl)
58
+ print(f"[launch] attn_backend = {tag}")
59
+ PYEOF
60
+
61
+ # ── Log tee ───────────────────────────────────────────────────────────────────
62
+ RUN_NAME="grpo_$(date +%Y%m%d_%H%M%S)"
63
+ LOG_DIR="logs/grpo"
64
+ mkdir -p "$LOG_DIR"
65
+ LOG_FILE="$LOG_DIR/${RUN_NAME}.log"
66
+
67
+ echo "[launch] run_name = $RUN_NAME"
68
+ echo "[launch] base_model = checkpoints/dual_task_v1"
69
+ echo "[launch] train_data = data/sft/gsm8k_sft.jsonl + data/math/math_numeric.jsonl"
70
+ echo "[launch] eval_data = data/sft/gsm8k_test.jsonl"
71
+ echo "[launch] log_file = $LOG_FILE"
72
+ echo "[launch] architecture = Two-phase self-play (K_q=2, K=10, N=20)"
73
+ echo "[launch] fixes_applied = min-warmup↑12, selfplay-gt-thresh↑0.65, kl-coef↑0.06,"
74
+ echo "[launch] math-ramp-start↑18, group-size↑10, num-iters↑60"
75
+ echo "[launch] wall-time β‰ˆ 3.3 h (Flash active) / 4.5 h (SDPA fallback)"
76
+
77
+ # ── Train ─────────────────────────────────────────────────────────────────────
78
+ python -u scripts/run_grpo_training.py \
79
+ --base-model checkpoints/dual_task_v1 \
80
+ --output-dir checkpoints/grpo \
81
+ --gsm8k-data data/sft/gsm8k_sft.jsonl \
82
+ --eval-data-path data/sft/gsm8k_test.jsonl \
83
+ \
84
+ --num-iterations 60 \
85
+ --group-size 10 \
86
+ --q-group-size 2 \
87
+ --questions-per-iter 20 \
88
+ \
89
+ --learning-rate 5e-6 \
90
+ --max-new-tokens 1000 \
91
+ --temperature 0.8 \
92
+ --max-grad-norm 0.5 \
93
+ --clip-eps 0.2 \
94
+ --kl-coef 0.06 \
95
+ --warmup-iters 8 \
96
+ --min-lr-ratio 0.1 \
97
+ \
98
+ --difficulty-alpha 3.5 \
99
+ --self-play-ratio 0.70 \
100
+ \
101
+ --math-mix-ratio 0.30 \
102
+ --math-mix-ratio-late 0.50 \
103
+ --math-ramp-start 18 \
104
+ --math-max-difficulty 3 \
105
+ \
106
+ --overlong-filter \
107
+ --min-warmup 12 \
108
+ --selfplay-gt-thresh 0.65 \
109
+ --selfplay-grounded-thresh 0.65 \
110
+ --selfplay-step-thresh 0.68 \
111
+ --selfplay-ramp-iters 28 \
112
+ --grounded-floor 0.55 \
113
+ \
114
+ --extractor-model Qwen/Qwen2.5-0.5B-Instruct \
115
+ --extraction-cache data/extraction_cache.json \
116
+ \
117
+ --eval-every 5 \
118
+ --eval-max-samples 150 \
119
+ --eval-max-new-tokens 1000 \
120
+ --eval-pass-at-k 0 \
121
+ --save-every 5 \
122
+ --keep-last 4 \
123
+ \
124
+ --use-prm \
125
+ --prm-model Qwen/Qwen2.5-Math-PRM-7B \
126
+ --run-name "$RUN_NAME" \
127
+ "$@" 2>&1 | tee "$LOG_FILE"
scripts/plot_grpo_run.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Generate demo-quality plots from a completed (or in-progress) GRPO run.
4
+
5
+ Usage
6
+ -----
7
+ # from the run output directory
8
+ python scripts/plot_grpo_run.py checkpoints/grpo/<run_name>/metrics.jsonl
9
+
10
+ # auto-discover the latest run
11
+ python scripts/plot_grpo_run.py --latest
12
+
13
+ # custom output directory
14
+ python scripts/plot_grpo_run.py metrics.jsonl --out-dir plots/my_run
15
+
16
+ Output
17
+ ------
18
+ Six PNG files saved next to the JSONL (or --out-dir if given):
19
+
20
+ 01_training_objective.png – combined_score vs iteration (PRIMARY demo plot)
21
+ 02_reward_components.png – 4-panel breakdown: correct / PRM / SymPy / format
22
+ 03_training_dynamics.png – GRPO loss + batch reward + batch accuracy
23
+ 04_reward_vs_eval.png – training reward vs eval score on same axis
24
+ 05_component_area.png – stacked-area chart of the 4 weighted components
25
+ 06_summary_card.png – single-panel card: all key metrics in one view
26
+
27
+ All figures use a clean dark-on-white academic style. They are saved at
28
+ 300 dpi so they look sharp in slides and posters.
29
+ """
30
+
31
+ from __future__ import annotations
32
+
33
+ import argparse
34
+ import json
35
+ import sys
36
+ from pathlib import Path
37
+ from typing import Any, Dict, List, Optional, Tuple
38
+
39
+ import matplotlib
40
+ matplotlib.use("Agg") # headless β€” no display needed on training servers
41
+ import matplotlib.pyplot as plt
42
+ import matplotlib.ticker as mtick
43
+ import numpy as np
44
+
45
+
46
+ # ── Style ────────────────────────────────────────────────────────────────────
47
+
48
+ PALETTE = {
49
+ "combined": "#2563EB", # blue β€” training objective
50
+ "correct": "#16A34A", # green β€” correctness
51
+ "prm": "#DC2626", # red β€” PRM step quality
52
+ "sympy": "#D97706", # amber β€” SymPy verification
53
+ "fmt": "#7C3AED", # violet β€” format
54
+ "reward": "#0891B2", # cyan β€” mean batch reward
55
+ "loss": "#64748B", # slate β€” loss
56
+ "batch_acc": "#059669", # emerald β€” batch accuracy
57
+ }
58
+
59
+ plt.rcParams.update({
60
+ "figure.dpi": 150,
61
+ "savefig.dpi": 300,
62
+ "font.family": "DejaVu Sans",
63
+ "axes.spines.top": False,
64
+ "axes.spines.right": False,
65
+ "axes.grid": True,
66
+ "grid.alpha": 0.3,
67
+ "grid.linestyle": "--",
68
+ "axes.labelsize": 11,
69
+ "axes.titlesize": 13,
70
+ "legend.fontsize": 9,
71
+ "xtick.labelsize": 9,
72
+ "ytick.labelsize": 9,
73
+ })
74
+
75
+
76
+ # ── Data loading ─────────────────────────────────────────────────────────────
77
+
78
+ def _load(path: Path) -> List[Dict[str, Any]]:
79
+ rows = []
80
+ with path.open(encoding="utf-8") as fh:
81
+ for line in fh:
82
+ line = line.strip()
83
+ if line:
84
+ rows.append(json.loads(line))
85
+ return rows
86
+
87
+
88
+ def _field(rows: List[Dict], key: str) -> Tuple[List[int], List[float]]:
89
+ """Return (iterations, values) for rows that have a non-empty key."""
90
+ iters, vals = [], []
91
+ for r in rows:
92
+ v = r.get(key)
93
+ if v is not None and v != "" and not (isinstance(v, float) and np.isnan(v)):
94
+ try:
95
+ iters.append(int(r["iteration"]))
96
+ vals.append(float(v))
97
+ except (TypeError, ValueError):
98
+ pass
99
+ return iters, vals
100
+
101
+
102
+ # ── Individual plots ─────────────────────────────────────────────────────────
103
+
104
+ def plot_training_objective(rows: List[Dict], out: Path) -> None:
105
+ """Plot 01: combined_score β€” the single most important demo plot."""
106
+ xi, xv = _field(rows, "combined_score")
107
+ if not xi:
108
+ return
109
+
110
+ fig, ax = plt.subplots(figsize=(9, 5))
111
+ ax.plot(xi, xv, color=PALETTE["combined"], linewidth=2.5,
112
+ marker="o", markersize=5, label="Training-objective score")
113
+ ax.fill_between(xi, xv, alpha=0.12, color=PALETTE["combined"])
114
+
115
+ # annotate first and last eval points
116
+ ax.annotate(f"{xv[0]:.3f}", (xi[0], xv[0]), textcoords="offset points",
117
+ xytext=(8, 6), fontsize=8, color=PALETTE["combined"])
118
+ ax.annotate(f"{xv[-1]:.3f}", (xi[-1], xv[-1]), textcoords="offset points",
119
+ xytext=(8, 6), fontsize=8, color=PALETTE["combined"])
120
+
121
+ ax.set_xlabel("Iteration")
122
+ ax.set_ylabel("Score (0 – 1)")
123
+ ax.set_title(
124
+ "GRPO Training β€” Combined Reward Score\n"
125
+ "0.60 Γ— correct + 0.15 Γ— PRM + 0.15 Γ— SymPy + 0.10 Γ— format",
126
+ fontsize=12,
127
+ )
128
+ ax.set_ylim(0, 1.05)
129
+ ax.yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1, decimals=0))
130
+ ax.legend(loc="lower right")
131
+ fig.tight_layout()
132
+ fig.savefig(out)
133
+ plt.close(fig)
134
+ print(f" saved {out.name}")
135
+
136
+
137
+ def plot_reward_components(rows: List[Dict], out: Path) -> None:
138
+ """Plot 02: four-panel breakdown of each reward component."""
139
+ specs = [
140
+ ("correct_rate", "correct", "Correctness (gt_match)", "60 %"),
141
+ ("prm_mean", "prm", "PRM Step Quality", "15 %"),
142
+ ("sympy_mean", "sympy", "SymPy Verification", "15 %"),
143
+ ("format_mean", "fmt", "Format Compliance", "10 %"),
144
+ ]
145
+
146
+ fig, axes = plt.subplots(2, 2, figsize=(12, 7), sharex=False)
147
+ axes = axes.flatten()
148
+
149
+ for ax, (key, pal, title, weight) in zip(axes, specs):
150
+ xi, xv = _field(rows, key)
151
+ if not xi:
152
+ ax.set_visible(False)
153
+ continue
154
+ ax.plot(xi, xv, color=PALETTE[pal], linewidth=2,
155
+ marker="o", markersize=4)
156
+ ax.fill_between(xi, xv, alpha=0.12, color=PALETTE[pal])
157
+ ax.set_title(f"{title} (weight {weight})", fontsize=11)
158
+ ax.set_xlabel("Iteration")
159
+ ax.set_ylabel("Score")
160
+ ax.set_ylim(0, 1.05)
161
+ ax.yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1, decimals=0))
162
+
163
+ if xv:
164
+ delta = xv[-1] - xv[0]
165
+ sign = "+" if delta >= 0 else ""
166
+ ax.set_title(
167
+ f"{title} (weight {weight}) Ξ”={sign}{delta:+.1%}",
168
+ fontsize=10,
169
+ )
170
+
171
+ fig.suptitle("Reward Component Breakdown over Training", fontsize=13, y=1.01)
172
+ fig.tight_layout()
173
+ fig.savefig(out, bbox_inches="tight")
174
+ plt.close(fig)
175
+ print(f" saved {out.name}")
176
+
177
+
178
+ def plot_training_dynamics(rows: List[Dict], out: Path) -> None:
179
+ """Plot 03: loss, mean_reward, batch_accuracy over all iterations."""
180
+ li, lv = _field(rows, "loss")
181
+ ri, rv = _field(rows, "mean_reward")
182
+ bi, bv = _field(rows, "batch_accuracy")
183
+
184
+ fig, axes = plt.subplots(3, 1, figsize=(10, 8), sharex=True)
185
+
186
+ if lv:
187
+ axes[0].plot(li, lv, color=PALETTE["loss"], linewidth=1.8)
188
+ axes[0].fill_between(li, lv, alpha=0.1, color=PALETTE["loss"])
189
+ axes[0].set_ylabel("GRPO Loss")
190
+ axes[0].set_title("Training Loss", fontsize=11)
191
+ axes[0].axhline(0, color="black", linewidth=0.8, linestyle="--", alpha=0.4)
192
+
193
+ if rv:
194
+ axes[1].plot(ri, rv, color=PALETTE["reward"], linewidth=1.8)
195
+ axes[1].fill_between(ri, rv, alpha=0.1, color=PALETTE["reward"])
196
+ axes[1].set_ylabel("Reward")
197
+ axes[1].set_ylim(0, 1.05)
198
+ axes[1].set_title("Mean Batch Reward", fontsize=11)
199
+ axes[1].yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1, decimals=0))
200
+
201
+ if bv:
202
+ axes[2].plot(bi, bv, color=PALETTE["batch_acc"], linewidth=1.8)
203
+ axes[2].fill_between(bi, bv, alpha=0.1, color=PALETTE["batch_acc"])
204
+ axes[2].set_ylabel("Accuracy")
205
+ axes[2].set_ylim(0, 1.05)
206
+ axes[2].set_title("Batch Accuracy (training rollouts)", fontsize=11)
207
+ axes[2].yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1, decimals=0))
208
+
209
+ for ax in axes:
210
+ ax.set_xlabel("Iteration")
211
+
212
+ fig.suptitle("GRPO Training Dynamics", fontsize=13)
213
+ fig.tight_layout()
214
+ fig.savefig(out)
215
+ plt.close(fig)
216
+ print(f" saved {out.name}")
217
+
218
+
219
+ def plot_reward_vs_eval(rows: List[Dict], out: Path) -> None:
220
+ """Plot 04: mean_reward (all iters) + combined_score (eval iters) overlaid."""
221
+ ri, rv = _field(rows, "mean_reward")
222
+ ei, ev = _field(rows, "combined_score")
223
+
224
+ fig, ax = plt.subplots(figsize=(10, 5))
225
+
226
+ if rv:
227
+ ax.plot(ri, rv, color=PALETTE["reward"], linewidth=1.4, alpha=0.7,
228
+ label="Batch reward (training)")
229
+ ax.fill_between(ri, rv, alpha=0.06, color=PALETTE["reward"])
230
+
231
+ if ev:
232
+ ax.plot(ei, ev, color=PALETTE["combined"], linewidth=2.5,
233
+ marker="D", markersize=6, label="Eval score (held-out GSM8K)")
234
+ for x, y in zip(ei, ev):
235
+ ax.annotate(f"{y:.3f}", (x, y), textcoords="offset points",
236
+ xytext=(0, 8), ha="center", fontsize=7,
237
+ color=PALETTE["combined"])
238
+
239
+ ax.set_xlabel("Iteration")
240
+ ax.set_ylabel("Score (0 – 1)")
241
+ ax.set_ylim(0, 1.05)
242
+ ax.yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1, decimals=0))
243
+ ax.set_title("Training Reward vs Held-Out Eval Score", fontsize=12)
244
+ ax.legend()
245
+ fig.tight_layout()
246
+ fig.savefig(out)
247
+ plt.close(fig)
248
+ print(f" saved {out.name}")
249
+
250
+
251
+ def plot_component_area(rows: List[Dict], out: Path) -> None:
252
+ """Plot 05: stacked-area of the four WEIGHTED components summing to combined_score."""
253
+ ei, ev_combined = _field(rows, "combined_score")
254
+ if not ei:
255
+ return
256
+
257
+ # Build per-component weighted series aligned to eval iterations
258
+ iter_set = set(ei)
259
+ aligned: Dict[str, List[float]] = {k: [] for k in ("correct", "prm", "sympy", "fmt")}
260
+ weights = {"correct": 0.60, "prm": 0.15, "sympy": 0.15, "fmt": 0.10}
261
+ keys = {"correct": "correct_rate", "prm": "prm_mean",
262
+ "sympy": "sympy_mean", "fmt": "format_mean"}
263
+
264
+ # Build lookup per iteration
265
+ it_map: Dict[int, Dict] = {r["iteration"]: r for r in rows if r["iteration"] in iter_set}
266
+ iters_sorted = sorted(iter_set)
267
+
268
+ for it in iters_sorted:
269
+ row = it_map.get(it, {})
270
+ for comp, field in keys.items():
271
+ v = row.get(field)
272
+ if v is not None and v != "":
273
+ aligned[comp].append(float(v) * weights[comp])
274
+ else:
275
+ aligned[comp].append(0.0)
276
+
277
+ x = np.array(iters_sorted)
278
+ arr = np.array([aligned["correct"], aligned["prm"],
279
+ aligned["sympy"], aligned["fmt"]])
280
+
281
+ fig, ax = plt.subplots(figsize=(10, 5))
282
+ labels = ["Correct (Γ—0.60)", "PRM (Γ—0.15)", "SymPy (Γ—0.15)", "Format (Γ—0.10)"]
283
+ colors = [PALETTE[k] for k in ("correct", "prm", "sympy", "fmt")]
284
+ ax.stackplot(x, arr, labels=labels, colors=colors, alpha=0.75)
285
+
286
+ ax.plot(x, ev_combined, color="black", linewidth=1.5,
287
+ linestyle="--", label="Combined score", zorder=5)
288
+
289
+ ax.set_xlabel("Iteration")
290
+ ax.set_ylabel("Weighted contribution to score")
291
+ ax.set_ylim(0, 1.0)
292
+ ax.yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1, decimals=0))
293
+ ax.set_title("Contribution of Each Reward Component (Stacked)", fontsize=12)
294
+ ax.legend(loc="lower right", ncol=2)
295
+ fig.tight_layout()
296
+ fig.savefig(out)
297
+ plt.close(fig)
298
+ print(f" saved {out.name}")
299
+
300
+
301
+ def plot_summary_card(rows: List[Dict], run_name: str, out: Path) -> None:
302
+ """Plot 06: all key metrics on a single clean card β€” ideal for poster / slide."""
303
+ ei, ev = _field(rows, "combined_score")
304
+ _, crv = _field(rows, "correct_rate")
305
+ _, prmv = _field(rows, "prm_mean")
306
+ _, syv = _field(rows, "sympy_mean")
307
+ _, fmv = _field(rows, "format_mean")
308
+ _, lv = _field(rows, "loss")
309
+ _, rv = _field(rows, "mean_reward")
310
+ li = _field(rows, "loss")[0]
311
+ ri = _field(rows, "mean_reward")[0]
312
+
313
+ fig, axes = plt.subplots(2, 3, figsize=(15, 8))
314
+ axes = axes.flatten()
315
+
316
+ def _panel(ax, iters, vals, color, title, pct=True):
317
+ if not iters:
318
+ ax.set_visible(False)
319
+ return
320
+ ax.plot(iters, vals, color=color, linewidth=2, marker="o", markersize=4)
321
+ ax.fill_between(iters, vals, alpha=0.12, color=color)
322
+ ax.set_title(title, fontsize=11, fontweight="bold")
323
+ ax.set_xlabel("Iteration", fontsize=9)
324
+ if pct:
325
+ ax.set_ylim(0, 1.05)
326
+ ax.yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1, decimals=0))
327
+ if vals:
328
+ ax.annotate(f"{vals[-1]:.3f}", (iters[-1], vals[-1]),
329
+ textcoords="offset points", xytext=(6, 4),
330
+ fontsize=8, color=color)
331
+
332
+ _panel(axes[0], ei, ev, PALETTE["combined"], "Training-Objective Score")
333
+ _panel(axes[1], ei, crv, PALETTE["correct"], "Correctness Rate")
334
+ _panel(axes[2], ei, prmv, PALETTE["prm"], "PRM Step Quality")
335
+ _panel(axes[3], ei, syv, PALETTE["sympy"], "SymPy Verification")
336
+ _panel(axes[4], ei, fmv, PALETTE["fmt"], "Format Compliance")
337
+ _panel(axes[5], li, lv, PALETTE["loss"], "GRPO Loss", pct=False)
338
+
339
+ fig.suptitle(f"GRPO Training Summary β€” {run_name}", fontsize=14, fontweight="bold")
340
+ fig.tight_layout()
341
+ fig.savefig(out, bbox_inches="tight")
342
+ plt.close(fig)
343
+ print(f" saved {out.name}")
344
+
345
+
346
+ # ── CLI ──────────────────────────────────────────────────────────────────────
347
+
348
+ def find_latest_metrics() -> Optional[Path]:
349
+ """Find the most recently modified metrics.jsonl under checkpoints/grpo/."""
350
+ ckpt = Path("checkpoints/grpo")
351
+ if not ckpt.exists():
352
+ return None
353
+ candidates = sorted(
354
+ ckpt.rglob("metrics.jsonl"),
355
+ key=lambda p: p.stat().st_mtime,
356
+ )
357
+ return candidates[-1] if candidates else None
358
+
359
+
360
+ def generate_plots(metrics_path: Path, out_dir: Optional[Path] = None) -> Path:
361
+ """Generate all six plots and return the output directory."""
362
+ rows = _load(metrics_path)
363
+ if not rows:
364
+ print(f"[plot] No data in {metrics_path}", file=sys.stderr)
365
+ return metrics_path.parent
366
+
367
+ out_dir = out_dir or metrics_path.parent / "plots"
368
+ out_dir.mkdir(parents=True, exist_ok=True)
369
+
370
+ # Derive run name from the directory name two levels up
371
+ run_name = metrics_path.parent.name
372
+
373
+ print(f"[plot] Generating plots for run '{run_name}' ({len(rows)} iterations)")
374
+ print(f"[plot] Output β†’ {out_dir}")
375
+
376
+ plot_training_objective(rows, out_dir / "01_training_objective.png")
377
+ plot_reward_components(rows, out_dir / "02_reward_components.png")
378
+ plot_training_dynamics(rows, out_dir / "03_training_dynamics.png")
379
+ plot_reward_vs_eval(rows, out_dir / "04_reward_vs_eval.png")
380
+ plot_component_area(rows, out_dir / "05_component_area.png")
381
+ plot_summary_card(rows, run_name, out_dir / "06_summary_card.png")
382
+
383
+ print(f"[plot] Done β€” {len(list(out_dir.glob('*.png')))} PNGs in {out_dir}")
384
+ return out_dir
385
+
386
+
387
+ def main() -> None:
388
+ parser = argparse.ArgumentParser(
389
+ description="Generate demo plots from a GRPO metrics.jsonl file."
390
+ )
391
+ parser.add_argument(
392
+ "metrics_jsonl", nargs="?", type=Path, default=None,
393
+ help="Path to metrics.jsonl produced by run_grpo_training.py",
394
+ )
395
+ parser.add_argument(
396
+ "--latest", action="store_true",
397
+ help="Auto-discover the most recent metrics.jsonl under checkpoints/grpo/",
398
+ )
399
+ parser.add_argument(
400
+ "--out-dir", type=Path, default=None,
401
+ help="Directory to write PNG files (default: <metrics_dir>/plots/)",
402
+ )
403
+ args = parser.parse_args()
404
+
405
+ if args.latest:
406
+ path = find_latest_metrics()
407
+ if path is None:
408
+ print("No metrics.jsonl found under checkpoints/grpo/", file=sys.stderr)
409
+ sys.exit(1)
410
+ print(f"[plot] Auto-selected {path}")
411
+ elif args.metrics_jsonl:
412
+ path = args.metrics_jsonl
413
+ else:
414
+ parser.print_help()
415
+ sys.exit(1)
416
+
417
+ if not path.exists():
418
+ print(f"File not found: {path}", file=sys.stderr)
419
+ sys.exit(1)
420
+
421
+ generate_plots(path, args.out_dir)
422
+
423
+
424
+ if __name__ == "__main__":
425
+ main()
scripts/plot_training_results.py ADDED
@@ -0,0 +1,521 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ AxiomForgeAI β€” Training Results Plots
4
+ ======================================
5
+ Reads the metrics CSV from a GRPO training run and generates five focused plots
6
+ that tell the story of what improved, how self-play was earned, and why step-level
7
+ reasoning quality matters as much as final-answer accuracy.
8
+
9
+ All plots are saved to images/ as high-resolution PNGs.
10
+
11
+ Usage
12
+ -----
13
+ python scripts/plot_training_results.py
14
+ python scripts/plot_training_results.py --metrics logs/grpo/grpo_20260426_032827/metrics.csv
15
+ python scripts/plot_training_results.py --out images/
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import argparse
21
+ import csv
22
+ from pathlib import Path
23
+ from typing import Dict, List
24
+
25
+ import matplotlib
26
+ matplotlib.use("Agg")
27
+ import matplotlib.pyplot as plt
28
+ import matplotlib.patches as mpatches
29
+ import numpy as np
30
+
31
+ # ── Style ──────────────────────────────────────────────────────────────────────
32
+ PALETTE = {
33
+ "indigo": "#6366f1",
34
+ "pink": "#ec4899",
35
+ "cyan": "#06b6d4",
36
+ "amber": "#f59e0b",
37
+ "emerald": "#10b981",
38
+ "slate": "#94a3b8",
39
+ "red": "#ef4444",
40
+ "violet": "#8b5cf6",
41
+ "white": "#f8fafc",
42
+ "bg": "#0f172a",
43
+ "bg2": "#1e293b",
44
+ "gridline": "#1e293b",
45
+ }
46
+
47
+ plt.rcParams.update({
48
+ "figure.facecolor": PALETTE["bg"],
49
+ "axes.facecolor": PALETTE["bg"],
50
+ "axes.edgecolor": PALETTE["slate"],
51
+ "axes.labelcolor": PALETTE["white"],
52
+ "axes.titlecolor": PALETTE["white"],
53
+ "axes.titlesize": 13,
54
+ "axes.labelsize": 11,
55
+ "axes.grid": True,
56
+ "grid.color": "#1e293b",
57
+ "grid.linewidth": 0.8,
58
+ "xtick.color": PALETTE["slate"],
59
+ "ytick.color": PALETTE["slate"],
60
+ "xtick.labelsize": 9,
61
+ "ytick.labelsize": 9,
62
+ "legend.facecolor": "#1e293b",
63
+ "legend.edgecolor": PALETTE["slate"],
64
+ "legend.labelcolor": PALETTE["white"],
65
+ "legend.fontsize": 9,
66
+ "text.color": PALETTE["white"],
67
+ "font.family": "sans-serif",
68
+ "lines.linewidth": 2.0,
69
+ })
70
+
71
+ PHASE_COLORS = {
72
+ "GROUNDED_ONLY": ("#6366f120", "#6366f1"),
73
+ "SELFPLAY_RAMP": ("#10b98120", "#10b981"),
74
+ }
75
+
76
+ DPI = 160
77
+ IMAGES_DIR = Path("images")
78
+
79
+ DEFAULT_METRICS = (
80
+ "logs/grpo/grpo_20260426_032827/metrics.csv"
81
+ )
82
+
83
+
84
+ # ── Helpers ────────────────────────────────────────────────────────────────────
85
+
86
+ def load_csv(path: str) -> List[Dict]:
87
+ rows = []
88
+ with open(path, encoding="utf-8") as f:
89
+ for r in csv.DictReader(f):
90
+ rows.append({k: v for k, v in r.items()})
91
+ return rows
92
+
93
+
94
+ def f(row: Dict, key: str, default: float = float("nan")) -> float:
95
+ v = row.get(key, "")
96
+ try:
97
+ return float(v) if v != "" else default
98
+ except (ValueError, TypeError):
99
+ return default
100
+
101
+
102
+ def moving_avg(values: List[float], w: int = 3) -> List[float]:
103
+ result = []
104
+ for i in range(len(values)):
105
+ lo = max(0, i - w + 1)
106
+ chunk = [v for v in values[lo : i + 1] if not np.isnan(v)]
107
+ result.append(float(np.mean(chunk)) if chunk else float("nan"))
108
+ return result
109
+
110
+
111
+ def shade_phases(ax, iters, phases):
112
+ """Draw translucent background rectangles for each training phase."""
113
+ prev_phase, start = None, iters[0]
114
+ for it, ph in zip(iters, phases):
115
+ if ph != prev_phase:
116
+ if prev_phase is not None:
117
+ bg, _ = PHASE_COLORS.get(prev_phase, ("#ffffff10", "#ffffff"))
118
+ ax.axvspan(start - 0.5, it - 0.5, facecolor=bg, linewidth=0, zorder=0)
119
+ prev_phase, start = ph, it
120
+ if prev_phase is not None:
121
+ bg, _ = PHASE_COLORS.get(prev_phase, ("#ffffff10", "#ffffff"))
122
+ ax.axvspan(start - 0.5, iters[-1] + 0.5, facecolor=bg, linewidth=0, zorder=0)
123
+
124
+
125
+ def phase_legend_patches(phases):
126
+ seen = []
127
+ patches = []
128
+ for ph in phases:
129
+ if ph not in seen:
130
+ seen.append(ph)
131
+ _, edge = PHASE_COLORS.get(ph, ("#ffffff10", "#ffffff"))
132
+ label = ph.replace("_", " ").title()
133
+ patches.append(mpatches.Patch(facecolor=edge + "40", edgecolor=edge,
134
+ linewidth=1.2, label=label))
135
+ return patches
136
+
137
+
138
+ def annotate_transition(ax, x_iter, label, ypos=0.97, color="#94a3b8"):
139
+ ax.axvline(x=x_iter - 0.5, color=color, linewidth=1, linestyle="--", alpha=0.7)
140
+ ax.text(x_iter, ypos, label, transform=ax.get_xaxis_transform(),
141
+ fontsize=7.5, color=color, ha="left", va="top",
142
+ bbox=dict(facecolor=PALETTE["bg2"], edgecolor="none", pad=2))
143
+
144
+
145
+ def save(fig: plt.Figure, name: str, out: Path):
146
+ out.mkdir(parents=True, exist_ok=True)
147
+ path = out / name
148
+ fig.savefig(path, dpi=DPI, bbox_inches="tight", facecolor=fig.get_facecolor())
149
+ print(f" βœ“ {path}")
150
+ plt.close(fig)
151
+
152
+
153
+ # ══════════════════════════════════════════════════════════════════════════════
154
+ # PLOT 1 β€” Hero: Reasoning quality at evaluation checkpoints
155
+ # Shows four signals together: GSM8K accuracy, combined score, step accuracy,
156
+ # and LCCP. The message: the model doesn't just get more answers right β€”
157
+ # every step of the reasoning chain gets better.
158
+ # ══════════════════════════════════════════════════════════════════════════════
159
+
160
+ def plot_eval_quality(rows: List[Dict], out: Path):
161
+ eval_rows = [r for r in rows if r.get("eval_combined", "") != ""]
162
+ iters = [int(r["iteration"]) for r in eval_rows]
163
+
164
+ gsm8k_acc = [f(r, "eval_correct_rt") * 100 for r in eval_rows]
165
+ combined = [f(r, "eval_combined") * 100 for r in eval_rows]
166
+ step_acc = [f(r, "eval_step_acc") * 100 for r in eval_rows]
167
+ lccp = [f(r, "eval_lccp") * 100 for r in eval_rows]
168
+ prm = [f(r, "eval_prm") * 100 for r in eval_rows]
169
+
170
+ fig, ax = plt.subplots(figsize=(9, 5))
171
+ fig.suptitle("Evaluation Quality Over Training β€” AxiomForgeAI",
172
+ fontsize=14, fontweight="bold", color=PALETTE["white"], y=1.01)
173
+
174
+ # --- lines
175
+ ax.plot(iters, gsm8k_acc, "o-", color=PALETTE["pink"], label="GSM8K Accuracy (final answer)", ms=7, zorder=5)
176
+ ax.plot(iters, combined, "s-", color=PALETTE["indigo"], label="Combined Score", ms=6, zorder=5)
177
+ ax.plot(iters, step_acc, "^-", color=PALETTE["cyan"], label="Step Accuracy (reasoning chain)", ms=6, zorder=5)
178
+ ax.plot(iters, lccp, "D-", color=PALETTE["emerald"], label="LCCP (chain integrity)", ms=6, zorder=5)
179
+ ax.plot(iters, prm, "v--", color=PALETTE["amber"], label="PRM Mean Score", ms=5, alpha=0.8, zorder=4)
180
+
181
+ # annotate best GSM8K
182
+ best_gsm = max(gsm8k_acc)
183
+ bi = gsm8k_acc.index(best_gsm)
184
+ ax.annotate(f" {best_gsm:.1f}%",
185
+ xy=(iters[bi], best_gsm), fontsize=9, color=PALETTE["pink"],
186
+ va="bottom", ha="left")
187
+
188
+ # annotate best combined
189
+ best_c = max(combined)
190
+ bci = combined.index(best_c)
191
+ ax.annotate(f" {best_c:.1f}",
192
+ xy=(iters[bci], best_c), fontsize=9, color=PALETTE["indigo"],
193
+ va="top", ha="left")
194
+
195
+ ax.set_xlabel("Training Iteration")
196
+ ax.set_ylabel("Score (%)")
197
+ ax.set_xticks(iters)
198
+ ax.set_ylim(78, 96)
199
+ ax.yaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter("%.0f%%"))
200
+ ax.legend(loc="lower right", framealpha=0.8)
201
+ ax.set_title(
202
+ "Four angles on quality β€” answer correctness, holistic score, per-step reasoning, and chain integrity",
203
+ fontsize=9, color=PALETTE["slate"], pad=6,
204
+ )
205
+
206
+ fig.tight_layout()
207
+ save(fig, "plot1_eval_quality.png", out)
208
+
209
+
210
+ # ══════════════════════════════════════════════════════════════════════════════
211
+ # PLOT 2 β€” Training Journey: full 30-iteration timeline with phase shading
212
+ # Shows mean reward, GT match rate, and step accuracy over every iteration.
213
+ # Phase backgrounds show when self-play unlocked and the curriculum ramped.
214
+ # ══════════════════════════════════════════════════════════════════════════════
215
+
216
+ def plot_training_journey(rows: List[Dict], out: Path):
217
+ iters = [int(r["iteration"]) for r in rows]
218
+ phases = [r["training_phase"] for r in rows]
219
+ mean_r = [f(r, "mean_reward") * 100 for r in rows]
220
+ gt_match = [f(r, "gt_match_rate") * 100 for r in rows]
221
+ step_acc = [f(r, "step_accuracy") * 100 for r in rows]
222
+ batch_acc = [f(r, "batch_accuracy") * 100 for r in rows]
223
+
224
+ ma_reward = moving_avg(mean_r, w=4)
225
+ ma_gt = moving_avg(gt_match, w=4)
226
+ ma_step = moving_avg(step_acc, w=4)
227
+
228
+ fig, ax = plt.subplots(figsize=(11, 5))
229
+ shade_phases(ax, iters, phases)
230
+
231
+ # raw (faint)
232
+ ax.plot(iters, mean_r, alpha=0.25, color=PALETTE["indigo"], linewidth=1)
233
+ ax.plot(iters, gt_match, alpha=0.25, color=PALETTE["pink"], linewidth=1)
234
+ ax.plot(iters, step_acc, alpha=0.25, color=PALETTE["cyan"], linewidth=1)
235
+
236
+ # smoothed (bold)
237
+ ax.plot(iters, ma_reward, color=PALETTE["indigo"], linewidth=2.5, label="Mean Reward (smooth)")
238
+ ax.plot(iters, ma_gt, color=PALETTE["pink"], linewidth=2.5, label="GT Match Rate (smooth)")
239
+ ax.plot(iters, ma_step, color=PALETTE["cyan"], linewidth=2.5, label="Step Accuracy (smooth)")
240
+
241
+ # self-play transition annotation
242
+ sp_start = next(i for i, p in enumerate(phases) if p == "SELFPLAY_RAMP")
243
+ annotate_transition(ax, iters[sp_start], "Self-play\nunlocked", ypos=0.98,
244
+ color=PALETTE["emerald"])
245
+
246
+ ax.set_xlabel("Training Iteration")
247
+ ax.set_ylabel("Score (%)")
248
+ ax.set_ylim(55, 105)
249
+ ax.yaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter("%.0f%%"))
250
+ ax.set_xticks(range(1, max(iters) + 1, 2))
251
+ ax.set_title("30-Iteration GRPO Training Timeline | Faint = raw Β· Bold = 4-iter moving average",
252
+ fontsize=9, color=PALETTE["slate"], pad=6)
253
+ fig.suptitle("Training Journey β€” Reward, GT Match & Step Accuracy",
254
+ fontsize=14, fontweight="bold", color=PALETTE["white"], y=1.01)
255
+
256
+ legend_patches = phase_legend_patches(phases)
257
+ h, l = ax.get_legend_handles_labels()
258
+ ax.legend(handles=h + legend_patches, loc="lower right", framealpha=0.8, ncol=2)
259
+
260
+ fig.tight_layout()
261
+ save(fig, "plot2_training_journey.png", out)
262
+
263
+
264
+ # ══════════════════════════════════════════════════════════════════════════════
265
+ # PLOT 3 β€” Self-Play Success: the curriculum earning its right to generate
266
+ # Shows the self-play ratio ramping up while question quality stays high.
267
+ # The headline: by iteration 30 more than 60% of training is model-generated,
268
+ # and those questions are 95-100% solvable and genuinely novel.
269
+ # ══════════════════════════════════════════════════════════════════════════════
270
+
271
+ def plot_selfplay_success(rows: List[Dict], out: Path):
272
+ sp_rows = [r for r in rows if f(r, "q_reward") > 0]
273
+ iters = [int(r["iteration"]) for r in sp_rows]
274
+ sp_rat = [f(r, "sp_ratio") * 100 for r in sp_rows]
275
+ q_sol = [f(r, "q_solvability") * 100 for r in sp_rows]
276
+ q_nov = [f(r, "q_novelty") * 100 for r in sp_rows]
277
+ q_rew = [f(r, "q_reward") * 100 for r in sp_rows]
278
+
279
+ fig, ax1 = plt.subplots(figsize=(10, 5))
280
+ ax2 = ax1.twinx()
281
+ ax2.tick_params(axis="y", labelcolor=PALETTE["slate"])
282
+ ax2.spines["right"].set_color(PALETTE["slate"])
283
+
284
+ # self-play ramp (left axis)
285
+ ax1.fill_between(iters, sp_rat, alpha=0.18, color=PALETTE["emerald"])
286
+ ax1.plot(iters, sp_rat, "o-", color=PALETTE["emerald"], ms=6,
287
+ label="Self-play ratio", linewidth=2.5)
288
+ ax1.set_ylabel("Self-play share of training (%)", color=PALETTE["emerald"])
289
+ ax1.tick_params(axis="y", labelcolor=PALETTE["emerald"])
290
+ ax1.set_ylim(0, 80)
291
+
292
+ # question quality (right axis)
293
+ ax2.plot(iters, q_sol, "s--", color=PALETTE["cyan"], ms=5, label="Solvability", linewidth=1.8)
294
+ ax2.plot(iters, q_nov, "^--", color=PALETTE["amber"], ms=5, label="Novelty", linewidth=1.8)
295
+ ax2.plot(iters, q_rew, "D--", color=PALETTE["pink"], ms=5, label="Q-Reward", linewidth=1.8)
296
+ ax2.set_ylabel("Question quality score (%)", color=PALETTE["slate"])
297
+ ax2.set_ylim(0, 115)
298
+
299
+ # merge legends
300
+ h1, l1 = ax1.get_legend_handles_labels()
301
+ h2, l2 = ax2.get_legend_handles_labels()
302
+ ax1.legend(h1 + h2, l1 + l2, loc="upper left", framealpha=0.8)
303
+
304
+ ax1.set_xlabel("Training Iteration")
305
+ ax1.set_xticks(iters)
306
+ ax1.yaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter("%.0f%%"))
307
+ ax2.yaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter("%.0f%%"))
308
+
309
+ # annotate final sp ratio
310
+ ax1.annotate(f" {sp_rat[-1]:.0f}% self-play\n by iter {iters[-1]}",
311
+ xy=(iters[-1], sp_rat[-1]), fontsize=9, color=PALETTE["emerald"],
312
+ va="center", ha="left")
313
+
314
+ fig.suptitle("Self-Play Curriculum β€” The Model Earns Its Own Training Data",
315
+ fontsize=14, fontweight="bold", color=PALETTE["white"], y=1.01)
316
+ ax1.set_title(
317
+ "Self-play ratio ramps from 0 β†’ 61% Β· Generated questions stay 93-100% solvable throughout",
318
+ fontsize=9, color=PALETTE["slate"], pad=6,
319
+ )
320
+ fig.tight_layout()
321
+ save(fig, "plot3_selfplay_success.png", out)
322
+
323
+
324
+ # ══════════════════════════════════════════════════════════════════════════════
325
+ # PLOT 4 β€” Reward Signal Tightening: mean Β± std over 30 iterations
326
+ # As the policy learns what "good" looks like, the spread between the best
327
+ # and worst solutions in a group narrows. Lower variance = more consistent
328
+ # reasoning, not lucky guessing.
329
+ # ══════════════════════════════════════════════���═══════════════════════════════
330
+
331
+ def plot_reward_confidence(rows: List[Dict], out: Path):
332
+ iters = [int(r["iteration"]) for r in rows]
333
+ phases = [r["training_phase"] for r in rows]
334
+ mean_r = np.array([f(r, "mean_reward") for r in rows])
335
+ std_r = np.array([f(r, "std_reward") for r in rows])
336
+ skipped = np.array([f(r, "skipped_groups", 0) for r in rows])
337
+ n_grps = np.array([f(r, "n_groups", 1) for r in rows])
338
+ skip_rt = skipped / np.maximum(n_grps, 1) * 100
339
+
340
+ fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(11, 7), sharex=True,
341
+ gridspec_kw={"height_ratios": [3, 1.2]})
342
+ fig.suptitle("Reward Confidence β€” Mean Β± Std & Skipped Groups Over 30 Iterations",
343
+ fontsize=14, fontweight="bold", color=PALETTE["white"], y=1.01)
344
+
345
+ shade_phases(ax1, iters, phases)
346
+
347
+ ax1.fill_between(iters, (mean_r - std_r) * 100, (mean_r + std_r) * 100,
348
+ alpha=0.20, color=PALETTE["indigo"])
349
+ ax1.plot(iters, mean_r * 100, color=PALETTE["indigo"], linewidth=2.5, label="Mean reward")
350
+ ax1.plot(iters, (mean_r - std_r) * 100, "--", color=PALETTE["slate"], linewidth=1,
351
+ alpha=0.6, label="Β±1 std")
352
+ ax1.plot(iters, (mean_r + std_r) * 100, "--", color=PALETTE["slate"], linewidth=1,
353
+ alpha=0.6)
354
+
355
+ # highlight the two tight-cluster peaks
356
+ for special_iter, label in [(11, "iter 11\nstd=0.098"), (22, "iter 22\nstd=0.124")]:
357
+ si = iters.index(special_iter)
358
+ ax1.annotate(label,
359
+ xy=(special_iter, (mean_r[si] + std_r[si]) * 100),
360
+ xytext=(special_iter + 1, (mean_r[si] + std_r[si]) * 100 + 2),
361
+ fontsize=8, color=PALETTE["amber"],
362
+ arrowprops=dict(arrowstyle="->", color=PALETTE["amber"], lw=1.2))
363
+
364
+ ax1.set_ylabel("Reward (%)")
365
+ ax1.set_ylim(55, 115)
366
+ ax1.yaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter("%.0f%%"))
367
+ h1, l1 = ax1.get_legend_handles_labels()
368
+ ax1.legend(handles=h1 + phase_legend_patches(phases), framealpha=0.8, ncol=3)
369
+
370
+ # skip-rate bar chart (bottom panel)
371
+ shade_phases(ax2, iters, phases)
372
+ ax2.bar(iters, skip_rt, color=PALETTE["red"], alpha=0.7, width=0.7, label="Skipped groups %")
373
+ ax2.set_ylabel("Skipped\ngroups (%)")
374
+ ax2.set_xlabel("Training Iteration")
375
+ ax2.set_ylim(0, 75)
376
+ ax2.set_xticks(range(1, max(iters) + 1, 2))
377
+ ax2.yaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter("%.0f%%"))
378
+ ax2.legend(loc="upper right", framealpha=0.8)
379
+
380
+ fig.tight_layout()
381
+ save(fig, "plot4_reward_confidence.png", out)
382
+
383
+
384
+ # ══════════════════════════════════════════════════════════════════════════════
385
+ # PLOT 5 β€” Step-Level Reasoning Quality: train vs eval
386
+ # Breaks down the two signals that measure HOW the model thinks (not just
387
+ # whether it gets the final answer right): step accuracy and LCCP.
388
+ # Train lines are noisy; eval lines show clean upward trends.
389
+ # ══════════════════════════════════════════════════════════════════════════════
390
+
391
+ def plot_reasoning_quality(rows: List[Dict], out: Path):
392
+ iters = [int(r["iteration"]) for r in rows]
393
+ phases = [r["training_phase"] for r in rows]
394
+
395
+ # training
396
+ t_step = [f(r, "step_accuracy") * 100 for r in rows]
397
+ t_lccp = [f(r, "lccp") * 100 for r in rows]
398
+ t_gt = [f(r, "gt_match_rate") * 100 for r in rows]
399
+
400
+ # eval (only at checkpoint iters)
401
+ eval_rows = [r for r in rows if r.get("eval_combined", "") != ""]
402
+ e_iters = [int(r["iteration"]) for r in eval_rows]
403
+ e_step = [f(r, "eval_step_acc") * 100 for r in eval_rows]
404
+ e_lccp = [f(r, "eval_lccp") * 100 for r in eval_rows]
405
+
406
+ # moving averages
407
+ ma_step = moving_avg(t_step, w=4)
408
+ ma_lccp = moving_avg(t_lccp, w=4)
409
+ ma_gt = moving_avg(t_gt, w=4)
410
+
411
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(13, 5.5))
412
+ fig.suptitle("Step-Level Reasoning Quality β€” Training vs Held-Out Evaluation",
413
+ fontsize=14, fontweight="bold", color=PALETTE["white"], y=1.01)
414
+
415
+ # ── LEFT: step accuracy ──
416
+ shade_phases(ax1, iters, phases)
417
+ ax1.plot(iters, t_step, alpha=0.2, color=PALETTE["cyan"], linewidth=1)
418
+ ax1.plot(iters, ma_step, color=PALETTE["cyan"], linewidth=2.5, label="Train step acc (smooth)")
419
+ ax1.plot(iters, t_gt, alpha=0.15, color=PALETTE["pink"], linewidth=1)
420
+ ax1.plot(iters, ma_gt, color=PALETTE["pink"], linewidth=2.5, label="Train GT match (smooth)")
421
+ ax1.plot(e_iters, e_step, "o-", color=PALETTE["white"], ms=8, linewidth=2,
422
+ label="Eval step accuracy", zorder=6)
423
+
424
+ # annotate eval start/end
425
+ ax1.annotate(f"{e_step[0]:.1f}%", xy=(e_iters[0], e_step[0]),
426
+ xytext=(e_iters[0] - 0.3, e_step[0] - 1.2), fontsize=8.5,
427
+ color=PALETTE["white"], ha="right")
428
+ ax1.annotate(f"{e_step[-1]:.1f}%", xy=(e_iters[-1], e_step[-1]),
429
+ xytext=(e_iters[-1] + 0.3, e_step[-1] + 0.5), fontsize=8.5,
430
+ color=PALETTE["white"])
431
+ ax1.annotate("", xy=(e_iters[-1], e_step[-1]),
432
+ xytext=(e_iters[0], e_step[0]),
433
+ arrowprops=dict(arrowstyle="->", color=PALETTE["cyan"], lw=1.5,
434
+ connectionstyle="arc3,rad=-0.3"))
435
+
436
+ ax1.set_title("Step Accuracy β€” Did each reasoning step hold up?",
437
+ fontsize=9.5, color=PALETTE["slate"], pad=5)
438
+ ax1.set_xlabel("Training Iteration")
439
+ ax1.set_ylabel("Score (%)")
440
+ ax1.set_ylim(55, 105)
441
+ ax1.set_xticks(range(1, max(iters) + 1, 3))
442
+ ax1.yaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter("%.0f%%"))
443
+ ax1.legend(handles=ax1.get_legend_handles_labels()[0] + phase_legend_patches(phases),
444
+ framealpha=0.8, ncol=1, loc="lower right")
445
+
446
+ # ── RIGHT: LCCP ──
447
+ shade_phases(ax2, iters, phases)
448
+ ax2.plot(iters, t_lccp, alpha=0.2, color=PALETTE["emerald"], linewidth=1)
449
+ ax2.plot(iters, ma_lccp, color=PALETTE["emerald"], linewidth=2.5, label="Train LCCP (smooth)")
450
+ ax2.plot(e_iters, e_lccp, "o-", color=PALETTE["white"], ms=8, linewidth=2,
451
+ label="Eval LCCP", zorder=6)
452
+
453
+ ax2.annotate(f"{e_lccp[0]:.1f}%", xy=(e_iters[0], e_lccp[0]),
454
+ xytext=(e_iters[0] - 0.3, e_lccp[0] - 1.5), fontsize=8.5,
455
+ color=PALETTE["white"], ha="right")
456
+ ax2.annotate(f"{e_lccp[-1]:.1f}%", xy=(e_iters[-1], e_lccp[-1]),
457
+ xytext=(e_iters[-1] + 0.3, e_lccp[-1] + 0.5), fontsize=8.5,
458
+ color=PALETTE["white"])
459
+
460
+ # show LCCP delta
461
+ delta = e_lccp[-1] - e_lccp[0]
462
+ ax2.text(0.97, 0.06,
463
+ f"Eval LCCP Ξ” = +{delta:.2f}pp\n(iter {e_iters[0]} β†’ {e_iters[-1]})",
464
+ transform=ax2.transAxes, ha="right", va="bottom",
465
+ fontsize=8.5, color=PALETTE["emerald"],
466
+ bbox=dict(facecolor=PALETTE["bg2"], edgecolor=PALETTE["emerald"],
467
+ linewidth=0.8, pad=5))
468
+
469
+ ax2.set_title("LCCP β€” Did the chain of reasoning stay correct until the first error?",
470
+ fontsize=9.5, color=PALETTE["slate"], pad=5)
471
+ ax2.set_xlabel("Training Iteration")
472
+ ax2.set_ylabel("LCCP (%)")
473
+ ax2.set_ylim(55, 100)
474
+ ax2.set_xticks(range(1, max(iters) + 1, 3))
475
+ ax2.yaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter("%.0f%%"))
476
+ ax2.legend(handles=ax2.get_legend_handles_labels()[0] + phase_legend_patches(phases),
477
+ framealpha=0.8, ncol=1, loc="lower right")
478
+
479
+ fig.tight_layout()
480
+ save(fig, "plot5_reasoning_quality.png", out)
481
+
482
+
483
+ # ══════════════════════════════════════════════════════════════════════════════
484
+ # Main
485
+ # ══════════════════════════════════════════════════════════════════════════════
486
+
487
+ def parse_args():
488
+ p = argparse.ArgumentParser(description="Generate AxiomForgeAI training plots")
489
+ p.add_argument("--metrics", default=DEFAULT_METRICS,
490
+ help=f"Path to metrics.csv (default: {DEFAULT_METRICS})")
491
+ p.add_argument("--out", default="images",
492
+ help="Output directory for PNGs (default: images/)")
493
+ return p.parse_args()
494
+
495
+
496
+ def main():
497
+ args = parse_args()
498
+ out = Path(args.out)
499
+
500
+ print(f"Loading metrics from : {args.metrics}")
501
+ print(f"Saving plots to : {out}/")
502
+ print()
503
+
504
+ rows = load_csv(args.metrics)
505
+ print(f"Loaded {len(rows)} iterations.\n")
506
+
507
+ print("Generating plots …")
508
+ plot_eval_quality(rows, out)
509
+ plot_training_journey(rows, out)
510
+ plot_selfplay_success(rows, out)
511
+ plot_reward_confidence(rows, out)
512
+ plot_reasoning_quality(rows, out)
513
+
514
+ print(f"\nβœ… All 5 plots saved to {out}/")
515
+ print("\nFiles:")
516
+ for p in sorted(out.glob("plot*.png")):
517
+ print(f" {p} ({p.stat().st_size // 1024} KB)")
518
+
519
+
520
+ if __name__ == "__main__":
521
+ main()
scripts/precompute_extraction_cache.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Offline step-chain extraction cache builder.
3
+
4
+ Run this once before training to pre-extract structured step chains from all
5
+ grounded training data (GSM8K + MATH). The resulting cache file is passed to
6
+ run_grpo_training.py via --extraction-cache so the extractor LLM is never
7
+ called for fixed training examples β€” only novel self-play solutions require
8
+ live extraction during training.
9
+
10
+ Usage
11
+ -----
12
+ python scripts/precompute_extraction_cache.py \\
13
+ --gsm8k-data data/sft/gsm8k_sft.jsonl \\
14
+ --math-data data/sft/math_sft.jsonl \\
15
+ --output-cache data/extraction_cache.json \\
16
+ --extractor-model Qwen/Qwen2.5-0.5B-Instruct \\
17
+ --device cuda
18
+
19
+ Cache key: md5(question + "\\n" + solution) β€” keying on both prevents
20
+ collisions when two MATH problems share identical solution text.
21
+ Entries for solutions the extractor cannot parse are stored with
22
+ success=False so training never re-attempts and correctly penalises them.
23
+ """
24
+
25
+ from __future__ import annotations
26
+
27
+ import argparse
28
+ import json
29
+ import logging
30
+ import pathlib
31
+ import sys
32
+ from typing import List, Tuple
33
+
34
+ logging.basicConfig(
35
+ level=logging.INFO,
36
+ format="%(asctime)s %(levelname)-8s %(message)s",
37
+ handlers=[logging.StreamHandler(sys.stdout)],
38
+ )
39
+ logger = logging.getLogger(__name__)
40
+
41
+
42
+ def load_jsonl(path: str) -> list[dict]:
43
+ records: list[dict] = []
44
+ with open(path, encoding="utf-8") as f:
45
+ for line in f:
46
+ line = line.strip()
47
+ if line:
48
+ try:
49
+ records.append(json.loads(line))
50
+ except json.JSONDecodeError:
51
+ pass
52
+ return records
53
+
54
+
55
+ def collect_qa_pairs(records: list[dict]) -> List[Tuple[str, str]]:
56
+ """
57
+ Extract (question, solution) pairs from dataset records.
58
+
59
+ Returns pairs where both fields are non-empty. Falls back to empty
60
+ string for the question when only the solution field is present.
61
+ """
62
+ pairs: List[Tuple[str, str]] = []
63
+ for rec in records:
64
+ sol = (
65
+ rec.get("solution")
66
+ or rec.get("output")
67
+ or rec.get("response")
68
+ or ""
69
+ )
70
+ q = (
71
+ rec.get("question")
72
+ or rec.get("problem")
73
+ or rec.get("input")
74
+ or ""
75
+ )
76
+ if sol.strip():
77
+ pairs.append((q.strip(), sol.strip()))
78
+ return pairs
79
+
80
+
81
+ def main() -> None:
82
+ parser = argparse.ArgumentParser(
83
+ description="Pre-extract step chains for grounded training data."
84
+ )
85
+ parser.add_argument(
86
+ "--gsm8k-data", required=True,
87
+ help="Path to GSM8K training JSONL (e.g. data/sft/gsm8k_sft.jsonl).",
88
+ )
89
+ parser.add_argument(
90
+ "--math-data", default=None,
91
+ help="Optional path to MATH training JSONL. If provided, those solutions "
92
+ "are also extracted and added to the cache.",
93
+ )
94
+ parser.add_argument(
95
+ "--output-cache", required=True,
96
+ help="Destination JSON file for the extraction cache.",
97
+ )
98
+ parser.add_argument(
99
+ "--extractor-model", default="Qwen/Qwen2.5-0.5B-Instruct",
100
+ help="HuggingFace model ID for the step chain extractor. Default Qwen/Qwen2.5-0.5B-Instruct.",
101
+ )
102
+ parser.add_argument(
103
+ "--device", default="cuda",
104
+ help="Device for the extractor model (default: cuda).",
105
+ )
106
+ parser.add_argument(
107
+ "--batch-size", type=int, default=1,
108
+ help="Reserved for future batched extraction. Currently always 1.",
109
+ )
110
+ args = parser.parse_args()
111
+
112
+ # ── Load data ─────────────────────────────────────────────────────────────
113
+ logger.info("Loading GSM8K data from: %s", args.gsm8k_data)
114
+ gsm8k_records = load_jsonl(args.gsm8k_data)
115
+ qa_pairs = collect_qa_pairs(gsm8k_records)
116
+ logger.info("GSM8K: %d (question, solution) pairs", len(qa_pairs))
117
+
118
+ if args.math_data:
119
+ logger.info("Loading MATH data from: %s", args.math_data)
120
+ math_records = load_jsonl(args.math_data)
121
+ math_pairs = collect_qa_pairs(math_records)
122
+ logger.info("MATH: %d (question, solution) pairs", len(math_pairs))
123
+ qa_pairs += math_pairs
124
+
125
+ if not qa_pairs:
126
+ logger.error(
127
+ "No solutions found in provided files. "
128
+ "Check field names (question/problem/input + solution/output/response)."
129
+ )
130
+ sys.exit(1)
131
+
132
+ # Deduplicate by (question, solution) content
133
+ # Two different MATH problems can have identical solution text but different
134
+ # questions β€” the question+solution key keeps them distinct in the cache.
135
+ seen: set = set()
136
+ unique_pairs: List[Tuple[str, str]] = []
137
+ for q, sol in qa_pairs:
138
+ key = (q, sol)
139
+ if key not in seen:
140
+ seen.add(key)
141
+ unique_pairs.append((q, sol))
142
+
143
+ logger.info(
144
+ "Total: %d pairs (%d unique after dedup)", len(qa_pairs), len(unique_pairs)
145
+ )
146
+
147
+ # ── Load extractor ────────────────────────────────────────────────────────
148
+ sys.path.insert(0, str(pathlib.Path(__file__).parent.parent))
149
+ from src.rl.unified_accuracy import StepChainExtractor
150
+
151
+ extractor = StepChainExtractor(
152
+ model_name=args.extractor_model,
153
+ device=args.device,
154
+ cache_path=args.output_cache, # load existing cache if present (resume)
155
+ )
156
+
157
+ # ── Build cache ───────────────────────────────────────────────────────────
158
+ already_cached = len(extractor._cache)
159
+ if already_cached:
160
+ logger.info("Resuming: %d entries already in cache", already_cached)
161
+
162
+ extractor.build_cache(unique_pairs)
163
+
164
+ # ── Save ──────────────────────────────────────────────────────────────────
165
+ extractor.save_cache()
166
+ logger.info(
167
+ "Done. Cache contains %d entries β†’ %s",
168
+ len(extractor._cache),
169
+ args.output_cache,
170
+ )
171
+
172
+
173
+ if __name__ == "__main__":
174
+ main()
scripts/prepare_aqua_dataset.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Download Chinar/AQuA-RAT from HuggingFace and convert it to the same JSONL
4
+ format used by gsm8k_sft.jsonl so the GRPO training script can consume it
5
+ directly via --gsm8k-data.
6
+
7
+ Chinar/AQuA-RAT schema (processed version)
8
+ -------------------------------------------
9
+ prompt : str β€” the math question
10
+ completion : str β€” step-by-step reasoning ending with:
11
+ "The answer is X . Therefore, the correct answer is: <value>"
12
+
13
+ Output schema (messages format expected by load_gsm8k)
14
+ -------------------------------------------------------
15
+ {
16
+ "id": "aqua_<idx>",
17
+ "skill_id": "aqua_rat_algebra",
18
+ "source": "Chinar/AQuA-RAT",
19
+ "split": "train" | "validation",
20
+ "messages": [
21
+ {"role": "system", "content": SOLVER_SYSTEM_PROMPT},
22
+ {"role": "user", "content": "Solve ... Problem:\\n<question>"},
23
+ {"role": "assistant", "content": "Step 1: ...\\nFinal Answer: <value>"}
24
+ ]
25
+ }
26
+
27
+ The dataset has only a 'train' split β€” we reserve the last 500 rows as
28
+ a validation set and use the rest for training.
29
+
30
+ Usage
31
+ -----
32
+ python scripts/prepare_aqua_dataset.py
33
+ python scripts/prepare_aqua_dataset.py --val-size 300 --dry-run
34
+ """
35
+
36
+ from __future__ import annotations
37
+
38
+ import argparse
39
+ import json
40
+ import re
41
+ import sys
42
+ from pathlib import Path
43
+ from typing import Any, Optional
44
+
45
+ # ---------------------------------------------------------------------------
46
+ # Prompt constants (kept in sync with src/config/prompts.py)
47
+ # ---------------------------------------------------------------------------
48
+
49
+ SOLVER_SYSTEM_PROMPT = (
50
+ "You are a step-by-step math solver. "
51
+ "Solve the given problem one step at a time. "
52
+ "Each step must be on its own line, starting with 'Step N:'. "
53
+ "End with a line starting with 'Final Answer:'. "
54
+ "Write every mathematical expression in Python/SymPy syntax "
55
+ "so it can be verified programmatically."
56
+ )
57
+
58
+ USER_WRAPPER = (
59
+ "Solve the following problem. Show your reasoning as numbered steps, "
60
+ "then give the final numeric answer on the last line.\n\nProblem:\n{question}"
61
+ )
62
+
63
+ # ---------------------------------------------------------------------------
64
+ # Answer extraction
65
+ # ---------------------------------------------------------------------------
66
+
67
+ # The completion always ends with a variant of:
68
+ # "The answer is E . Therefore, the correct answer is: 23"
69
+ _ANSWER_TAIL = re.compile(
70
+ r"(?:The answer is\s+[A-Ea-e]\s*[.\-]?\s*)?"
71
+ r"Therefore,?\s+the correct answer is\s*:?\s*(.+)$",
72
+ re.IGNORECASE,
73
+ )
74
+
75
+
76
+ def _extract_answer_and_rationale(completion: str) -> Optional[tuple[str, str]]:
77
+ """
78
+ Split the completion into (rationale_lines, final_answer_str).
79
+ Returns None if no extractable numeric answer is found.
80
+ """
81
+ # Find the tail marker
82
+ m = _ANSWER_TAIL.search(completion)
83
+ if not m:
84
+ return None
85
+
86
+ raw_answer = m.group(1).strip()
87
+ # Everything before the tail is the rationale
88
+ rationale = completion[: m.start()].strip()
89
+ # Also strip a standalone "The answer is X ." line at the end of rationale
90
+ rationale = re.sub(r"\s*The answer is\s+[A-Ea-e]\s*[.\-]?\s*$", "", rationale, flags=re.IGNORECASE).strip()
91
+
92
+ # Normalise the answer to a clean numeric string
93
+ final_answer = _normalise_answer(raw_answer)
94
+ if final_answer is None:
95
+ return None
96
+
97
+ return rationale, final_answer
98
+
99
+
100
+ def _normalise_answer(raw: str) -> Optional[str]:
101
+ """
102
+ Extract a single numeric value from an answer string.
103
+
104
+ "23" β†’ "23"
105
+ "$ 1600" β†’ "1600"
106
+ "8 seconds" β†’ "8"
107
+ "5 and 1" β†’ None (multi-value β€” skip)
108
+ "I and II" β†’ None (non-numeric β€” skip)
109
+ "βˆ’ 3 ≀ x ≀ 4" β†’ None (inequality β€” skip)
110
+ """
111
+ text = raw.strip()
112
+
113
+ # Remove currency / whitespace
114
+ text = text.replace("$", "").replace("Rs.", "").replace("Rs", "").replace(",", "").strip()
115
+
116
+ # Handle unicode minus
117
+ text = text.replace("\u2212", "-").replace("βˆ’", "-")
118
+
119
+ # Skip if "and" still present (multi-value like "5 and 1")
120
+ if re.search(r"\band\b", text, re.IGNORECASE):
121
+ return None
122
+
123
+ # Skip inequalities / expressions with variables
124
+ if re.search(r"[a-zA-Z≀β‰₯<>]", text):
125
+ return None
126
+
127
+ # Single number (integer or decimal, optionally negative)
128
+ m = re.fullmatch(r"\s*(-?\d+(?:\.\d+)?)\s*(?:[a-zA-Z%Β°].*)?", text)
129
+ if m:
130
+ val_str = m.group(1)
131
+ try:
132
+ val = float(val_str)
133
+ return str(int(val)) if val == int(val) else val_str
134
+ except ValueError:
135
+ pass
136
+
137
+ return None
138
+
139
+
140
+ # ---------------------------------------------------------------------------
141
+ # Rationale β†’ Step N: format
142
+ # ---------------------------------------------------------------------------
143
+
144
+ def _rationale_to_steps(rationale: str) -> list[str]:
145
+ lines: list[str] = []
146
+ for raw in rationale.splitlines():
147
+ line = raw.strip()
148
+ if line:
149
+ line = line.replace("^", "**")
150
+ lines.append(line)
151
+ if not lines and rationale.strip():
152
+ sentences = re.split(r"(?<=[.!?])\s+", rationale.strip())
153
+ lines = [s.strip() for s in sentences if s.strip()]
154
+ return lines
155
+
156
+
157
+ def _build_assistant(rationale: str, final_answer: str) -> str:
158
+ steps = _rationale_to_steps(rationale)
159
+ parts = [f"Step {i}: {line}" for i, line in enumerate(steps, 1)]
160
+ body = "\n".join(parts)
161
+ return f"{body}\nFinal Answer: {final_answer}" if body else f"Final Answer: {final_answer}"
162
+
163
+
164
+ # ---------------------------------------------------------------------------
165
+ # Row conversion
166
+ # ---------------------------------------------------------------------------
167
+
168
+ def convert_row(row: dict[str, Any], idx: int, split: str) -> Optional[dict[str, Any]]:
169
+ question = (row.get("prompt") or "").strip()
170
+ completion = (row.get("completion") or "").strip()
171
+
172
+ if not question or not completion:
173
+ return None
174
+
175
+ result = _extract_answer_and_rationale(completion)
176
+ if result is None:
177
+ return None
178
+
179
+ rationale, final_answer = result
180
+ assistant_text = _build_assistant(rationale, final_answer)
181
+
182
+ return {
183
+ "id": f"aqua_{split}_{idx}",
184
+ "skill_id": "aqua_rat_algebra",
185
+ "source": "Chinar/AQuA-RAT",
186
+ "split": split,
187
+ "messages": [
188
+ {"role": "system", "content": SOLVER_SYSTEM_PROMPT},
189
+ {"role": "user", "content": USER_WRAPPER.format(question=question)},
190
+ {"role": "assistant", "content": assistant_text},
191
+ ],
192
+ }
193
+
194
+
195
+ # ---------------------------------------------------------------------------
196
+ # Main
197
+ # ---------------------------------------------------------------------------
198
+
199
+ def main() -> None:
200
+ parser = argparse.ArgumentParser()
201
+ parser.add_argument("--output-dir", default="data/sft")
202
+ parser.add_argument("--val-size", type=int, default=500,
203
+ help="How many rows from the end of the dataset to use as validation.")
204
+ parser.add_argument("--dry-run", action="store_true")
205
+ parser.add_argument("--max-samples", type=int, default=None)
206
+ args = parser.parse_args()
207
+
208
+ try:
209
+ from datasets import load_dataset
210
+ except ImportError:
211
+ print("ERROR: pip install datasets", file=sys.stderr)
212
+ sys.exit(1)
213
+
214
+ print("Downloading Chinar/AQuA-RAT …")
215
+ ds = load_dataset("Chinar/AQuA-RAT")
216
+ all_rows = list(ds["train"])
217
+ total = len(all_rows)
218
+ print(f" Total rows: {total:,}")
219
+
220
+ val_rows = all_rows[-args.val_size:]
221
+ train_rows = all_rows[: -args.val_size]
222
+
223
+ splits = {
224
+ "train": train_rows,
225
+ "validation": val_rows,
226
+ }
227
+
228
+ out_dir = Path(args.output_dir)
229
+ out_dir.mkdir(parents=True, exist_ok=True)
230
+
231
+ for split, rows in splits.items():
232
+ if args.max_samples:
233
+ rows = rows[: args.max_samples]
234
+
235
+ records: list[dict] = []
236
+ skipped = 0
237
+ for idx, row in enumerate(rows):
238
+ rec = convert_row(row, idx, split)
239
+ if rec is None:
240
+ skipped += 1
241
+ else:
242
+ records.append(rec)
243
+
244
+ skip_pct = 100.0 * skipped / max(1, len(rows))
245
+
246
+ if args.dry_run:
247
+ print(f"\n── {split}: {len(records)} valid / {skipped} skipped ({skip_pct:.1f}%) ──")
248
+ for rec in records[:3]:
249
+ print(json.dumps(rec, indent=2))
250
+ continue
251
+
252
+ out_path = out_dir / f"aqua_{split}.jsonl"
253
+ with out_path.open("w", encoding="utf-8") as f:
254
+ for rec in records:
255
+ f.write(json.dumps(rec, ensure_ascii=False) + "\n")
256
+
257
+ print(f" [{split:12s}] {len(records):6,d} valid {skipped:5,d} skipped ({skip_pct:.1f}%) β†’ {out_path}")
258
+
259
+ if not args.dry_run:
260
+ print("\nDone. Launch continuation training with:")
261
+ print(" bash launch_grpo_aqua.sh")
262
+
263
+
264
+ if __name__ == "__main__":
265
+ main()
scripts/prepare_combined_dataset.py ADDED
@@ -0,0 +1,711 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Combined dataset pipeline β€” NuminaMath-CoT + OpenMathInstruct-2
4
+ ================================================================
5
+ Downloads, filters, normalises, and merges two large math datasets into a single
6
+ JSONL file (train / val / test) that the GRPO training script can consume directly
7
+ via --gsm8k-data.
8
+
9
+ Why these two datasets
10
+ ----------------------
11
+ NuminaMath-CoT (AI-MO/NuminaMath-CoT)
12
+ 860 K problems. Clean \\boxed{} answers. 7 rich topic categories that map
13
+ directly to ZPD skill_ids. Sources span AMC, AIME, Chinese HS, olympiads,
14
+ and synthetic β€” giving natural difficulty diversity.
15
+
16
+ OpenMathInstruct-2 (nvidia/OpenMathInstruct-2)
17
+ 14 M synthetic problems with step-level CoT. `expected_answer` is pre-verified.
18
+ Diverse surface forms prevent pattern memorisation. We skip any row whose
19
+ problem_source is "gsm8k" (already in prior training).
20
+
21
+ Output schema (identical to gsm8k_sft.jsonl / aqua_train.jsonl)
22
+ ---------------------------------------------------------------
23
+ {
24
+ "id": "<source>_<split>_<idx>",
25
+ "skill_id": "<topic_slug>", ← used by ZPD CurriculumManager
26
+ "source": "<hf_dataset_name>",
27
+ "split": "train" | "val" | "test",
28
+ "difficulty": 1 | 2 | 3, ← 1=easy 2=medium 3=hard (for ZPD)
29
+ "task_type": "solve",
30
+ "messages": [
31
+ {"role": "system", "content": SOLVER_SYSTEM_PROMPT},
32
+ {"role": "user", "content": "Solve ... Problem:\\n<question>"},
33
+ {"role": "assistant", "content": "Step 1: ...\\nFinal Answer: <answer>"}
34
+ ]
35
+ }
36
+
37
+ Usage
38
+ -----
39
+ # Quick test (no download, just show stats)
40
+ python scripts/prepare_combined_dataset.py --dry-run
41
+
42
+ # Full pipeline (default caps: 20 K numina + 15 K openmath)
43
+ python scripts/prepare_combined_dataset.py
44
+
45
+ # Larger run
46
+ python scripts/prepare_combined_dataset.py --max-numina 40000 --max-openmath 30000
47
+
48
+ # Only one source
49
+ python scripts/prepare_combined_dataset.py --skip-openmath
50
+ python scripts/prepare_combined_dataset.py --skip-numina
51
+
52
+ # Custom output dir
53
+ python scripts/prepare_combined_dataset.py --output-dir data/sft/combined
54
+ """
55
+
56
+ from __future__ import annotations
57
+
58
+ import argparse
59
+ import hashlib
60
+ import json
61
+ import logging
62
+ import math
63
+ import random
64
+ import re
65
+ import sys
66
+ from collections import Counter, defaultdict
67
+ from pathlib import Path
68
+ from typing import Any, Dict, Iterator, List, Optional, Tuple
69
+
70
+ logging.basicConfig(
71
+ level=logging.INFO,
72
+ format="%(asctime)s %(levelname)-8s %(message)s",
73
+ datefmt="%H:%M:%S",
74
+ )
75
+ log = logging.getLogger(__name__)
76
+
77
+ # ---------------------------------------------------------------------------
78
+ # Constants β€” kept in sync with src/config/prompts.py
79
+ # ---------------------------------------------------------------------------
80
+
81
+ SOLVER_SYSTEM_PROMPT = (
82
+ "You are a step-by-step math solver. "
83
+ "Solve the given problem one step at a time. "
84
+ "Each step must be on its own line, starting with 'Step N:'. "
85
+ "End with a line starting with 'Final Answer:'. "
86
+ "Write every mathematical expression in Python/SymPy syntax "
87
+ "so it can be verified programmatically."
88
+ )
89
+
90
+ USER_WRAPPER = (
91
+ "Solve the following problem. Show your reasoning as numbered steps, "
92
+ "then give the final numeric answer on the last line.\n\nProblem:\n{question}"
93
+ )
94
+
95
+ # ---------------------------------------------------------------------------
96
+ # Skill-ID mappings (drives ZPD CurriculumManager per-topic mastery)
97
+ # ---------------------------------------------------------------------------
98
+
99
+ # NuminaMath-CoT `type` field β†’ skill_id
100
+ NUMINA_TYPE_TO_SKILL: Dict[str, str] = {
101
+ "algebra": "numina_algebra",
102
+ "intermediate_algebra": "numina_algebra",
103
+ "prealgebra": "numina_prealgebra",
104
+ "number_theory": "numina_number_theory",
105
+ "geometry": "numina_geometry",
106
+ "counting_and_probability": "numina_combinatorics",
107
+ "precalculus": "numina_calculus",
108
+ "calculus": "numina_calculus",
109
+ "statistics": "numina_statistics",
110
+ "probability": "numina_statistics",
111
+ # competition-source buckets (fallback when type not in map above)
112
+ "cn_k12": "numina_algebra",
113
+ "olympiads": "numina_olympiad",
114
+ "amc_aime": "numina_competition",
115
+ "synthetic_math": "numina_synthetic",
116
+ }
117
+
118
+ # NuminaMath source β†’ approximate difficulty (1=easy 2=medium 3=hard)
119
+ NUMINA_SOURCE_DIFFICULTY: Dict[str, int] = {
120
+ "cn_k12": 1,
121
+ "synthetic_math": 2,
122
+ "amc_aime": 2,
123
+ "olympiads": 3,
124
+ }
125
+
126
+ # OpenMathInstruct-2 problem_source β†’ skill_id / difficulty
127
+ OPENMATH_SOURCE_TO_SKILL: Dict[str, str] = {
128
+ "math": "openmath_algebra", # overridden per-row by subject
129
+ "amc_aime_1983_2024": "openmath_competition",
130
+ "synthetic_math": "openmath_synthetic",
131
+ "number_theory": "openmath_number_theory",
132
+ }
133
+
134
+ OPENMATH_SOURCE_DIFFICULTY: Dict[str, int] = {
135
+ "math": 2,
136
+ "amc_aime_1983_2024": 3,
137
+ "synthetic_math": 1,
138
+ }
139
+
140
+ # OpenMathInstruct MATH-subject β†’ skill_id (when problem_source == "math")
141
+ OPENMATH_MATH_SUBJECT_SKILL: Dict[str, str] = {
142
+ "Algebra": "openmath_algebra",
143
+ "Number Theory": "openmath_number_theory",
144
+ "Geometry": "openmath_geometry",
145
+ "Counting & Probability": "openmath_combinatorics",
146
+ "Intermediate Algebra": "openmath_algebra",
147
+ "Prealgebra": "openmath_prealgebra",
148
+ "Precalculus": "openmath_calculus",
149
+ "Calculus": "openmath_calculus",
150
+ }
151
+
152
+ # ---------------------------------------------------------------------------
153
+ # Answer normalisation
154
+ # ---------------------------------------------------------------------------
155
+
156
+ _BOXED_RE = re.compile(r"\\boxed\{((?:[^{}]|\{[^{}]*\})*)\}")
157
+ _LATEX_FRAC = re.compile(r"\\frac\{(\d+)\}\{(\d+)\}")
158
+ _PLAIN_FRAC = re.compile(r"^(-?\d+)\s*/\s*(\d+)$")
159
+ _CURRENCY = re.compile(r"(?:Rs\.?|USD|\$|€|Β£)\s*", re.IGNORECASE)
160
+ _UNICODE_MINUS = str.maketrans({"\u2212": "-", "βˆ’": "-"})
161
+
162
+
163
+ def extract_boxed(text: str) -> Optional[str]:
164
+ """Return the last \\boxed{} contents from a solution string."""
165
+ matches = _BOXED_RE.findall(text)
166
+ return matches[-1].strip() if matches else None
167
+
168
+
169
+ def normalise_numeric(raw: str) -> Optional[str]:
170
+ """
171
+ Convert a raw answer string to a clean numeric string.
172
+
173
+ Returns None for:
174
+ - multi-value answers ("3 and 5")
175
+ - symbolic expressions ("3\\sqrt{2}", "x+1")
176
+ - inequalities
177
+ - fractions where num/den exceed safe range
178
+ """
179
+ text = raw.strip()
180
+
181
+ # Remove currency symbols and commas in numbers
182
+ text = _CURRENCY.sub("", text)
183
+ text = text.replace(",", "").translate(_UNICODE_MINUS).strip()
184
+
185
+ # Skip if still contains words other than units
186
+ if re.search(r"\b(and|or|none|no solution|undefined)\b", text, re.IGNORECASE):
187
+ return None
188
+
189
+ # Skip if contains letters (symbolic)
190
+ if re.search(r"[a-zA-Z]", text):
191
+ return None
192
+
193
+ # Skip inequalities / ranges
194
+ if re.search(r"[≀β‰₯<>]", text):
195
+ return None
196
+
197
+ # Handle LaTeX fractions: \frac{3}{4}
198
+ m = _LATEX_FRAC.fullmatch(text)
199
+ if m:
200
+ num, den = int(m.group(1)), int(m.group(2))
201
+ if den:
202
+ v = num / den
203
+ return str(int(v)) if v == int(v) else f"{v:.4f}"
204
+ return None
205
+
206
+ # Handle plain fractions: 3/4
207
+ m = _PLAIN_FRAC.match(text)
208
+ if m:
209
+ num, den = int(m.group(1)), int(m.group(2))
210
+ if den:
211
+ v = num / den
212
+ return str(int(v)) if v == int(v) else f"{v:.4f}"
213
+ return None
214
+
215
+ # Handle percentage β†’ decimal
216
+ pct = re.fullmatch(r"(-?\d+(?:\.\d+)?)\s*%", text)
217
+ if pct:
218
+ v = float(pct.group(1))
219
+ return str(int(v)) if v == int(v) else f"{v:.4f}"
220
+
221
+ # Plain integer or decimal (possibly negative, possibly with trailing unit like "km")
222
+ m = re.match(r"^\s*(-?\d+(?:\.\d+)?)\s*(?:[^0-9.\s].*)?\s*$", text)
223
+ if m:
224
+ val_str = m.group(1)
225
+ try:
226
+ v = float(val_str)
227
+ return str(int(v)) if v == int(v) else val_str
228
+ except ValueError:
229
+ pass
230
+
231
+ return None
232
+
233
+
234
+ # ---------------------------------------------------------------------------
235
+ # Solution β†’ Step N: format
236
+ # ---------------------------------------------------------------------------
237
+
238
+ _SKIP_LINE_RE = re.compile(
239
+ r"^\s*("
240
+ r"\\boxed\{|"
241
+ r"(Therefore|Thus|Hence|So),?\s+(the\s+)?(final\s+)?answer\s+is|"
242
+ r"The\s+(final\s+)?answer\s+is|"
243
+ r"Answer\s*[:=]"
244
+ r")",
245
+ re.IGNORECASE,
246
+ )
247
+
248
+
249
+ def solution_to_steps(solution: str, final_answer: str, max_steps: int = 18) -> str:
250
+ """
251
+ Convert an arbitrary CoT solution to the pipeline's Step N: format.
252
+
253
+ Strategy:
254
+ 1. Split on newlines.
255
+ 2. Drop blank lines and lines that just announce the final answer
256
+ (those are replaced by the explicit Final Answer: line).
257
+ 3. Strip any existing "Step N:" prefix to avoid double-numbering.
258
+ 4. Re-number as "Step 1:", "Step 2:", …
259
+ 5. Append "Final Answer: <answer>".
260
+ """
261
+ raw_lines = [l.strip() for l in solution.split("\n") if l.strip()]
262
+ clean: List[str] = []
263
+ for line in raw_lines:
264
+ if _SKIP_LINE_RE.match(line):
265
+ continue
266
+ # Strip old step prefix
267
+ line = re.sub(r"^Step\s*\d+\s*[:.)]\s*", "", line)
268
+ if line:
269
+ clean.append(line)
270
+
271
+ # Cap to max_steps to keep token count reasonable
272
+ clean = clean[:max_steps]
273
+
274
+ if not clean:
275
+ return f"Final Answer: {final_answer}"
276
+
277
+ parts = [f"Step {i}: {line}" for i, line in enumerate(clean, 1)]
278
+ return "\n".join(parts) + f"\nFinal Answer: {final_answer}"
279
+
280
+
281
+ # ---------------------------------------------------------------------------
282
+ # Record builders
283
+ # ---------------------------------------------------------------------------
284
+
285
+ def build_record(
286
+ idx: int,
287
+ split: str,
288
+ source_name: str,
289
+ skill_id: str,
290
+ difficulty: int,
291
+ question: str,
292
+ solution_text: str,
293
+ final_answer: str,
294
+ ) -> Dict[str, Any]:
295
+ assistant_content = solution_to_steps(solution_text, final_answer)
296
+ return {
297
+ "id": f"{source_name.replace('/', '_')}_{split}_{idx}",
298
+ "skill_id": skill_id,
299
+ "source": source_name,
300
+ "split": split,
301
+ "difficulty": difficulty,
302
+ "task_type": "solve",
303
+ "messages": [
304
+ {"role": "system", "content": SOLVER_SYSTEM_PROMPT},
305
+ {"role": "user", "content": USER_WRAPPER.format(question=question.strip())},
306
+ {"role": "assistant", "content": assistant_content},
307
+ ],
308
+ }
309
+
310
+
311
+ # ---------------------------------------------------------------------------
312
+ # Deduplication
313
+ # ---------------------------------------------------------------------------
314
+
315
+ def problem_hash(text: str) -> str:
316
+ """Fast 16-char hash for near-dedup (exact-match on normalised text)."""
317
+ normalised = re.sub(r"\s+", " ", text.strip().lower())
318
+ return hashlib.md5(normalised.encode()).hexdigest()[:16]
319
+
320
+
321
+ # ---------------------------------------------------------------------------
322
+ # NuminaMath-CoT processing
323
+ # ---------------------------------------------------------------------------
324
+
325
+ def _numina_skill_and_difficulty(row: Dict) -> Tuple[str, int]:
326
+ topic = (row.get("type") or "").lower().strip()
327
+ source = (row.get("source") or "").lower().strip()
328
+
329
+ skill = NUMINA_TYPE_TO_SKILL.get(topic)
330
+ if skill is None:
331
+ skill = NUMINA_TYPE_TO_SKILL.get(source, "numina_general")
332
+
333
+ difficulty = NUMINA_SOURCE_DIFFICULTY.get(source, 2)
334
+ return skill, difficulty
335
+
336
+
337
+ def iter_numina(
338
+ max_samples: int,
339
+ per_skill_cap: int,
340
+ skip_olympiad: bool,
341
+ seed: int,
342
+ ) -> Iterator[Dict[str, Any]]:
343
+ """
344
+ Stream NuminaMath-CoT from HuggingFace and yield cleaned records.
345
+ Uses per-skill quota to guarantee topic diversity.
346
+ """
347
+ try:
348
+ from datasets import load_dataset # type: ignore
349
+ except ImportError:
350
+ log.error("pip install datasets huggingface_hub")
351
+ sys.exit(1)
352
+
353
+ log.info("Streaming AI-MO/NuminaMath-CoT …")
354
+ ds = load_dataset("AI-MO/NuminaMath-CoT", split="train", streaming=True,
355
+ trust_remote_code=True)
356
+
357
+ skill_counts: Counter = Counter()
358
+ seen_hashes: set = set()
359
+ total_yielded = 0
360
+
361
+ rng = random.Random(seed)
362
+
363
+ for row in ds:
364
+ if total_yielded >= max_samples:
365
+ break
366
+
367
+ problem = (row.get("problem") or "").strip()
368
+ solution = (row.get("solution") or "").strip()
369
+ if not problem or not solution:
370
+ continue
371
+
372
+ # Extract and normalise answer from \boxed{}
373
+ raw_answer = extract_boxed(solution)
374
+ if raw_answer is None:
375
+ continue
376
+ final_answer = normalise_numeric(raw_answer)
377
+ if final_answer is None:
378
+ continue
379
+
380
+ skill, difficulty = _numina_skill_and_difficulty(row)
381
+
382
+ # Optionally skip very hard olympiad problems
383
+ if skip_olympiad and skill == "numina_olympiad":
384
+ continue
385
+
386
+ # Per-skill cap to guarantee diversity
387
+ if skill_counts[skill] >= per_skill_cap:
388
+ continue
389
+
390
+ # Dedup
391
+ h = problem_hash(problem)
392
+ if h in seen_hashes:
393
+ continue
394
+ seen_hashes.add(h)
395
+
396
+ skill_counts[skill] += 1
397
+ total_yielded += 1
398
+
399
+ yield build_record(
400
+ idx=total_yielded,
401
+ split="__assign__",
402
+ source_name="AI-MO/NuminaMath-CoT",
403
+ skill_id=skill,
404
+ difficulty=difficulty,
405
+ question=problem,
406
+ solution_text=solution,
407
+ final_answer=final_answer,
408
+ )
409
+
410
+ log.info("NuminaMath-CoT: yielded %d records | skill dist: %s",
411
+ total_yielded, dict(skill_counts.most_common()))
412
+
413
+
414
+ # ---------------------------------------------------------------------------
415
+ # OpenMathInstruct-2 processing
416
+ # ---------------------------------------------------------------------------
417
+
418
+ def _openmath_skill_and_difficulty(row: Dict) -> Tuple[str, int]:
419
+ src = (row.get("problem_source") or "").lower().strip()
420
+ subj = (row.get("subject") or "").strip()
421
+
422
+ if src == "math" and subj:
423
+ skill = OPENMATH_MATH_SUBJECT_SKILL.get(subj, "openmath_algebra")
424
+ else:
425
+ skill = OPENMATH_SOURCE_TO_SKILL.get(src, "openmath_general")
426
+
427
+ difficulty = OPENMATH_SOURCE_DIFFICULTY.get(src, 2)
428
+ return skill, difficulty
429
+
430
+
431
+ def iter_openmath(
432
+ max_samples: int,
433
+ per_skill_cap: int,
434
+ skip_gsm8k: bool,
435
+ seed: int,
436
+ ) -> Iterator[Dict[str, Any]]:
437
+ """
438
+ Stream OpenMathInstruct-2 from HuggingFace and yield cleaned records.
439
+ Only yields rows where `is_correct_solution` is True (pre-verified by NVIDIA).
440
+ """
441
+ try:
442
+ from datasets import load_dataset # type: ignore
443
+ except ImportError:
444
+ log.error("pip install datasets huggingface_hub")
445
+ sys.exit(1)
446
+
447
+ log.info("Streaming nvidia/OpenMathInstruct-2 (this may take a moment) …")
448
+ ds = load_dataset(
449
+ "nvidia/OpenMathInstruct-2",
450
+ split="train",
451
+ streaming=True,
452
+ trust_remote_code=True,
453
+ )
454
+
455
+ skill_counts: Counter = Counter()
456
+ seen_hashes: set = set()
457
+ total_yielded = 0
458
+
459
+ for row in ds:
460
+ if total_yielded >= max_samples:
461
+ break
462
+
463
+ # Filter: skip gsm8k (contamination risk)
464
+ problem_src = (row.get("problem_source") or "").lower()
465
+ if skip_gsm8k and "gsm8k" in problem_src:
466
+ continue
467
+
468
+ # Filter: only verified correct solutions
469
+ if not row.get("is_correct_solution", True):
470
+ continue
471
+
472
+ problem = (row.get("problem") or "").strip()
473
+ solution = (row.get("generated_solution") or "").strip()
474
+ expected = (row.get("expected_answer") or "").strip()
475
+
476
+ if not problem or not solution or not expected:
477
+ continue
478
+
479
+ # Normalise the pre-extracted answer
480
+ final_answer = normalise_numeric(expected)
481
+ if final_answer is None:
482
+ continue
483
+
484
+ skill, difficulty = _openmath_skill_and_difficulty(row)
485
+
486
+ # Per-skill cap
487
+ if skill_counts[skill] >= per_skill_cap:
488
+ continue
489
+
490
+ # Dedup
491
+ h = problem_hash(problem)
492
+ if h in seen_hashes:
493
+ continue
494
+ seen_hashes.add(h)
495
+
496
+ skill_counts[skill] += 1
497
+ total_yielded += 1
498
+
499
+ yield build_record(
500
+ idx=total_yielded,
501
+ split="__assign__",
502
+ source_name="nvidia/OpenMathInstruct-2",
503
+ skill_id=skill,
504
+ difficulty=difficulty,
505
+ question=problem,
506
+ solution_text=solution,
507
+ final_answer=final_answer,
508
+ )
509
+
510
+ log.info("OpenMathInstruct-2: yielded %d records | skill dist: %s",
511
+ total_yielded, dict(skill_counts.most_common()))
512
+
513
+
514
+ # ---------------------------------------------------------------------------
515
+ # Dataset stats printer
516
+ # ---------------------------------------------------------------------------
517
+
518
+ def print_stats(records: List[Dict], label: str) -> None:
519
+ skill_c: Counter = Counter(r["skill_id"] for r in records)
520
+ diff_c: Counter = Counter(r["difficulty"] for r in records)
521
+ src_c: Counter = Counter(r["source"] for r in records)
522
+ split_c: Counter = Counter(r["split"] for r in records)
523
+
524
+ log.info("─── %s (%d records) ───────────────────────────────", label, len(records))
525
+ log.info(" by split: %s", dict(split_c))
526
+ log.info(" by source: %s", dict(src_c))
527
+ log.info(" by difficulty: %s", dict(sorted(diff_c.items())))
528
+ log.info(" by skill_id:")
529
+ for sk, cnt in skill_c.most_common():
530
+ log.info(" %-40s %5d", sk, cnt)
531
+
532
+
533
+ # ---------------------------------------------------------------------------
534
+ # Write JSONL
535
+ # ---------------------------------------------------------------------------
536
+
537
+ def write_jsonl(records: List[Dict], path: Path) -> None:
538
+ path.parent.mkdir(parents=True, exist_ok=True)
539
+ with path.open("w", encoding="utf-8") as f:
540
+ for rec in records:
541
+ f.write(json.dumps(rec, ensure_ascii=False) + "\n")
542
+ log.info("Wrote %d records β†’ %s", len(records), path)
543
+
544
+
545
+ # ---------------------------------------------------------------------------
546
+ # Train / val / test split (stratified by skill_id)
547
+ # ---------------------------------------------------------------------------
548
+
549
+ def stratified_split(
550
+ records: List[Dict],
551
+ train_frac: float = 0.85,
552
+ val_frac: float = 0.10,
553
+ seed: int = 42,
554
+ ) -> Tuple[List[Dict], List[Dict], List[Dict]]:
555
+ """
556
+ Stratified split by skill_id so every skill appears in all three sets.
557
+ Remaining fraction after train+val goes to test.
558
+ """
559
+ rng = random.Random(seed)
560
+
561
+ by_skill: Dict[str, List[Dict]] = defaultdict(list)
562
+ for r in records:
563
+ by_skill[r["skill_id"]].append(r)
564
+
565
+ train_, val_, test_ = [], [], []
566
+ for skill, items in by_skill.items():
567
+ rng.shuffle(items)
568
+ n = len(items)
569
+ n_train = math.floor(n * train_frac)
570
+ n_val = math.floor(n * val_frac)
571
+ train_ += items[:n_train]
572
+ val_ += items[n_train: n_train + n_val]
573
+ test_ += items[n_train + n_val:]
574
+
575
+ for r in train_: r["split"] = "train"
576
+ for r in val_: r["split"] = "val"
577
+ for r in test_: r["split"] = "test"
578
+
579
+ # Shuffle each split so skill interleaves during training
580
+ rng.shuffle(train_)
581
+ rng.shuffle(val_)
582
+ rng.shuffle(test_)
583
+
584
+ return train_, val_, test_
585
+
586
+
587
+ # ---------------------------------------------------------------------------
588
+ # Main
589
+ # ---------------------------------------------------------------------------
590
+
591
+ def parse_args() -> argparse.Namespace:
592
+ p = argparse.ArgumentParser(
593
+ description="Build combined NuminaMath + OpenMathInstruct-2 training data."
594
+ )
595
+ p.add_argument("--output-dir", default="data/sft",
596
+ help="Directory for output JSONL files.")
597
+ p.add_argument("--max-numina", type=int, default=20_000,
598
+ help="Max records from NuminaMath-CoT (default 20 000).")
599
+ p.add_argument("--max-openmath", type=int, default=15_000,
600
+ help="Max records from OpenMathInstruct-2 (default 15 000).")
601
+ p.add_argument("--per-skill-cap", type=int, default=4_000,
602
+ help="Max records per skill_id to guarantee topic diversity.")
603
+ p.add_argument("--skip-numina", action="store_true",
604
+ help="Skip NuminaMath-CoT entirely.")
605
+ p.add_argument("--skip-openmath", action="store_true",
606
+ help="Skip OpenMathInstruct-2 entirely.")
607
+ p.add_argument("--skip-olympiad", action="store_true", default=True,
608
+ help="Skip numina_olympiad problems (too hard for 1.5B; default: True).")
609
+ p.add_argument("--no-skip-olympiad", dest="skip_olympiad", action="store_false",
610
+ help="Include olympiad-level problems.")
611
+ p.add_argument("--train-frac", type=float, default=0.85)
612
+ p.add_argument("--val-frac", type=float, default=0.10)
613
+ p.add_argument("--seed", type=int, default=42)
614
+ p.add_argument("--dry-run", action="store_true",
615
+ help="Process only 500 rows from each source and show stats (no write).")
616
+ return p.parse_args()
617
+
618
+
619
+ def main() -> None:
620
+ args = parse_args()
621
+ rng = random.Random(args.seed)
622
+
623
+ if args.dry_run:
624
+ args.max_numina = min(args.max_numina, 500)
625
+ args.max_openmath = min(args.max_openmath, 500)
626
+ log.info("DRY RUN β€” capped at 500 samples per source, nothing written to disk.")
627
+
628
+ all_records: List[Dict] = []
629
+
630
+ # ── NuminaMath-CoT ────────────────────────────────────────────────────
631
+ if not args.skip_numina:
632
+ numina_recs = list(iter_numina(
633
+ max_samples = args.max_numina,
634
+ per_skill_cap = args.per_skill_cap,
635
+ skip_olympiad = args.skip_olympiad,
636
+ seed = args.seed,
637
+ ))
638
+ all_records.extend(numina_recs)
639
+ log.info("NuminaMath-CoT collected: %d records", len(numina_recs))
640
+ else:
641
+ log.info("Skipping NuminaMath-CoT (--skip-numina).")
642
+
643
+ # ── OpenMathInstruct-2 ────────────────────────────────────────────────
644
+ if not args.skip_openmath:
645
+ openmath_recs = list(iter_openmath(
646
+ max_samples = args.max_openmath,
647
+ per_skill_cap = args.per_skill_cap,
648
+ skip_gsm8k = True,
649
+ seed = args.seed,
650
+ ))
651
+ all_records.extend(openmath_recs)
652
+ log.info("OpenMathInstruct-2 collected: %d records", len(openmath_recs))
653
+ else:
654
+ log.info("Skipping OpenMathInstruct-2 (--skip-openmath).")
655
+
656
+ if not all_records:
657
+ log.error("No records collected β€” check dataset availability.")
658
+ sys.exit(1)
659
+
660
+ # ── Deduplicate across sources ─────────────────────────────────────────
661
+ seen: set = set()
662
+ deduped: List[Dict] = []
663
+ for r in all_records:
664
+ question = r["messages"][1]["content"]
665
+ h = problem_hash(question)
666
+ if h not in seen:
667
+ seen.add(h)
668
+ deduped.append(r)
669
+
670
+ log.info("After cross-source dedup: %d β†’ %d records (removed %d dupes)",
671
+ len(all_records), len(deduped), len(all_records) - len(deduped))
672
+
673
+ # ── Stratified split ──────────────────────────────────────────────────
674
+ train_recs, val_recs, test_recs = stratified_split(
675
+ deduped, args.train_frac, args.val_frac, args.seed
676
+ )
677
+
678
+ print_stats(train_recs + val_recs + test_recs, "COMBINED DATASET")
679
+
680
+ # ── Write outputs ─────────────────────────────────────────────────────
681
+ if args.dry_run:
682
+ log.info("DRY RUN complete β€” no files written.")
683
+ log.info(" would write: combined_train.jsonl (%d rows)", len(train_recs))
684
+ log.info(" would write: combined_val.jsonl (%d rows)", len(val_recs))
685
+ log.info(" would write: combined_test.jsonl (%d rows)", len(test_recs))
686
+ log.info("Sample record:")
687
+ print(json.dumps(train_recs[0], indent=2, ensure_ascii=False))
688
+ return
689
+
690
+ out = Path(args.output_dir)
691
+ write_jsonl(train_recs, out / "combined_train.jsonl")
692
+ write_jsonl(val_recs, out / "combined_val.jsonl")
693
+ write_jsonl(test_recs, out / "combined_test.jsonl")
694
+
695
+ log.info("")
696
+ log.info("╔══════════════════════════════════════════════════════════════╗")
697
+ log.info("β•‘ Pipeline complete. Next step: β•‘")
698
+ log.info("β•‘ bash launch_grpo_combined.sh β•‘")
699
+ log.info("β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•")
700
+ log.info(" train : %6d rows β†’ %s/combined_train.jsonl", len(train_recs), out)
701
+ log.info(" val : %6d rows β†’ %s/combined_val.jsonl", len(val_recs), out)
702
+ log.info(" test : %6d rows β†’ %s/combined_test.jsonl", len(test_recs), out)
703
+ log.info("")
704
+ log.info("Skill coverage (for ZPD CurriculumManager):")
705
+ skill_c = Counter(r["skill_id"] for r in train_recs)
706
+ for sk, cnt in sorted(skill_c.items()):
707
+ log.info(" %-40s %5d train samples", sk, cnt)
708
+
709
+
710
+ if __name__ == "__main__":
711
+ main()
scripts/run_grpo_training.py ADDED
The diff for this file is too large to render. See raw diff
 
scripts/run_inference.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Inference pipeline: Base Qwen2.5-Math-1.5B-Instruct vs RL fine-tuned checkpoint.
4
+
5
+ For each sampled GSM8K question, both models generate a step-by-step solution.
6
+ Results are saved to reports/<run_name>/ as JSON files for the Gradio demo.
7
+
8
+ Usage
9
+ -----
10
+ # Full run (50 questions, both models):
11
+ python scripts/run_inference.py \\
12
+ --checkpoint checkpoints/grpo_run_v1 \\
13
+ --num-questions 50 \\
14
+ --run-name comparison_v1
15
+
16
+ # Quick smoke test (10 questions, no RL model):
17
+ python scripts/run_inference.py \\
18
+ --num-questions 10 \\
19
+ --base-only \\
20
+ --run-name smoke
21
+
22
+ # Custom data source:
23
+ python scripts/run_inference.py \\
24
+ --checkpoint checkpoints/grpo_run_v1 \\
25
+ --data data/sft/gsm8k_test.jsonl \\
26
+ --num-questions 30
27
+ """
28
+
29
+ from __future__ import annotations
30
+
31
+ import argparse
32
+ import json
33
+ import logging
34
+ import random
35
+ import sys
36
+ import time
37
+ from datetime import datetime
38
+ from pathlib import Path
39
+ from typing import Any, Dict, List, Optional, Tuple
40
+
41
+ import torch
42
+ from tqdm.auto import tqdm
43
+ from transformers import AutoModelForCausalLM, AutoTokenizer
44
+
45
+ sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
46
+
47
+ from src.config.prompts import create_solver_messages
48
+ from src.sft.solution_format import extract_final_answer_numeric_str
49
+ from src.utils.attn_backend import select_attn_implementation
50
+
51
+ logging.basicConfig(
52
+ level=logging.INFO,
53
+ format="%(asctime)s %(levelname)-8s %(name)s - %(message)s",
54
+ )
55
+ logger = logging.getLogger(__name__)
56
+
57
+ BASE_MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct"
58
+ REPORTS_DIR = Path("reports")
59
+
60
+
61
+ # ── Data loading ──────────────────────────────────────────────────────────────
62
+
63
+ def load_gsm8k_questions(
64
+ data_path: Optional[str],
65
+ num_questions: int,
66
+ seed: int = 42,
67
+ ) -> List[Dict[str, str]]:
68
+ """
69
+ Load GSM8K questions from a local JSONL file or fall back to HuggingFace.
70
+
71
+ Each returned record has keys: ``question``, ``gold_final``, ``answer``.
72
+ """
73
+ # ── Try local JSONL first ────────────────────────────────────────────────
74
+ candidates = [data_path] if data_path else []
75
+ candidates += [
76
+ "data/sft/gsm8k_test.jsonl",
77
+ "data/sft/gsm8k_sft.jsonl",
78
+ ]
79
+
80
+ for path in candidates:
81
+ if path and Path(path).exists():
82
+ logger.info("Loading GSM8K from local file: %s", path)
83
+ rows: List[Dict] = []
84
+ with open(path, encoding="utf-8") as f:
85
+ for line in f:
86
+ line = line.strip()
87
+ if line:
88
+ rows.append(json.loads(line))
89
+ rng = random.Random(seed)
90
+ sample = rng.sample(rows, min(num_questions, len(rows)))
91
+ logger.info("Sampled %d / %d questions.", len(sample), len(rows))
92
+ return sample
93
+
94
+ # ── Fall back to HuggingFace datasets ────────────────────────────────────
95
+ logger.info("No local file found β€” downloading GSM8K from HuggingFace…")
96
+ try:
97
+ from datasets import load_dataset
98
+ ds = load_dataset("openai/gsm8k", "main", split="test")
99
+ except Exception as e:
100
+ raise RuntimeError(
101
+ "Could not load GSM8K. Provide --data or install datasets: pip install datasets"
102
+ ) from e
103
+
104
+ rows = []
105
+ for item in ds:
106
+ q = item["question"].strip()
107
+ a = item["answer"].strip()
108
+ # GSM8K answers end with "#### <number>"
109
+ gold = a.split("####")[-1].strip() if "####" in a else ""
110
+ rows.append({"question": q, "gold_final": gold, "answer": a})
111
+
112
+ rng = random.Random(seed)
113
+ sample = rng.sample(rows, min(num_questions, len(rows)))
114
+ logger.info("Sampled %d questions from HF GSM8K test split.", len(sample))
115
+ return sample
116
+
117
+
118
+ # ── Model loading ─────────────────────────────────────────────────────────────
119
+
120
+ def load_base_model(
121
+ device: torch.device,
122
+ attn_impl: str,
123
+ ) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
124
+ logger.info("Loading base model: %s", BASE_MODEL_ID)
125
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, trust_remote_code=True)
126
+ if tokenizer.pad_token is None:
127
+ tokenizer.pad_token = tokenizer.eos_token
128
+ tokenizer.padding_side = "left"
129
+
130
+ model = AutoModelForCausalLM.from_pretrained(
131
+ BASE_MODEL_ID,
132
+ torch_dtype=torch.bfloat16,
133
+ device_map={"": device},
134
+ trust_remote_code=True,
135
+ attn_implementation=attn_impl,
136
+ )
137
+ model.eval()
138
+ logger.info("Base model loaded.")
139
+ return model, tokenizer
140
+
141
+
142
+ def load_rl_model(
143
+ checkpoint: str,
144
+ base_model: AutoModelForCausalLM,
145
+ base_tokenizer: AutoTokenizer,
146
+ device: torch.device,
147
+ attn_impl: str,
148
+ ) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
149
+ """
150
+ Load the RL fine-tuned checkpoint for comparison against the raw base model.
151
+
152
+ Two checkpoint formats are supported:
153
+
154
+ PEFT / LoRA adapter (has adapter_config.json)
155
+ The already-loaded base model weights are deep-copied in CPU memory,
156
+ the adapter is applied on top, then merged and unloaded.
157
+ This avoids downloading the 1.5B base weights from HuggingFace a
158
+ second time β€” the base model is downloaded only once per run.
159
+
160
+ Full saved model (has config.json, no adapter_config.json)
161
+ Loaded directly from disk with from_pretrained.
162
+ """
163
+ import copy
164
+
165
+ ckpt_path = Path(checkpoint)
166
+ if not ckpt_path.exists():
167
+ raise FileNotFoundError(f"Checkpoint not found: {checkpoint}")
168
+
169
+ is_peft = (ckpt_path / "adapter_config.json").exists()
170
+
171
+ if is_peft:
172
+ logger.info(
173
+ "Loading PEFT adapter from %s (reusing base weights β€” no second HF download)",
174
+ checkpoint,
175
+ )
176
+ from peft import PeftModel
177
+
178
+ # Deep-copy the already-loaded base model so the base remains untouched
179
+ # for side-by-side comparison. For a 1.5B bfloat16 model this takes
180
+ # ~1-2 s and avoids re-downloading ~3 GB from HuggingFace.
181
+ base_copy = copy.deepcopy(base_model)
182
+ model = PeftModel.from_pretrained(base_copy, checkpoint)
183
+ model = model.merge_and_unload()
184
+ model = model.to(device)
185
+ else:
186
+ logger.info("Loading full model checkpoint from %s", checkpoint)
187
+ model = AutoModelForCausalLM.from_pretrained(
188
+ checkpoint,
189
+ torch_dtype=torch.bfloat16,
190
+ device_map={"": device},
191
+ trust_remote_code=True,
192
+ attn_implementation=attn_impl,
193
+ )
194
+
195
+ # Patch chat_template from base tokenizer if missing
196
+ tokenizer = AutoTokenizer.from_pretrained(
197
+ checkpoint if (ckpt_path / "tokenizer_config.json").exists() else BASE_MODEL_ID,
198
+ trust_remote_code=True,
199
+ )
200
+ if tokenizer.pad_token is None:
201
+ tokenizer.pad_token = tokenizer.eos_token
202
+ tokenizer.padding_side = "left"
203
+ if tokenizer.chat_template is None and base_tokenizer.chat_template:
204
+ tokenizer.chat_template = base_tokenizer.chat_template
205
+
206
+ model.eval()
207
+ logger.info("RL model loaded.")
208
+ return model, tokenizer
209
+
210
+
211
+ # ── Inference ─────────────────────────────────────────────────────────────────
212
+
213
+ def generate_solution(
214
+ model: AutoModelForCausalLM,
215
+ tokenizer: AutoTokenizer,
216
+ question: str,
217
+ device: torch.device,
218
+ max_new_tokens: int = 512,
219
+ temperature: float = 0.1,
220
+ ) -> Tuple[str, float]:
221
+ """
222
+ Generate a step-by-step solution for ``question``.
223
+
224
+ Returns ``(solution_text, elapsed_seconds)``.
225
+ Low temperature (0.1) for deterministic, greedy-like output during eval.
226
+ """
227
+ messages = create_solver_messages(question)
228
+ prompt = tokenizer.apply_chat_template(
229
+ messages, tokenize=False, add_generation_prompt=True
230
+ )
231
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
232
+ inputs = {k: v.to(device) for k, v in inputs.items()}
233
+ prompt_len = inputs["input_ids"].shape[1]
234
+
235
+ stop_ids = [tokenizer.eos_token_id]
236
+ im_end = tokenizer.convert_tokens_to_ids("<|im_end|>")
237
+ if isinstance(im_end, int) and im_end not in stop_ids:
238
+ stop_ids.append(im_end)
239
+
240
+ t0 = time.time()
241
+ with torch.no_grad():
242
+ output = model.generate(
243
+ **inputs,
244
+ max_new_tokens=max_new_tokens,
245
+ do_sample=temperature > 0.05,
246
+ temperature=temperature if temperature > 0.05 else None,
247
+ top_p=0.95 if temperature > 0.05 else None,
248
+ eos_token_id=stop_ids,
249
+ pad_token_id=tokenizer.pad_token_id,
250
+ use_cache=True,
251
+ )
252
+ elapsed = time.time() - t0
253
+
254
+ response_ids = output[0][prompt_len:]
255
+ solution = tokenizer.decode(response_ids, skip_special_tokens=True).strip()
256
+ return solution, elapsed
257
+
258
+
259
+ def score_answer(solution: str, gold_final: str) -> Dict[str, Any]:
260
+ """
261
+ Extract the predicted final answer and compare with gold.
262
+ Returns a dict with ``predicted``, ``gold``, ``correct``, ``match_type``.
263
+ """
264
+ predicted_raw = extract_final_answer_numeric_str(solution)
265
+
266
+ if predicted_raw is None:
267
+ return {
268
+ "predicted": None,
269
+ "gold": gold_final,
270
+ "correct": False,
271
+ "match_type": "no_answer_found",
272
+ }
273
+
274
+ # Normalise: strip whitespace, remove commas (e.g. "1,200" β†’ "1200")
275
+ def _norm(s: str) -> str:
276
+ return s.strip().replace(",", "").rstrip(".").lower()
277
+
278
+ pred_n = _norm(predicted_raw)
279
+ gold_n = _norm(gold_final)
280
+
281
+ # Direct string match
282
+ if pred_n == gold_n:
283
+ return {
284
+ "predicted": predicted_raw,
285
+ "gold": gold_final,
286
+ "correct": True,
287
+ "match_type": "exact",
288
+ }
289
+
290
+ # Numeric match (handles float/int equivalence)
291
+ try:
292
+ pred_f = float(pred_n)
293
+ gold_f = float(gold_n)
294
+ if abs(pred_f - gold_f) < 1e-6:
295
+ return {
296
+ "predicted": predicted_raw,
297
+ "gold": gold_final,
298
+ "correct": True,
299
+ "match_type": "numeric",
300
+ }
301
+ except (ValueError, TypeError):
302
+ pass
303
+
304
+ return {
305
+ "predicted": predicted_raw,
306
+ "gold": gold_final,
307
+ "correct": False,
308
+ "match_type": "wrong",
309
+ }
310
+
311
+
312
+ # ── Report serialisation ──────────────────────────────────────────────────────
313
+
314
+ def save_question_report(
315
+ report_dir: Path,
316
+ idx: int,
317
+ question: str,
318
+ gold_final: str,
319
+ base_result: Dict[str, Any],
320
+ rl_result: Optional[Dict[str, Any]],
321
+ ) -> Path:
322
+ record = {
323
+ "idx": idx,
324
+ "question": question,
325
+ "gold_final": gold_final,
326
+ "base_model": base_result,
327
+ "rl_model": rl_result,
328
+ }
329
+ out = report_dir / f"q_{idx:04d}.json"
330
+ out.write_text(json.dumps(record, indent=2, ensure_ascii=False), encoding="utf-8")
331
+ return out
332
+
333
+
334
+ def save_summary(
335
+ report_dir: Path,
336
+ run_name: str,
337
+ checkpoint: Optional[str],
338
+ base_correct: int,
339
+ rl_correct: Optional[int],
340
+ total: int,
341
+ total_time_s: float,
342
+ args_dict: Dict,
343
+ ) -> None:
344
+ summary = {
345
+ "run_name": run_name,
346
+ "timestamp": datetime.now().isoformat(),
347
+ "base_model": BASE_MODEL_ID,
348
+ "rl_checkpoint": checkpoint,
349
+ "num_questions": total,
350
+ "base_accuracy": round(base_correct / total, 4) if total else 0,
351
+ "rl_accuracy": round(rl_correct / total, 4) if (rl_correct is not None and total) else None,
352
+ "base_correct": base_correct,
353
+ "rl_correct": rl_correct,
354
+ "total_time_s": round(total_time_s, 1),
355
+ "args": args_dict,
356
+ }
357
+ out = report_dir / "summary.json"
358
+ out.write_text(json.dumps(summary, indent=2), encoding="utf-8")
359
+ logger.info("Summary saved β†’ %s", out)
360
+
361
+
362
+ # ── Main ──────────────────────────────────────────────────────────────────────
363
+
364
+ def parse_args() -> argparse.Namespace:
365
+ p = argparse.ArgumentParser(description="Run inference: base vs RL model on GSM8K")
366
+ p.add_argument("--checkpoint", type=str, default=None,
367
+ help="Path to RL fine-tuned model or PEFT adapter. "
368
+ "If omitted, only the base model is run.")
369
+ p.add_argument("--data", type=str, default=None,
370
+ help="Path to local GSM8K JSONL file. "
371
+ "Defaults to data/sft/gsm8k_test.jsonl or HuggingFace.")
372
+ p.add_argument("--num-questions", type=int, default=50)
373
+ p.add_argument("--seed", type=int, default=42)
374
+ p.add_argument("--max-new-tokens", type=int, default=512)
375
+ p.add_argument("--temperature", type=float, default=0.1)
376
+ p.add_argument("--run-name", type=str, default=None,
377
+ help="Report sub-folder name. Defaults to timestamp.")
378
+ p.add_argument("--base-only", action="store_true",
379
+ help="Skip RL model; only run the base model.")
380
+ p.add_argument("--reports-dir", type=str, default="reports")
381
+ return p.parse_args()
382
+
383
+
384
+ def main() -> None:
385
+ args = parse_args()
386
+
387
+ run_name = args.run_name or f"run_{datetime.now():%Y%m%d_%H%M%S}"
388
+ report_dir = Path(args.reports_dir) / run_name
389
+ report_dir.mkdir(parents=True, exist_ok=True)
390
+ logger.info("Reports β†’ %s", report_dir)
391
+
392
+ # ── Device ────────────────────────────────────────────────────────────────
393
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
394
+ attn_impl = select_attn_implementation()
395
+ logger.info("Device: %s | attn: %s", device, attn_impl)
396
+ if torch.cuda.is_available():
397
+ g = torch.cuda.get_device_properties(0)
398
+ logger.info("GPU: %s | %.1f GB", g.name, g.total_memory / 1e9)
399
+
400
+ # ── Data ──────────────────────────────────────────────────────────────────
401
+ questions = load_gsm8k_questions(args.data, args.num_questions, args.seed)
402
+
403
+ # ── Models ────────────────────────────────────────────────────────────────
404
+ base_model, base_tokenizer = load_base_model(device, attn_impl)
405
+
406
+ rl_model, rl_tokenizer = None, None
407
+ if not args.base_only and args.checkpoint:
408
+ rl_model, rl_tokenizer = load_rl_model(
409
+ args.checkpoint, base_model, base_tokenizer, device, attn_impl
410
+ )
411
+ elif not args.base_only and not args.checkpoint:
412
+ logger.warning("No --checkpoint provided. Running base model only.")
413
+
414
+ # ── Inference loop ────────────────────────────────────────────────────────
415
+ base_correct = 0
416
+ rl_correct = 0 if rl_model else None
417
+ t_total_start = time.time()
418
+
419
+ for idx, row in enumerate(tqdm(questions, desc="Inference")):
420
+ question = row["question"]
421
+ gold_final = row.get("gold_final", "").strip()
422
+
423
+ # Base model
424
+ base_solution, base_time = generate_solution(
425
+ base_model, base_tokenizer, question, device,
426
+ args.max_new_tokens, args.temperature,
427
+ )
428
+ base_score = score_answer(base_solution, gold_final)
429
+ if base_score["correct"]:
430
+ base_correct += 1
431
+
432
+ base_result = {
433
+ "solution": base_solution,
434
+ "predicted": base_score["predicted"],
435
+ "correct": base_score["correct"],
436
+ "match_type": base_score["match_type"],
437
+ "time_s": round(base_time, 2),
438
+ "num_tokens": len(base_tokenizer.encode(base_solution)),
439
+ }
440
+
441
+ # RL model
442
+ rl_result = None
443
+ if rl_model is not None:
444
+ rl_solution, rl_time = generate_solution(
445
+ rl_model, rl_tokenizer, question, device,
446
+ args.max_new_tokens, args.temperature,
447
+ )
448
+ rl_score = score_answer(rl_solution, gold_final)
449
+ if rl_score["correct"]:
450
+ rl_correct += 1
451
+
452
+ rl_result = {
453
+ "solution": rl_solution,
454
+ "predicted": rl_score["predicted"],
455
+ "correct": rl_score["correct"],
456
+ "match_type": rl_score["match_type"],
457
+ "time_s": round(rl_time, 2),
458
+ "num_tokens": len(rl_tokenizer.encode(rl_solution)),
459
+ }
460
+
461
+ save_question_report(report_dir, idx, question, gold_final, base_result, rl_result)
462
+
463
+ # Live progress log every 10 questions
464
+ if (idx + 1) % 10 == 0 or idx == len(questions) - 1:
465
+ done = idx + 1
466
+ b_acc = base_correct / done
467
+ log_str = f"[{done}/{len(questions)}] Base acc: {b_acc:.1%}"
468
+ if rl_correct is not None:
469
+ log_str += f" | RL acc: {rl_correct / done:.1%}"
470
+ logger.info(log_str)
471
+
472
+ total_time = time.time() - t_total_start
473
+
474
+ # ── Summary ───────────────────────────────────────────────────────────────
475
+ save_summary(
476
+ report_dir=report_dir,
477
+ run_name=run_name,
478
+ checkpoint=args.checkpoint,
479
+ base_correct=base_correct,
480
+ rl_correct=rl_correct,
481
+ total=len(questions),
482
+ total_time_s=total_time,
483
+ args_dict=vars(args),
484
+ )
485
+
486
+ logger.info("=" * 60)
487
+ logger.info("Run complete: %s", run_name)
488
+ logger.info("Base accuracy : %d / %d = %.1f%%",
489
+ base_correct, len(questions), 100 * base_correct / len(questions))
490
+ if rl_correct is not None:
491
+ logger.info("RL accuracy : %d / %d = %.1f%%",
492
+ rl_correct, len(questions), 100 * rl_correct / len(questions))
493
+ delta = rl_correct - base_correct
494
+ sign = "+" if delta >= 0 else ""
495
+ logger.info("Delta : %s%d questions (%s%.1f%%)",
496
+ sign, delta, sign, 100 * delta / len(questions))
497
+ logger.info("Reports : %s", report_dir)
498
+ logger.info("=" * 60)
499
+
500
+
501
+ if __name__ == "__main__":
502
+ main()