Spaces:
Sleeping
Sleeping
| """ | |
| GRPO training for self-improvement math environment. | |
| Group Relative Policy Optimization (GRPO) is dramatically simpler and more | |
| stable than PPO for LLM fine-tuning on math tasks: | |
| - No value function / critic needed | |
| - No GAE, no gamma, no lambda | |
| - No KL instability from per-step advantage collapse | |
| - Advantages computed as within-group z-scores: A_i = (R_i - mean_R) / std_R | |
| - Proven on math RL: DeepSeek-Math, Qwen-Math, DAPO all use GRPO variants | |
| The algorithm per question: | |
| 1. Generate K solutions (default K=4) | |
| 2. Score each with the existing reward pipeline (PRM + SymPy + format) | |
| 3. A_i = (R_i - mean(R)) / (std(R) + eps) | |
| 4. policy_loss = -mean_i [ A_i * sum_t log pi(a_t | s_{<t}) / T_i ] | |
| 5. Skip the group if all rewards are identical (zero gradient signal) | |
| Expected improvement curve: | |
| - Iterations 1-5: reward mean rising, policy learning to avoid R=0 outputs | |
| - Iterations 5-15: GSM8K accuracy starts moving (+2-5%) | |
| - Iterations 15-30: continued improvement toward ~70-75%+ from 63.6% baseline | |
| Usage: | |
| python scripts/run_grpo_training.py \\ | |
| --base-model checkpoints/dual_task_v1 \\ | |
| --gsm8k-data data/sft/gsm8k_sft.jsonl \\ | |
| --num-iterations 30 \\ | |
| --group-size 4 \\ | |
| --questions-per-iter 16 | |
| # Faster smoke test (no PRM, 3 iters): | |
| python scripts/run_grpo_training.py \\ | |
| --base-model checkpoints/dual_task_v1 \\ | |
| --num-iterations 3 --group-size 4 --questions-per-iter 8 \\ | |
| --no-prm --skip-initial-eval --run-name smoke_grpo | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import atexit | |
| import copy | |
| import csv | |
| import json | |
| import logging | |
| import random | |
| import re | |
| import shutil | |
| import sys | |
| import time | |
| import types | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from peft import PeftModel | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from tqdm.auto import tqdm | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| from scripts.convert_gsm8k_to_sft import parse_gsm8k_answer | |
| from scripts.eval_sft_inference import evaluate_gsm8k | |
| from src.rl.prm_scorer import ProcessRewardScorer | |
| from src.sft.solution_format import extract_final_answer_numeric_str | |
| from src.utils.attn_backend import select_attn_implementation | |
| from src.rl.math_environment_curriculum import CurriculumMathEnvironment | |
| from src.config.prompts import create_generator_messages | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s %(levelname)-8s %(name)s - %(message)s", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # --------------------------------------------------------------------------- | |
| # Logging infrastructure | |
| # --------------------------------------------------------------------------- | |
| class TeeStream: | |
| """Mirrors every write to a terminal stream into a log file. | |
| Wrapping sys.stdout and sys.stderr with this object ensures that *all* | |
| output β bare print() calls, tqdm bars, third-party library writes β lands | |
| in the run log file in addition to the terminal. | |
| A separate FileHandler on the root logger (see _add_file_logging) captures | |
| the Python logging subsystem independently, because logging.StreamHandler | |
| stores a reference to the stream at creation time and therefore bypasses | |
| any later sys.stderr reassignment. Both mechanisms together guarantee that | |
| nothing escapes the log file. | |
| """ | |
| def __init__(self, primary, secondary): | |
| self.primary = primary | |
| self.secondary = secondary | |
| def write(self, data: str) -> int: | |
| self.primary.write(data) | |
| self.secondary.write(data) | |
| return len(data) | |
| def flush(self) -> None: | |
| self.primary.flush() | |
| self.secondary.flush() | |
| def isatty(self) -> bool: | |
| return getattr(self.primary, "isatty", lambda: False)() | |
| def fileno(self) -> int: | |
| return self.primary.fileno() | |
| def _add_file_logging(log_path: Path) -> logging.FileHandler: | |
| """Attach a FileHandler to the root logger. | |
| Every logger.info / logger.warning / β¦ call β from any module β will be | |
| written to ``log_path`` in addition to the terminal. This complements | |
| TeeStream: TeeStream captures bare print() / sys.stderr writes; this | |
| handler captures the logging subsystem, which uses its own internal stream | |
| reference that TeeStream cannot intercept. | |
| """ | |
| fh = logging.FileHandler(log_path, mode="a", encoding="utf-8") | |
| fh.setLevel(logging.DEBUG) | |
| fh.setFormatter(logging.Formatter( | |
| "%(asctime)s %(levelname)-8s %(name)s - %(message)s" | |
| )) | |
| logging.getLogger().addHandler(fh) | |
| return fh | |
| if torch.cuda.is_available(): | |
| torch.set_float32_matmul_precision("high") | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| torch.backends.cudnn.benchmark = True # auto-tune fastest conv algo per shape | |
| # --------------------------------------------------------------------------- | |
| # Data loading | |
| # --------------------------------------------------------------------------- | |
| def _infer_eval_dataset_name(data_path: str) -> str: | |
| """Derive a short human-readable label from the eval data file path.""" | |
| stem = Path(data_path).stem.lower() | |
| if "aqua" in stem: | |
| return "AQuA-RAT" | |
| if "math" in stem: | |
| return "MATH" | |
| if "gsm" in stem: | |
| return "GSM8K" | |
| return Path(data_path).stem | |
| def load_gsm8k(path: str) -> List[Dict[str, str]]: | |
| """Return list of {"question": ..., "gold_final": ...} from a JSONL file.""" | |
| pairs: List[Dict[str, str]] = [] | |
| p = Path(path) | |
| if not p.exists(): | |
| logger.warning("Training data not found at %s", path) | |
| return pairs | |
| with p.open(encoding="utf-8") as f: | |
| for line in f: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| try: | |
| rec = json.loads(line) | |
| except json.JSONDecodeError: | |
| continue | |
| question = "" | |
| gold = "" | |
| if "question" in rec and "answer" in rec: | |
| question = rec["question"].strip() | |
| _, gold = parse_gsm8k_answer(str(rec["answer"])) | |
| elif "messages" in rec: | |
| user_text = "" | |
| asst_text = "" | |
| for msg in rec["messages"]: | |
| if msg.get("role") == "user" and not user_text: | |
| user_text = msg.get("content", "").strip() | |
| elif msg.get("role") == "assistant" and not asst_text: | |
| asst_text = msg.get("content", "") | |
| if "Problem:" in user_text: | |
| question = user_text.split("Problem:", 1)[1].strip() | |
| else: | |
| question = user_text | |
| answer_str = extract_final_answer_numeric_str(asst_text) or "" | |
| gold = answer_str.strip() | |
| if question and gold: | |
| pairs.append({"question": question, "gold_final": gold}) | |
| logger.info("Loaded %d QA pairs from %s", len(pairs), path) | |
| return pairs | |
| # --------------------------------------------------------------------------- | |
| # MATH harder dataset | |
| # --------------------------------------------------------------------------- | |
| def _extract_boxed(text: str) -> Optional[str]: | |
| r"""Extract the content of the first ``\boxed{...}`` in *text*.""" | |
| m = re.search(r"\\boxed\{([^}]*)\}", text) | |
| return m.group(1).strip() if m else None | |
| def _boxed_to_numeric(answer: str) -> Optional[str]: | |
| """ | |
| Convert a ``\\boxed{...}`` answer to a plain numeric string. | |
| Returns a string of the form ``"42"`` or ``"3.5000"`` when the answer | |
| is a recognisable integer, decimal, or simple fraction (``3/4`` or | |
| ``\\frac{3}{4}``). Returns ``None`` for symbolic / multi-part answers | |
| like ``3\\sqrt{2}`` or ``(1, 2)``. | |
| """ | |
| ans = answer.strip() | |
| # Direct integer | |
| try: | |
| return str(int(ans)) | |
| except ValueError: | |
| pass | |
| # Direct float (includes "3.5", "0.75", etc.) | |
| try: | |
| v = float(ans) | |
| return str(int(v)) if v == int(v) else f"{v:.4f}" | |
| except ValueError: | |
| pass | |
| # LaTeX fraction \frac{num}{den} | |
| m = re.fullmatch(r"\\frac\{(\d+)\}\{(\d+)\}", ans) | |
| if m: | |
| num, den = int(m.group(1)), int(m.group(2)) | |
| if den: | |
| v = num / den | |
| return str(int(v)) if v == int(v) else f"{v:.4f}" | |
| # Plain fraction num/den | |
| m = re.fullmatch(r"(\d+)/(\d+)", ans) | |
| if m: | |
| num, den = int(m.group(1)), int(m.group(2)) | |
| if den: | |
| v = num / den | |
| return str(int(v)) if v == int(v) else f"{v:.4f}" | |
| return None | |
| def load_math_dataset( | |
| local_path: Optional[str] = None, | |
| cache_path: str = "data/math/math_numeric.jsonl", | |
| max_difficulty: int = 3, | |
| ) -> List[Dict[str, str]]: | |
| """ | |
| Load a subset of the MATH competition dataset filtered to problems with | |
| numerically-verifiable answers (integers, decimals, simple fractions). | |
| Loading order | |
| ------------- | |
| 1. ``local_path`` if provided and the file exists. | |
| 2. ``cache_path`` if that file exists (written on first HF download). | |
| 3. HuggingFace ``competition_math`` dataset; filtered + written to | |
| ``cache_path`` for subsequent runs. | |
| Only problems with ``Level β€ max_difficulty`` are included. Difficulty | |
| 1-2 β AMC-8 level (comparable to hard GSM8K); difficulty 3 β AMC-10. | |
| Levels 4-5 are graduate-level and usually too hard for a 1.5B model to | |
| get any reward signal from (win_rate β 0 β skipped groups every iter). | |
| """ | |
| for candidate in filter(None, [local_path, cache_path]): | |
| p = Path(candidate) | |
| if p.exists(): | |
| pairs: List[Dict[str, str]] = [] | |
| with p.open(encoding="utf-8") as f: | |
| for line in f: | |
| line = line.strip() | |
| if line: | |
| try: | |
| pairs.append(json.loads(line)) | |
| except json.JSONDecodeError: | |
| pass | |
| if pairs: | |
| logger.info("Loaded %d MATH pairs from %s", len(pairs), p) | |
| return pairs | |
| # Download from HuggingFace | |
| logger.info( | |
| "MATH dataset not found locally β downloading from HuggingFace " | |
| "(qwedsacf/competition_math, difficulty β€ %d, numeric answers only)...", | |
| max_difficulty, | |
| ) | |
| # Try HF sources in priority order. Only keep sources confirmed reachable; | |
| # lighteval/MATH and hendrycks/competition_math have network/naming issues. | |
| _HF_SOURCES = [ | |
| ("qwedsacf/competition_math", {}), # reliable public mirror | |
| ("lighteval/MATH-Hard", {"name": "default"}), # hard subset | |
| ] | |
| ds = None | |
| for hf_name, hf_kwargs in _HF_SOURCES: | |
| try: | |
| from datasets import load_dataset # type: ignore | |
| ds = load_dataset(hf_name, split="train", trust_remote_code=True, **hf_kwargs) | |
| logger.info("Loaded HuggingFace dataset: %s (%d items)", hf_name, len(ds)) | |
| break | |
| except Exception as exc: | |
| logger.warning("Could not load %s: %s β trying next source.", hf_name, exc) | |
| if ds is None: | |
| logger.warning( | |
| "All MATH dataset sources failed. Proceeding with GSM8K only. " | |
| "To load offline: download from https://github.com/hendrycks/math " | |
| "and pass --math-data <path_to_jsonl>." | |
| ) | |
| return [] | |
| pairs = [] | |
| for item in ds: | |
| level_str = item.get("level", "Level 5") | |
| try: | |
| level = int(level_str.split()[-1]) | |
| except (ValueError, IndexError): | |
| level = 5 | |
| if level > max_difficulty: | |
| continue | |
| question = item.get("problem", "").strip() | |
| solution = item.get("solution", "") | |
| boxed = _extract_boxed(solution) | |
| if not boxed: | |
| continue | |
| numeric = _boxed_to_numeric(boxed) | |
| if not numeric: | |
| continue | |
| pairs.append({"question": question, "gold_final": numeric}) | |
| if pairs: | |
| out_p = Path(cache_path) | |
| out_p.parent.mkdir(parents=True, exist_ok=True) | |
| with out_p.open("w", encoding="utf-8") as f: | |
| for p_item in pairs: | |
| f.write(json.dumps(p_item) + "\n") | |
| logger.info("Cached %d MATH numeric pairs to %s", len(pairs), out_p) | |
| else: | |
| logger.warning("No MATH pairs passed the numeric filter β check the dataset.") | |
| return pairs | |
| # --------------------------------------------------------------------------- | |
| # Reward | |
| # --------------------------------------------------------------------------- | |
| # --------------------------------------------------------------------------- | |
| # Self-play verification cascade | |
| # --------------------------------------------------------------------------- | |
| # Routes each self-play group to the right verification tool based on | |
| # problem type and difficulty, then gates the GRPO update on the result. | |
| # Returns False (β skip group) when no tool can verify cleanly, preventing | |
| # circular PRM-only reward from anchoring the training signal. | |
| import re as _re | |
| _FINAL_ANSWER_RE = _re.compile(r"final answer[:\s]*([^\n]+)", _re.I) | |
| # Problem-type routing tables | |
| _PAL_TOPICS = frozenset({"arithmetic", "algebra", "prealgebra", "grounded"}) | |
| _SYMPY_TOPICS = frozenset({ | |
| "number_theory", "intermediate_algebra", "precalculus", | |
| "counting_and_probability", | |
| }) | |
| _EXCLUDE_TOPICS = frozenset({"geometry"}) # spatial reasoning; cannot verify programmatically | |
| def _extract_final_answer(solution: str) -> Optional[str]: | |
| """Extract the text after 'Final Answer:' from a solution string.""" | |
| m = _FINAL_ANSWER_RE.search(solution) | |
| return m.group(1).strip() if m else None | |
| def _pal_eval(answer_str: str) -> Optional[float]: | |
| """Tier 1: arithmetic / basic algebra via safe eval (no builtins, no names).""" | |
| try: | |
| val = eval(answer_str, {"__builtins__": {}}, {}) # noqa: S307 | |
| f = float(val) | |
| return None if f != f else f # NaN guard | |
| except Exception: | |
| return None | |
| def _sympy_eval(answer_str: str) -> Optional[float]: | |
| """Tier 2: symbolic evaluation via SymPy for algebra, number theory, etc.""" | |
| try: | |
| from sympy import sympify, N as _N # type: ignore | |
| f = float(_N(sympify(answer_str), 15)) | |
| return None if f != f else f # NaN guard | |
| except Exception: | |
| return None | |
| def _verify_self_play_answer( | |
| solutions: List[str], | |
| target_topic: str, | |
| target_difficulty: float, | |
| ) -> bool: | |
| """ | |
| Tiered verification cascade for self-play groups. | |
| Returns True only when a majority of solutions agree on an answer that an | |
| independent tool (PAL eval or SymPy) can verify as a finite number. | |
| Returns False β drop this group, no gradient β when: | |
| * topic is geometry (spatial reasoning, can't verify programmatically) | |
| * difficulty >= 4.0 (should have been blocked at generation, guard here too) | |
| * no tool can parse a consistent numerical answer | |
| * fewer than half of solutions agree on the majority answer | |
| Coverage for GSM8K + MATH: | |
| GSM8K β PAL tier, ~95%+ verified | |
| MATH L1-L2 algebra β PAL + SymPy fallback, ~80% verified | |
| MATH number theory / intermediate algebra β SymPy primary, ~70% verified | |
| MATH geometry β excluded entirely (~3-5% of MATH) | |
| MATH L4-L5 β excluded at generation time (see call site) | |
| """ | |
| topic = target_topic.lower().replace(" ", "_") | |
| # Hard exclusions (guard even if called after generation-time check) | |
| if topic in _EXCLUDE_TOPICS or target_difficulty >= 4.0: | |
| return False | |
| answers: List[float] = [] | |
| for sol in solutions: | |
| raw = _extract_final_answer(sol) | |
| if raw is None: | |
| continue | |
| val: Optional[float] | |
| if topic in _PAL_TOPICS or target_difficulty <= 2: | |
| val = _pal_eval(raw) or _sympy_eval(raw) | |
| elif topic in _SYMPY_TOPICS: | |
| val = _sympy_eval(raw) or _pal_eval(raw) | |
| else: | |
| # Unknown topic: try both | |
| val = _pal_eval(raw) or _sympy_eval(raw) | |
| if val is not None: | |
| answers.append(round(val, 6)) | |
| if not answers: | |
| return False # Tier 4: cannot verify β exclude | |
| majority = max(set(answers), key=answers.count) | |
| return answers.count(majority) >= max(1, len(solutions) // 2) | |
| def compute_grounded_reward( | |
| question: str, | |
| solution: str, | |
| gold_final: str, | |
| math_env: CurriculumMathEnvironment, | |
| ) -> Dict[str, float]: | |
| """Score a solution against a known gold answer (grounded path). | |
| Returns a dict with: | |
| combined_score β 0.50Γcorrect + 0.40Γprocess(prm_final,prm_mean) + 0.10Γfmt | |
| step_accuracy β fraction of PRM steps rated > 0.5 (the core process metric) | |
| prm_mean_score β PRM mean across all steps | |
| prm_final_score β PRM score on the final reasoning step | |
| gt_match β bool, whether pred matches gold | |
| format_score β format compliance score | |
| """ | |
| result = math_env.compute_grounded_reward( | |
| question=question, | |
| solution=solution, | |
| gold_final=gold_final, | |
| ) | |
| return { | |
| "combined_score": float(result.get("combined_score", 0.0)), | |
| "step_accuracy": float(result.get("step_accuracy", 0.0)), | |
| "lccp": float(result.get("lccp", 0.0)), | |
| "prm_mean_score": float(result.get("prm_mean_score", 0.0)), | |
| "prm_final_score": float(result.get("prm_final_score", 0.0)), | |
| "gt_match": bool(result.get("gt_match", False)), | |
| "format_score": float(result.get("format_score", 0.0)), | |
| } | |
| def compute_self_play_reward( | |
| question: str, | |
| solution: str, | |
| target_topic: str, | |
| target_difficulty: float, | |
| math_env: CurriculumMathEnvironment, | |
| ) -> Tuple[float, float, float, Dict]: | |
| """Score a self-generated question + solution (self-play path). | |
| Returns (combined_reward, question_reward, solution_reward, q_metrics). | |
| Reward breakdown: R = 0.40Γquestion_quality + 0.60Γsolution_quality, | |
| where question_quality captures topic match, difficulty fit, clarity, | |
| novelty, and solvability β completing the Theme #4 self-improvement loop | |
| where the model is rewarded for generating *good challenges*, not only | |
| for solving them. | |
| q_metrics contains the full question quality breakdown: | |
| topic_match, difficulty_fit, clarity, novelty, solvability, overall_score | |
| """ | |
| result = math_env.compute_reward( | |
| question=question, | |
| solution=solution, | |
| target_topic=target_topic, | |
| target_difficulty=target_difficulty, | |
| ) | |
| combined = float(result["combined_score"]) | |
| sol_score = result.get("solution_metrics", {}) | |
| s_reward = float(sol_score.get("overall_score", 0.0)) if isinstance(sol_score, dict) else 0.0 | |
| # question_reward is NOT a top-level key in compute_reward()'s return dict. | |
| # The question quality score lives inside question_metrics["overall_score"]. | |
| # Key mapping from QuestionEvalResult.to_dict(): | |
| # overall_score β scalar (overall question quality) | |
| # topic_match β scalar | |
| # difficulty_score β scalar (fit to target difficulty; named _score not _fit) | |
| # clarity β scalar | |
| # solvability_scoreβ scalar (the dict version is under "solvability" β don't use that) | |
| # novelty_combined β scalar (the dict version is under "novelty" β don't use that) | |
| q_metrics_raw = result.get("question_metrics", {}) or {} | |
| # Use the gated question reward (zeroed when solution is invalid) β this is | |
| # what actually contributed to combined_score, not the raw overall_score. | |
| q_reward = float(result.get("effective_question_reward", q_metrics_raw.get("overall_score", 0.0))) | |
| q_metrics: Dict = { | |
| "overall_score": q_reward, | |
| "topic_match": float(q_metrics_raw.get("topic_match", 0.0)), | |
| "difficulty_fit": float(q_metrics_raw.get("difficulty_score", 0.0)), | |
| "clarity": float(q_metrics_raw.get("clarity", 0.0)), | |
| "novelty": float(q_metrics_raw.get("novelty_combined", 0.0)), | |
| "solvability": float(q_metrics_raw.get("solvability_score", 0.0)), | |
| # Chain integrity score from Phase 2+ unified calculator (None if inactive) | |
| "sp_chain_integrity_score": result.get("sp_chain_integrity_score"), | |
| } | |
| return combined, q_reward, s_reward, q_metrics | |
| def generate_question( | |
| model: AutoModelForCausalLM, | |
| tokenizer: AutoTokenizer, | |
| instruction: str, | |
| max_new_tokens: int, | |
| device: torch.device, | |
| temperature: float = 0.85, | |
| ) -> str: | |
| """Generate a math question from a curriculum instruction. | |
| Uses centralized prompts from src/config/prompts.py to ensure consistency | |
| across SFT training, GRPO, PPO, and inference. | |
| Returns the raw decoded question text (no special tokens). | |
| """ | |
| # Use centralized prompt configuration | |
| messages = create_generator_messages(instruction) | |
| try: | |
| prompt = tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| except Exception: | |
| # Fallback if chat template is missing | |
| system = messages[0]["content"] | |
| user = messages[1]["content"] | |
| prompt = f"{system}\n\n{user}\n" | |
| enc = tokenizer( | |
| prompt, return_tensors="pt", truncation=True, max_length=512 | |
| ).to(device) | |
| prompt_len = enc["input_ids"].shape[1] | |
| stop_ids: List[int] = [] | |
| if tokenizer.eos_token_id is not None: | |
| stop_ids.append(tokenizer.eos_token_id) | |
| im_end = tokenizer.convert_tokens_to_ids("<|im_end|>") | |
| if isinstance(im_end, int) and im_end not in stop_ids: | |
| stop_ids.append(im_end) | |
| out = model.generate( | |
| input_ids=enc["input_ids"], | |
| attention_mask=enc["attention_mask"], | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_p=0.95, | |
| pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id, | |
| eos_token_id=stop_ids or None, | |
| use_cache=True, | |
| ) | |
| return tokenizer.decode(out[0][prompt_len:], skip_special_tokens=True).strip() | |
| # --------------------------------------------------------------------------- | |
| # Generation | |
| # --------------------------------------------------------------------------- | |
| def generate_questions_batched( | |
| model: AutoModelForCausalLM, | |
| tokenizer: AutoTokenizer, | |
| instruction: str, | |
| K_q: int, | |
| max_new_tokens: int, | |
| temperature: float, | |
| device: torch.device, | |
| ) -> Tuple[List[str], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: | |
| """ | |
| Generate K_q question candidates from a single curriculum instruction in | |
| one batched model.generate() call. Returns the same four-tuple as | |
| ``generate_solutions_batched`` so the question token IDs can be passed | |
| directly to ``grpo_loss_for_group`` for the question-level GRPO update. | |
| Uses the same centralized prompts (``create_generator_messages``) as | |
| ``generate_question()`` so the chat format is identical whether running | |
| single-question or batched two-phase generation. | |
| Returns: | |
| questions : K_q decoded question strings | |
| input_ids_list : K_q full (prompt+response) token ID tensors | |
| response_masks : K_q bool masks (True = non-pad response token) | |
| old_log_probs : K_q scalar tensors (sum log Ο_old over response), | |
| no_grad β used as denominator in IS ratio. | |
| """ | |
| messages = create_generator_messages(instruction) | |
| try: | |
| prompt = tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| except Exception: | |
| prompt = f"{system}\n\n{instruction}\n" | |
| stop_ids = _build_stop_token_ids(tokenizer) | |
| pad_id: int = ( | |
| tokenizer.pad_token_id | |
| if tokenizer.pad_token_id is not None | |
| else tokenizer.eos_token_id | |
| ) | |
| enc = tokenizer( | |
| prompt, return_tensors="pt", truncation=True, max_length=512 | |
| ).to(device) | |
| prompt_len: int = enc["input_ids"].shape[1] | |
| input_ids_batch = enc["input_ids"].expand(K_q, -1).contiguous() | |
| attn_mask_batch = enc["attention_mask"].expand(K_q, -1).contiguous() | |
| model.eval() | |
| with torch.no_grad(): | |
| out = model.generate( | |
| input_ids=input_ids_batch, | |
| attention_mask=attn_mask_batch, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_p=0.95, | |
| pad_token_id=pad_id, | |
| eos_token_id=stop_ids, | |
| use_cache=True, | |
| ) | |
| questions: List[str] = [] | |
| input_ids_list: List[torch.Tensor] = [] | |
| response_masks: List[torch.Tensor] = [] | |
| pad_id_t = torch.tensor(pad_id, device=device, dtype=out.dtype) | |
| for i in range(K_q): | |
| full_ids = out[i] | |
| response_section = full_ids[prompt_len:] | |
| mask = torch.zeros(full_ids.shape[0], dtype=torch.bool, device=device) | |
| mask[prompt_len:] = response_section != pad_id_t | |
| question = tokenizer.decode(response_section, skip_special_tokens=True).strip() | |
| questions.append(question) | |
| input_ids_list.append(full_ids) | |
| response_masks.append(mask) | |
| # Single batched forward pass for all K_q old log-probs (same trick as solutions). | |
| old_log_probs: List[torch.Tensor] = [] | |
| with torch.no_grad(): | |
| attn_mask_lp = (out != pad_id_t) | |
| attn_mask_lp[:, :prompt_len] = True | |
| batch_logits = model( | |
| input_ids=out, | |
| attention_mask=attn_mask_lp.long(), | |
| use_cache=False, | |
| return_dict=True, | |
| ).logits # [K_q, total_len, vocab] | |
| for i in range(K_q): | |
| full_ids = out[i] | |
| mask = response_masks[i] | |
| shift_logits = batch_logits[i, :-1] | |
| shift_labels = full_ids[1:] | |
| shift_mask = mask[1:] | |
| lp_tokens = F.log_softmax(shift_logits, dim=-1)[ | |
| torch.arange(shift_logits.size(0), device=device), | |
| shift_labels, | |
| ] | |
| resp_lps = lp_tokens[shift_mask] | |
| old_log_probs.append( | |
| resp_lps.sum().detach() if resp_lps.numel() > 0 | |
| else torch.tensor(0.0, device=device) | |
| ) | |
| return questions, input_ids_list, response_masks, old_log_probs | |
| def _build_stop_token_ids(tokenizer: AutoTokenizer) -> List[int]: | |
| """ | |
| Return a list of token IDs that should stop generation. | |
| Qwen2.5-chat models end turns with <|im_end|> (ID 151645). If that | |
| token is not the same as eos_token_id we include both so that .generate() | |
| halts cleanly instead of running to max_new_tokens and emitting repetitive | |
| garbage. | |
| """ | |
| stop_ids: List[int] = [] | |
| if tokenizer.eos_token_id is not None: | |
| stop_ids.append(tokenizer.eos_token_id) | |
| im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>") | |
| if isinstance(im_end_id, int) and im_end_id not in stop_ids: | |
| stop_ids.append(im_end_id) | |
| return stop_ids or None # type: ignore[return-value] | |
| def generate_solutions_batched( | |
| model: AutoModelForCausalLM, | |
| tokenizer: AutoTokenizer, | |
| prompt: str, | |
| K: int, | |
| max_new_tokens: int, | |
| temperature: float, | |
| device: torch.device, | |
| ) -> Tuple[List[str], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: | |
| """ | |
| Generate K solutions for a prompt in a **single batched** model.generate() call. | |
| Batching all K sequences together achieves near-100% GPU utilisation vs | |
| the old sequential loop (which was <20% utilised). On an A100 with K=8, | |
| this is typically 4-8Γ faster than K sequential calls. | |
| ``prompt`` must come from ``math_env.format_solution_prompt(question)`` | |
| so the chat-template system/user wrapping exactly matches the SFT | |
| training format. | |
| Returns: | |
| solutions : K decoded strings (prompt stripped, specials removed) | |
| input_ids_list : K full (prompt+response) token ID tensors | |
| response_masks : K bool masks (True = non-pad response token) | |
| old_log_probs : K scalar tensors, sum(log Ο_old(token)) over response, | |
| computed no_grad β used for IS clip ratio in the loss. | |
| """ | |
| stop_ids = _build_stop_token_ids(tokenizer) | |
| pad_id: int = ( | |
| tokenizer.pad_token_id | |
| if tokenizer.pad_token_id is not None | |
| else tokenizer.eos_token_id | |
| ) | |
| enc = tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| padding=False, | |
| truncation=True, | |
| max_length=1024, | |
| ).to(device) | |
| prompt_len: int = enc["input_ids"].shape[1] | |
| # Expand prompt K times along the batch dimension (no data copy). | |
| input_ids_batch = enc["input_ids"].expand(K, -1).contiguous() | |
| attn_mask_batch = enc["attention_mask"].expand(K, -1).contiguous() | |
| model.eval() | |
| with torch.no_grad(): | |
| out = model.generate( | |
| input_ids=input_ids_batch, | |
| attention_mask=attn_mask_batch, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_p=0.9, | |
| pad_token_id=pad_id, | |
| eos_token_id=stop_ids, | |
| use_cache=True, | |
| ) | |
| # out: [K, prompt_len + padded_response_len] | |
| # ββ 1. Build masks and decode solutions ββββββββββββββββββββββββββββββββββ | |
| solutions: List[str] = [] | |
| input_ids_list: List[torch.Tensor] = [] | |
| response_masks: List[torch.Tensor] = [] | |
| pad_id_t = torch.tensor(pad_id, device=device, dtype=out.dtype) | |
| for i in range(K): | |
| full_ids = out[i] | |
| response_section = full_ids[prompt_len:] | |
| mask = torch.zeros(full_ids.shape[0], dtype=torch.bool, device=device) | |
| mask[prompt_len:] = response_section != pad_id_t | |
| solution = tokenizer.decode(response_section, skip_special_tokens=True) | |
| solutions.append(solution) | |
| input_ids_list.append(full_ids) | |
| response_masks.append(mask) | |
| # ββ 2. Batched old_log_probs β ONE forward pass for all K sequences βββββββ | |
| # The old sequential approach called compute_sequence_log_prob K times | |
| # (K separate CPUβGPU round-trips + K forward passes). A single batched | |
| # forward pass over out[K, total_len] gives the same result KΓ faster. | |
| # | |
| # Attention mask: always attend to prompt tokens; attend to response tokens | |
| # only where they are non-pad. This matches what the model saw during | |
| # model.generate() and prevents padding from distorting log probs. | |
| old_log_probs: List[torch.Tensor] = [] | |
| with torch.no_grad(): | |
| attn_mask_lp = (out != pad_id_t) # [K, total_len] | |
| attn_mask_lp[:, :prompt_len] = True # prompt always attended | |
| batch_logits = model( | |
| input_ids=out, | |
| attention_mask=attn_mask_lp.long(), | |
| use_cache=False, | |
| return_dict=True, | |
| ).logits # [K, total_len, vocab] | |
| for i in range(K): | |
| full_ids = out[i] | |
| mask = response_masks[i] | |
| shift_logits = batch_logits[i, :-1] # [total_len-1, vocab] | |
| shift_labels = full_ids[1:] # [total_len-1] | |
| shift_mask = mask[1:] # [total_len-1] | |
| lp_tokens = F.log_softmax(shift_logits, dim=-1)[ | |
| torch.arange(shift_logits.size(0), device=device), | |
| shift_labels, | |
| ] # [total_len-1] | |
| resp_lps = lp_tokens[shift_mask] | |
| old_log_probs.append( | |
| resp_lps.sum().detach() if resp_lps.numel() > 0 | |
| else torch.tensor(0.0, device=device) | |
| ) | |
| return solutions, input_ids_list, response_masks, old_log_probs | |
| def compute_sequence_log_prob( | |
| model: AutoModelForCausalLM, | |
| input_ids: torch.Tensor, | |
| response_mask: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """ | |
| Forward pass through model to get sum of log probs for response tokens. | |
| Returns scalar tensor (differentiable). | |
| """ | |
| # input_ids: [seq_len] β unsqueeze to [1, seq_len] | |
| ids = input_ids.unsqueeze(0) | |
| # Causal LM: logits[i] predicts token[i+1] | |
| outputs = model(input_ids=ids, use_cache=False, return_dict=True) | |
| logits = outputs.logits[0] # [seq_len, vocab] | |
| # Shift: predict token t+1 from logit at position t | |
| shift_logits = logits[:-1] # [seq_len-1, vocab] | |
| shift_labels = input_ids[1:] # [seq_len-1] | |
| shift_mask = response_mask[1:] # [seq_len-1] (response tokens) | |
| log_probs = F.log_softmax(shift_logits, dim=-1) # [seq_len-1, vocab] | |
| token_log_probs = log_probs[ | |
| torch.arange(shift_logits.size(0), device=shift_logits.device), | |
| shift_labels, | |
| ] # [seq_len-1] | |
| # Sum log probs over response tokens only | |
| response_log_probs = token_log_probs[shift_mask] | |
| if response_log_probs.numel() == 0: | |
| return torch.tensor(0.0, requires_grad=True, device=input_ids.device) | |
| return response_log_probs.sum() | |
| # --------------------------------------------------------------------------- | |
| # GRPO update for one question group | |
| # --------------------------------------------------------------------------- | |
| def grpo_loss_for_group( | |
| model: AutoModelForCausalLM, | |
| input_ids_list: List[torch.Tensor], | |
| response_masks: List[torch.Tensor], | |
| rewards: List[float], | |
| old_log_probs: List[torch.Tensor], | |
| clip_eps: float = 0.2, | |
| kl_coef: float = 0.0, | |
| ref_model: Optional[AutoModelForCausalLM] = None, | |
| eps: float = 1e-8, | |
| ) -> Optional[torch.Tensor]: | |
| """ | |
| Compute GRPO loss for a group of K solutions to the same question. | |
| IS clip (``clip_eps > 0``): | |
| ratio = Ο_ΞΈ(response) / Ο_old(response) [sequence level] | |
| L_GRPO = -min(ratio Γ A, clip(ratio, 1-Ξ΅, 1+Ξ΅) Γ A) / T | |
| Reference-policy KL penalty (``kl_coef > 0``, ``ref_model`` required): | |
| KL(Ο_ΞΈ β Ο_ref) β (log Ο_ΞΈ β log Ο_ref) / T per sequence | |
| L_total = L_GRPO + Ξ² Γ KL | |
| The KL term acts as an anchor: it prevents the policy from drifting so | |
| far from its starting point that it forgets the SFT knowledge baked in | |
| during dual_task_v1 fine-tuning. Ξ²=0.04 is a conservative starting | |
| value (matches DeepSeekMath GRPO default). | |
| Returns None if all rewards are identical (zero gradient signal). | |
| """ | |
| rewards_arr = np.array(rewards, dtype=np.float32) | |
| std_r = rewards_arr.std() | |
| if std_r < eps: | |
| return None | |
| mean_r = rewards_arr.mean() | |
| advantages = (rewards_arr - mean_r) / (std_r + eps) | |
| advantages = np.clip(advantages, -5.0, 5.0) | |
| _device = next(model.parameters()).device | |
| group_loss = torch.tensor(0.0, device=_device) | |
| n_valid = 0 | |
| model.train() | |
| for ids, mask, adv, old_lp in zip( | |
| input_ids_list, response_masks, advantages, old_log_probs | |
| ): | |
| new_lp = compute_sequence_log_prob(model, ids, mask) # differentiable | |
| n_response = int(mask[1:].sum().item()) | |
| if n_response == 0: | |
| continue | |
| adv_t = torch.tensor(adv, dtype=new_lp.dtype, device=_device) | |
| # ββ GRPO surrogate (with optional IS clip) ββββββββββββββββββββββββ | |
| if clip_eps > 0: | |
| ratio = torch.exp(new_lp - old_lp.to(_device).detach()) | |
| surr_unclipped = ratio * adv_t / n_response | |
| surr_clipped = ( | |
| torch.clamp(ratio, 1.0 - clip_eps, 1.0 + clip_eps) | |
| * adv_t / n_response | |
| ) | |
| loss_i = -torch.min(surr_unclipped, surr_clipped) | |
| else: | |
| loss_i = -(adv_t * new_lp / n_response) | |
| # ββ Reference-policy KL penalty βββββββββββββββββββββββββββββββββββ | |
| # KL(Ο_ΞΈ β Ο_ref) = mean_token(log Ο_ΞΈ β log Ο_ref) | |
| # Adding +Ξ²ΓKL to the minimisation objective penalises drift from | |
| # the reference (frozen) checkpoint. This is differentiable through | |
| # new_lp; ref_lp is always detached (no grad through frozen model). | |
| if kl_coef > 0.0 and ref_model is not None: | |
| with torch.no_grad(): | |
| ref_lp = compute_sequence_log_prob(ref_model, ids, mask) | |
| kl_per_token = (new_lp - ref_lp.to(_device).detach()) / n_response | |
| loss_i = loss_i + kl_coef * kl_per_token | |
| group_loss = group_loss + loss_i | |
| n_valid += 1 | |
| if n_valid == 0: | |
| return None | |
| return group_loss / n_valid | |
| # --------------------------------------------------------------------------- | |
| # Evaluation helpers | |
| # --------------------------------------------------------------------------- | |
| def _log_eval_result(label: str, res: Dict, best: Optional[float]) -> None: | |
| """Print a structured evaluation summary that mirrors the training objective.""" | |
| cs = float(res.get("combined_score", 0.0)) | |
| cr = float(res.get("correct_rate", 0.0)) | |
| step_a = float(res.get("step_accuracy", 0.0)) | |
| lccp = float(res.get("lccp", 0.0)) | |
| prm = float(res.get("prm_mean", 0.0)) | |
| prm_fin = float(res.get("prm_final", 0.0)) | |
| fmt = float(res.get("format_mean", 0.0)) | |
| n_sc = int(res.get("n_scored", res.get("total", 0))) | |
| fa_acc = float(res.get("final_answer_accuracy", cr)) | |
| pak = res.get("pass_at_k") | |
| pak_k = int(res.get("pass_at_k_k", 4)) | |
| best_str = f" (best={best:.4f})" if best is not None else "" | |
| logger.info( | |
| "Training Score [%s]: %.4f%s | n=%d", | |
| label, cs, best_str, n_sc, | |
| ) | |
| logger.info( | |
| " Components : 0.50Γcorrect(%.1f%%) + 0.40Γprocess + 0.10Γfmt(%.3f)", | |
| 100 * cr, fmt, | |
| ) | |
| logger.info( | |
| " Process score : prm_mean=%.3f prm_final=%.3f β weighted=%.3f", | |
| prm, prm_fin, 0.60 * prm_fin + 0.40 * prm, | |
| ) | |
| logger.info( | |
| " Step accuracy : %.1f%% (bag-of-steps: fraction of steps PRM >0.5)", | |
| 100 * step_a, | |
| ) | |
| logger.info( | |
| " Chain integrity (LCCP): %.1f%% β fraction of steps before first failure\n" | |
| " [LCCP=100%% β all steps correct; LCCP=0%% β first step wrong]", | |
| 100 * lccp, | |
| ) | |
| if pak is not None: | |
| logger.info( | |
| " pass@%d (T=0.8): %.1f%% | greedy correct: %.1f%% " | |
| "β ceiling vs floor gap", | |
| pak_k, 100 * pak, 100 * cr, | |
| ) | |
| logger.info( | |
| " (debug) final-answer accuracy: %.1f%%", | |
| 100 * fa_acc, | |
| ) | |
| def evaluate_policy( | |
| model: AutoModelForCausalLM, | |
| tokenizer: AutoTokenizer, | |
| eval_data_path: str, | |
| max_samples: int, | |
| max_new_tokens: int, | |
| math_env: Optional[Any] = None, | |
| pass_at_k: int = 4, | |
| ) -> Dict[str, object]: | |
| """Run GSM8K evaluation using the SAME reward formula as GRPO training. | |
| When *math_env* is supplied a ``reward_fn`` is constructed that calls | |
| ``math_env.compute_grounded_reward(question, solution, gold)``. This | |
| returns ``combined_score = 0.50Γcorrect + 0.40Γprocess(0.60Γprm_final | |
| + 0.40Γprm_mean) + 0.10Γformat``, making the eval metric IDENTICAL to | |
| the GRPO training objective. Any improvement in step quality, chain | |
| integrity, or format compliance shows up immediately in the accuracy | |
| number instead of being hidden behind the coarse binary final-answer | |
| signal. | |
| """ | |
| if not Path(eval_data_path).exists(): | |
| return {"accuracy": 0.0, "combined_score": 0.0, "total": 0} | |
| model.eval() | |
| reward_fn = None | |
| if math_env is not None: | |
| import logging as _log_mod | |
| _mec_logger = _log_mod.getLogger("src.rl.math_environment_curriculum") | |
| _prm_logger = _log_mod.getLogger("src.rl.prm_scorer") | |
| def reward_fn(question: str, solution: str, gold: str) -> Dict: | |
| """Thin wrapper that silences per-sample INFO logs during eval.""" | |
| _old_mec = _mec_logger.level | |
| _old_prm = _prm_logger.level | |
| _mec_logger.setLevel(_log_mod.WARNING) | |
| _prm_logger.setLevel(_log_mod.WARNING) | |
| try: | |
| return math_env.compute_grounded_reward(question, solution, gold) | |
| finally: | |
| _mec_logger.setLevel(_old_mec) | |
| _prm_logger.setLevel(_old_prm) | |
| results = evaluate_gsm8k( | |
| model=model, | |
| tokenizer=tokenizer, | |
| data_path=eval_data_path, | |
| max_samples=max_samples, | |
| max_new_tokens=max_new_tokens, | |
| reward_fn=reward_fn, | |
| pass_at_k=pass_at_k, | |
| dataset_name=_infer_eval_dataset_name(eval_data_path), | |
| ) | |
| model.train() | |
| return results | |
| # --------------------------------------------------------------------------- | |
| # Main training loop | |
| # --------------------------------------------------------------------------- | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="GRPO training for self-improvement math") | |
| parser.add_argument("--base-model", default="checkpoints/dual_task_v1") | |
| parser.add_argument("--output-dir", default="checkpoints/grpo") | |
| parser.add_argument("--gsm8k-data", default="data/sft/gsm8k_sft.jsonl") | |
| parser.add_argument("--eval-data-path", default="data/sft/dual_task_val.jsonl") | |
| parser.add_argument("--num-iterations", type=int, default=30) | |
| parser.add_argument( | |
| "--group-size", type=int, default=4, | |
| help="K: number of solutions per question per GRPO group (default 4).", | |
| ) | |
| parser.add_argument( | |
| "--q-group-size", type=int, default=1, | |
| help="K_q: question candidates per self-play group (default 1 = disabled). " | |
| "When β₯2, a second question-level GRPO update is added: K_q questions are " | |
| "sampled from the same instruction, each solved group-size times; the " | |
| "per-question reward (mean of its M solution rewards) drives a GRPO update " | |
| "on the question tokens. Recommended: 2 with --group-size 4 to keep " | |
| "total self-play compute the same as K_q=1 with group-size 8.", | |
| ) | |
| parser.add_argument( | |
| "--questions-per-iter", type=int, default=16, | |
| help="Number of questions per training iteration (default 16).", | |
| ) | |
| parser.add_argument("--learning-rate", type=float, default=5e-6) | |
| parser.add_argument("--max-new-tokens", type=int, default=400) | |
| parser.add_argument("--temperature", type=float, default=0.8) | |
| parser.add_argument("--eval-every", type=int, default=5) | |
| parser.add_argument("--eval-max-samples", type=int, default=250) | |
| parser.add_argument("--eval-max-new-tokens", type=int, default=512) | |
| parser.add_argument( | |
| "--eval-pass-at-k", type=int, default=0, | |
| help="Number of sampled solutions per eval problem for pass@k (0 to disable). " | |
| "Makes eval directly comparable to training batch_acc (both K samples at T=0.8). " | |
| "Disabled by default β enable with e.g. --eval-pass-at-k 4 for demo runs only " | |
| "(adds KΓeval_samples extra forward passes).", | |
| ) | |
| parser.add_argument("--use-prm", dest="use_prm", action="store_true", default=True) | |
| parser.add_argument("--no-prm", dest="use_prm", action="store_false") | |
| parser.add_argument("--prm-model", default="Qwen/Qwen2.5-Math-PRM-7B") | |
| parser.add_argument("--skip-initial-eval", action="store_true") | |
| parser.add_argument("--run-name", default=None) | |
| parser.add_argument("--max-grad-norm", type=float, default=1.0) | |
| parser.add_argument( | |
| "--kl-coef", type=float, default=0.04, | |
| help="Reference-policy KL penalty coefficient Ξ². 0 = disabled. Default 0.04.", | |
| ) | |
| parser.add_argument( | |
| "--math-data", type=str, default=None, | |
| help="Path to MATH dataset JSONL. If absent, downloads from HuggingFace " | |
| "(competition_math) and caches to data/math/math_numeric.jsonl.", | |
| ) | |
| parser.add_argument( | |
| "--math-mix-ratio", type=float, default=0.3, | |
| help="Fraction of each question batch drawn from MATH (vs GSM8K). " | |
| "0 = GSM8K only, 1 = MATH only. Default 0.3.", | |
| ) | |
| parser.add_argument( | |
| "--math-mix-ratio-late", type=float, default=None, | |
| help="If set, ramp MATH fraction from --math-mix-ratio to this value " | |
| "starting at iter 15 (linear ramp over next 10 iters). " | |
| "Example: --math-mix-ratio 0.3 --math-mix-ratio-late 0.5 " | |
| "raises difficulty progressively once the policy is stable.", | |
| ) | |
| parser.add_argument( | |
| "--math-ramp-start", type=int, default=15, | |
| help="Iteration at which to begin the MATH ratio ramp. Default 15.", | |
| ) | |
| parser.add_argument( | |
| "--math-max-difficulty", type=int, default=3, | |
| help="Maximum MATH difficulty level to include (1-5). Default 3.", | |
| ) | |
| parser.add_argument( | |
| "--clip-eps", type=float, default=0.2, | |
| help="Importance-sampling clip ratio Ξ΅ (PPO-style clip applied inside GRPO). " | |
| "0 = disabled (plain GRPO). Default 0.2.", | |
| ) | |
| parser.add_argument( | |
| "--warmup-iters", type=int, default=3, | |
| help="Number of linear LR warmup iterations before cosine decay. Default 3.", | |
| ) | |
| parser.add_argument( | |
| "--min-lr-ratio", type=float, default=0.1, | |
| help="Cosine decay floor as a fraction of peak LR (default 0.1 = 10%%).", | |
| ) | |
| parser.add_argument( | |
| "--difficulty-alpha", type=float, default=2.0, | |
| help="Sharpness of difficulty-weighted question sampling. " | |
| "Higher = stronger preference for on-the-margin questions (win_rate β 0.5). " | |
| "0 = uniform random (default behaviour). Default 2.0.", | |
| ) | |
| parser.add_argument( | |
| "--overlong-filter", dest="overlong_filter", | |
| action="store_true", default=True, | |
| help="Skip solutions that hit max-new-tokens (truncated = no Final Answer). Default on.", | |
| ) | |
| parser.add_argument( | |
| "--no-overlong-filter", dest="overlong_filter", action="store_false", | |
| help="Disable overlong-response filtering.", | |
| ) | |
| parser.add_argument( | |
| "--save-every", type=int, default=1, | |
| help="Save a full checkpoint every N iterations (default 1 = every iter). " | |
| "Best-policy is always saved when accuracy improves, independently of this flag.", | |
| ) | |
| parser.add_argument( | |
| "--keep-last", type=int, default=0, | |
| help="Keep only the last K iter_* checkpoints on disk (0 = keep all). " | |
| "best_policy/ is never pruned.", | |
| ) | |
| parser.add_argument( | |
| "--self-play-ratio", type=float, default=0.3, | |
| help="Fraction of each question batch that uses SELF-PLAY (model generates the " | |
| "question from a curriculum instruction, then solves it, rewarded on " | |
| "0.40 Γ question_quality + 0.60 Γ solution_quality). " | |
| "The remaining (1 - ratio) uses GROUNDED questions from GSM8K / MATH with " | |
| "gold-answer reward. " | |
| "0.0 = fully grounded (original behaviour), 1.0 = fully self-play. " | |
| "Default 0.3 β mirrors the PPO default of 30%% grounded / 70%% self-play " | |
| "(inverted here because grounded is our primary accuracy signal).", | |
| ) | |
| # ββ Phase-curriculum parameters βββββββββββββββββββββββββββββββββββββββββββ | |
| parser.add_argument( | |
| "--min-warmup", type=int, default=10, | |
| help="Minimum iterations in Phase 1 (grounded-only) before considering graduation " | |
| "to Phase 2 (self-play ramp). Prevents graduating on a lucky early batch. " | |
| "Default 10.", | |
| ) | |
| parser.add_argument( | |
| "--selfplay-gt-thresh", type=float, default=0.55, | |
| help="gt_match_rate threshold required to graduate from Phase 1 to Phase 2. " | |
| "Measures raw answer correctness (SymPy exact match), not reward-gamed " | |
| "combined_score. Default 0.55.", | |
| ) | |
| parser.add_argument( | |
| "--selfplay-grounded-thresh", type=float, default=0.60, | |
| help="grounded_accuracy (combined_score > 0.5) threshold for Phase 1 graduation. " | |
| "Default 0.60.", | |
| ) | |
| parser.add_argument( | |
| "--selfplay-step-thresh", type=float, default=0.65, | |
| help="step_accuracy (PRM steps rated > 0.5) threshold for Phase 1 graduation. " | |
| "Ensures the model has learned clean step format before entering self-play. " | |
| "Default 0.65.", | |
| ) | |
| parser.add_argument( | |
| "--selfplay-ramp-iters", type=int, default=20, | |
| help="Number of iterations to ramp self-play ratio from ~0%% to --self-play-ratio " | |
| "(Phase 2). Grounded anchor stays at β₯30%% throughout. Default 20.", | |
| ) | |
| parser.add_argument( | |
| "--grounded-floor", type=float, default=0.50, | |
| help="Minimum gt_match_rate to maintain during Phase 3. If it falls below this " | |
| "value, self-play is suspended until grounded performance recovers. " | |
| "Should be slightly below --selfplay-gt-thresh. Default 0.50.", | |
| ) | |
| # ββ Unified accuracy calculator parameters ββββββββββββββββββββββββββββββββ | |
| parser.add_argument( | |
| "--extractor-model", default="Qwen/Qwen2.5-0.5B-Instruct", | |
| help="Small model used for step chain extraction in the unified accuracy " | |
| "calculator (Phase 2+). Loaded in 4-bit to minimise VRAM. " | |
| "Default Qwen/Qwen2.5-0.5B-Instruct.", | |
| ) | |
| parser.add_argument( | |
| "--extraction-cache", default=None, | |
| help="Path to a pre-built JSON extraction cache from " | |
| "scripts/precompute_extraction_cache.py. When provided, grounded-data " | |
| "extractions are served from cache instead of calling the extractor LLM " | |
| "at training time. Only novel self-play solutions require live extraction. " | |
| "Default None (extraction always uses the LLM).", | |
| ) | |
| args = parser.parse_args() | |
| # ββ Run identity βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Establish run_name first β everything that follows (including log paths) | |
| # derives from it. | |
| run_name = args.run_name or f"grpo_{datetime.now():%Y%m%d_%H%M%S}" | |
| out_dir = Path(args.output_dir) / run_name | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| # ββ Log directory βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # One canonical directory for ALL run artefacts that are not model weights: | |
| # console_output.log β full terminal mirror (logger.* + print + tqdm) | |
| # config.json β serialised CLI args for reproducibility | |
| # metrics.csv β one row per iteration, written live | |
| # summary.json β written at the end of training | |
| log_dir = Path("logs") / "grpo" / run_name | |
| log_dir.mkdir(parents=True, exist_ok=True) | |
| # ββ Console log file βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| console_log_path = log_dir / "console_output.log" | |
| _console_log_file = console_log_path.open("a", encoding="utf-8", buffering=1) | |
| # 1) FileHandler on the root logger β every logger.*() call goes to file. | |
| # This is necessary because logging.StreamHandler stores a reference to | |
| # sys.stderr at *creation* time (inside logging.basicConfig above), so | |
| # reassigning sys.stderr later has no effect on existing handlers. | |
| _file_handler = _add_file_logging(console_log_path) | |
| # 2) TeeStream on sys.stdout / sys.stderr β every print() / tqdm bar / | |
| # library write also goes to file. Both together cover 100% of output. | |
| _original_stdout = sys.stdout | |
| _original_stderr = sys.stderr | |
| sys.stdout = TeeStream(_original_stdout, _console_log_file) | |
| sys.stderr = TeeStream(_original_stderr, _console_log_file) | |
| logger.info("=" * 70) | |
| logger.info("GRPO run: %s", run_name) | |
| logger.info("Checkpoints : %s", out_dir) | |
| logger.info("Logs : %s", log_dir) | |
| logger.info("Console log : %s", console_log_path) | |
| logger.info("=" * 70) | |
| # ββ Persist config for reproducibility βββββββββββββββββββββββββββββββββββ | |
| (log_dir / "config.json").write_text( | |
| json.dumps(vars(args), indent=2, default=str), encoding="utf-8" | |
| ) | |
| # ββ Live CSV metrics writer βββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Written one row per iteration so you can tail / open in Excel mid-run. | |
| _metrics_csv_path = log_dir / "metrics.csv" | |
| _csv_file: Optional[Any] = None | |
| _csv_writer: Optional[Any] = None | |
| def _append_metrics_csv(row: Dict[str, Any]) -> None: | |
| """Append one metrics row to metrics.csv; writes header on first call.""" | |
| nonlocal _csv_file, _csv_writer | |
| # Normalise floats to fixed precision so the CSV is human-readable. | |
| flat = { | |
| k: (f"{v:.6f}" if isinstance(v, float) else v) | |
| for k, v in row.items() | |
| } | |
| if _csv_writer is None: | |
| _csv_file = _metrics_csv_path.open("w", newline="", encoding="utf-8") | |
| _csv_writer = csv.DictWriter( | |
| _csv_file, | |
| fieldnames=list(flat.keys()), | |
| extrasaction="ignore", | |
| ) | |
| _csv_writer.writeheader() | |
| _csv_writer.writerow(flat) | |
| _csv_file.flush() # type: ignore[union-attr] | |
| # ββ Teardown: restore streams and close files on any exit path βββββββββββ | |
| # atexit runs unconditionally β on normal completion, keyboard interrupt, | |
| # unhandled exception, or OOM crash. This is equivalent to a finally block | |
| # without requiring the entire training body to be re-indented. | |
| def _teardown_logging() -> None: | |
| sys.stdout = _original_stdout | |
| sys.stderr = _original_stderr | |
| logging.getLogger().removeHandler(_file_handler) | |
| if not getattr(_file_handler.stream, "closed", False): | |
| _file_handler.close() | |
| if _csv_file is not None and not getattr(_csv_file, "closed", False): | |
| _csv_file.close() | |
| if not _console_log_file.closed: | |
| _console_log_file.close() | |
| atexit.register(_teardown_logging) | |
| random.seed(42) | |
| np.random.seed(42) | |
| torch.manual_seed(42) | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| attn_impl = select_attn_implementation() | |
| logger.info("Device: %s | attn: %s", device, attn_impl) | |
| if torch.cuda.is_available(): | |
| _gpu = torch.cuda.get_device_properties(0) | |
| logger.info( | |
| "GPU: %s | %.1f GB VRAM | capability sm_%d%d", | |
| _gpu.name, _gpu.total_memory / 1e9, _gpu.major, _gpu.minor, | |
| ) | |
| logger.info( | |
| "Run config: K=%d K_q=%d N=%d lr=%.1e T=%.2f max_new=%d | " | |
| "clip_eps=%.2f kl_coef=%.4f warmup=%d | diff_alpha=%.1f | " | |
| "self_play=%.0f%% grounded=%.0f%% | " | |
| "math_mix=%.0f%% math_maxdiff=%d | overlong_filter=%s | " | |
| "eval_every=%d eval_N=%d | grad_clip=%.2f save_every=%d keep_last=%d | " | |
| "question_GRPO=%s", | |
| args.group_size, args.q_group_size, args.questions_per_iter, args.learning_rate, | |
| args.temperature, args.max_new_tokens, | |
| args.clip_eps, args.kl_coef, args.warmup_iters, | |
| args.difficulty_alpha, | |
| 100 * args.self_play_ratio, 100 * (1 - args.self_play_ratio), | |
| 100 * args.math_mix_ratio, args.math_max_difficulty, | |
| args.overlong_filter, | |
| args.eval_every, args.eval_max_samples, | |
| args.max_grad_norm, args.save_every, args.keep_last, | |
| f"ENABLED (K_q={args.q_group_size})" if args.q_group_size > 1 else "disabled", | |
| ) | |
| # ββ Load model ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| logger.info("Loading model from %s ...", args.base_model) | |
| tokenizer = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=True) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.padding_side = "right" | |
| # SFT adapter checkpoints often don't save chat_template, which causes | |
| # tokenizer.apply_chat_template() to raise an error inside evaluate_gsm8k | |
| # β silently swallowed there, giving 0% accuracy even for a capable model. | |
| if tokenizer.chat_template is None: | |
| _base_model_name = "Qwen/Qwen2.5-Math-1.5B-Instruct" | |
| _meta_file = Path(args.base_model) / "pipeline_meta.json" | |
| if _meta_file.exists(): | |
| _meta = json.loads(_meta_file.read_text(encoding="utf-8")) | |
| _base_model_name = _meta.get("base_model", _base_model_name) | |
| logger.info( | |
| "Tokenizer has no chat_template; loading from base model %s", _base_model_name | |
| ) | |
| try: | |
| _base_tok = AutoTokenizer.from_pretrained(_base_model_name, trust_remote_code=True) | |
| if _base_tok.chat_template is not None: | |
| tokenizer.chat_template = _base_tok.chat_template | |
| logger.info("Chat template loaded successfully.") | |
| except Exception as _e: | |
| logger.warning("Could not load chat template from base model: %s", _e) | |
| # PEFT <= 0.12 crashes inside merge_and_unload() when the | |
| # transformers.integrations.tensor_parallel module is missing. | |
| if "transformers.integrations.tensor_parallel" not in sys.modules: | |
| sys.modules["transformers.integrations.tensor_parallel"] = types.ModuleType( | |
| "tensor_parallel" | |
| ) | |
| model_path = Path(args.base_model) | |
| is_adapter = (model_path / "adapter_config.json").exists() | |
| load_kwargs = dict( | |
| torch_dtype=torch.bfloat16, | |
| low_cpu_mem_usage=True, | |
| device_map={"": device}, | |
| trust_remote_code=True, | |
| attn_implementation=attn_impl, | |
| ) | |
| if is_adapter: | |
| # Determine actual base model from pipeline_meta.json (written by SFT pipeline). | |
| _meta_path = model_path / "pipeline_meta.json" | |
| _base_for_weights = "Qwen/Qwen2.5-Math-1.5B-Instruct" | |
| if _meta_path.exists(): | |
| _base_for_weights = json.loads( | |
| _meta_path.read_text(encoding="utf-8") | |
| ).get("base_model", _base_for_weights) | |
| logger.info("Detected PEFT adapter β loading base %s then merging %s", | |
| _base_for_weights, args.base_model) | |
| _base = AutoModelForCausalLM.from_pretrained(_base_for_weights, **load_kwargs) | |
| model = PeftModel.from_pretrained(_base, args.base_model).merge_and_unload() | |
| model = model.to(device) | |
| else: | |
| model = AutoModelForCausalLM.from_pretrained(args.base_model, **load_kwargs) | |
| # PEFT.merge_and_unload() leaves requires_grad=False on every param. | |
| # Re-enable unconditionally so GRPO's optimizer actually updates weights. | |
| params_before = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| for p in model.parameters(): | |
| p.requires_grad_(True) | |
| params_after = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| if params_before == 0 and params_after > 0: | |
| logger.warning( | |
| "All parameters were frozen on load (PEFT merge_and_unload bug). " | |
| "Re-enabled requires_grad β any prior frozen runs were training nothing." | |
| ) | |
| # Flash-Attn 2 turns attention memory from O(TΒ²) to O(T), so gradient | |
| # checkpointing gives almost no extra saving while costing ~30% more | |
| # backward time. Disable it when Flash is active (mirrors PPO runner). | |
| # gradient_checkpointing_enable requires use_reentrant=False on modern | |
| # PyTorch β the default True is deprecated and causes silent issues. | |
| # Also set use_cache=False: HF models can't use KV cache together with | |
| # gradient checkpointing (incompatible memory management). | |
| flash_active = attn_impl == "flash_attention_2" | |
| if not flash_active: | |
| model.gradient_checkpointing_enable( | |
| gradient_checkpointing_kwargs={"use_reentrant": False} | |
| ) | |
| if hasattr(model, "config"): | |
| model.config.use_cache = False | |
| logger.info("Gradient checkpointing ENABLED (use_reentrant=False, use_cache=False).") | |
| else: | |
| logger.info( | |
| "Flash-Attn 2 active β gradient checkpointing OFF " | |
| "(Flash already gives O(T) attention memory)." | |
| ) | |
| n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| n_total = sum(p.numel() for p in model.parameters()) | |
| logger.info( | |
| "Trainable parameters: %s / %s (%.1f%%)", | |
| f"{n_params:,}", f"{n_total:,}", 100.0 * n_params / max(n_total, 1), | |
| ) | |
| # ββ Reference policy (frozen copy) βββββββββββββββββββββββββββββββββββββββ | |
| # A deep copy of the policy at t=0, kept frozen forever. Used in the KL | |
| # penalty to anchor the policy against catastrophic forgetting of SFT | |
| # knowledge: L += Ξ² Γ (log Ο_ΞΈ - log Ο_ref) / T. | |
| # Memory cost: ~3 GB (1.5B Γ 2 bytes BF16) β negligible on 80 GB. | |
| ref_model: Optional[AutoModelForCausalLM] = None | |
| if args.kl_coef > 0.0: | |
| logger.info( | |
| "Creating frozen reference policy (kl_coef=%.4f, ~%.1f GB VRAM)...", | |
| args.kl_coef, sum(p.numel() for p in model.parameters()) * 2 / 1e9, | |
| ) | |
| ref_model = copy.deepcopy(model) | |
| ref_model.requires_grad_(False) | |
| ref_model.eval() | |
| logger.info("Reference policy ready.") | |
| else: | |
| logger.info("KL coef = 0 β no reference policy created.") | |
| optimizer = torch.optim.AdamW( | |
| [p for p in model.parameters() if p.requires_grad], | |
| lr=args.learning_rate, | |
| fused=torch.cuda.is_available(), | |
| ) | |
| # ββ LR schedule: linear warmup β cosine decay ββββββββββββββββββββββββββββ | |
| # Linear warmup avoids the large initial gradient spike when the policy | |
| # starts updating from an SFT checkpoint. Cosine decay then smoothly | |
| # reduces LR toward min_lr as training progresses (standard in RLHF runs). | |
| from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR | |
| _n_warmup = max(1, args.warmup_iters) | |
| _n_total = max(1, args.num_iterations) | |
| _n_decay = max(1, _n_total - _n_warmup) | |
| _min_lr = args.learning_rate * args.min_lr_ratio | |
| _warmup_sched = LinearLR( | |
| optimizer, | |
| start_factor=0.1, | |
| end_factor=1.0, | |
| total_iters=_n_warmup, | |
| ) | |
| _cosine_sched = CosineAnnealingLR( | |
| optimizer, | |
| T_max=_n_decay, | |
| eta_min=_min_lr, | |
| ) | |
| scheduler = SequentialLR( | |
| optimizer, | |
| schedulers=[_warmup_sched, _cosine_sched], | |
| milestones=[_n_warmup], | |
| ) | |
| logger.info( | |
| "LR schedule: %.1e warmup(%d iters) β cosine decay(%d iters, min=%.1e)", | |
| args.learning_rate, _n_warmup, _n_decay, _min_lr, | |
| ) | |
| # ββ Load data ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| gsm8k_pairs = load_gsm8k(args.gsm8k_data) | |
| if not gsm8k_pairs: | |
| logger.error("No training data found at %s β cannot train. Exiting.", args.gsm8k_data) | |
| sys.exit(1) | |
| # Optional MATH dataset mixing | |
| math_pairs: List[Dict[str, str]] = [] | |
| if args.math_mix_ratio > 0.0: | |
| math_pairs = load_math_dataset( | |
| local_path=args.math_data, | |
| max_difficulty=args.math_max_difficulty, | |
| ) | |
| if math_pairs: | |
| logger.info( | |
| "MATH mixing: %.0f%% MATH (%d problems) + %.0f%% GSM8K (%d problems)", | |
| 100 * args.math_mix_ratio, len(math_pairs), | |
| 100 * (1 - args.math_mix_ratio), len(gsm8k_pairs), | |
| ) | |
| else: | |
| logger.warning("No MATH pairs loaded β using GSM8K only.") | |
| # Combined pool used for difficulty sampling; kept separate for VRAM-aware | |
| # batch construction (sampler draws from each pool proportionally). | |
| qa_pairs = gsm8k_pairs # for reward env (all GSM8K gold answers needed) | |
| # ββ Load PRM (optional) βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| prm: Optional[ProcessRewardScorer] = None | |
| if args.use_prm: | |
| try: | |
| prm = ProcessRewardScorer( | |
| model_name=args.prm_model, | |
| device=device, | |
| load_in_4bit=True, | |
| ) | |
| logger.info("PRM loaded: %s (4-bit)", args.prm_model) | |
| except Exception as exc: | |
| logger.warning("PRM load failed (%s); running without PRM.", exc) | |
| # Build a minimal math_env just for its reward utilities (compute_grounded_reward). | |
| # value_model=None is safe: it's only stored as self.value and never invoked on | |
| # the grounded-reward path, so GRPO avoids the ~3 GB ValueHead backbone entirely. | |
| from src.rl.unified_accuracy import StepChainExtractor, UnifiedAccuracyCalculator | |
| _extractor = StepChainExtractor( | |
| model_name=args.extractor_model, | |
| device=str(device), | |
| cache_path=args.extraction_cache, | |
| ) | |
| _unified_calc = UnifiedAccuracyCalculator(extractor=_extractor, question_evaluator=None) | |
| logger.info( | |
| "Unified accuracy calculator ready (extractor=%s, cache=%s)", | |
| args.extractor_model, | |
| args.extraction_cache or "none", | |
| ) | |
| # Eagerly load the extractor model now to avoid a 30β60 s stall on the | |
| # first training iteration that triggers live (non-cached) extraction. | |
| logger.info("Warming up step-chain extractor (eager load)...") | |
| _extractor.warmup() | |
| logger.info("Extractor warmup complete") | |
| # ββ LLM-backed question classifier (replaces keyword regex) βββββββββββββ | |
| # Uses the already-loaded policy model for topic classification during | |
| # self-play reward computation. ~60-120 ms per call, cached, falls back | |
| # to regex on any error. Dramatically more accurate than keyword matching | |
| # for geometry, calculus, competition_math, and statistics. | |
| from src.rl.llm_question_classifier import LLMQuestionClassifier | |
| _llm_classifier = LLMQuestionClassifier( | |
| model=model, | |
| tokenizer=tokenizer, | |
| device=device, | |
| cache_size=10_000, | |
| ) | |
| math_env = CurriculumMathEnvironment( | |
| policy_model=model, | |
| value_model=None, | |
| tokenizer=tokenizer, | |
| # Feed all training questions as the novelty reference set so | |
| # session_novelty is measured against the actual training distribution β | |
| # a self-play question that mimics a dataset question gets low novelty. | |
| reference_questions=[p["question"] for p in gsm8k_pairs], | |
| grounded_qa_pairs=qa_pairs, | |
| prm_scorer=prm, | |
| max_solution_tokens=args.max_new_tokens, | |
| device=device, | |
| unified_accuracy_calc=_unified_calc, | |
| ) | |
| # Inject LLM classifier into the question quality evaluator | |
| math_env.question_evaluator.classifier = _llm_classifier | |
| # Wire the question_evaluator into the unified calc after math_env is available | |
| _unified_calc.question_evaluator = math_env.question_evaluator | |
| # Bootstrap curriculum from dataset skill_ids when the training data | |
| # contains structured records (NuminaMath / OpenMathInstruct format). | |
| # Falls back to the keyword-classifier path for plain GSM8K. | |
| _raw_records: list = [] | |
| _train_path = Path(args.gsm8k_data) | |
| if _train_path.exists(): | |
| with _train_path.open(encoding="utf-8") as _f: | |
| for _line in _f: | |
| _line = _line.strip() | |
| if _line: | |
| try: | |
| _raw_records.append(json.loads(_line)) | |
| except json.JSONDecodeError: | |
| pass | |
| if any("skill_id" in r for r in _raw_records[:20]): | |
| logger.info( | |
| "Detected structured dataset (%d records) β bootstrapping " | |
| "curriculum from skill_ids instead of keyword classifier.", | |
| len(_raw_records), | |
| ) | |
| math_env.curriculum_manager.initialize_from_dataset(_raw_records) | |
| else: | |
| logger.info("Plain dataset detected β using keyword-classifier bootstrap.") | |
| # ββ Difficulty-adaptive sampling state βββββββββββββββββββββββββββββββββββ | |
| # Track per-question win-rate. Questions where the model scores correctly | |
| # 20-80% of the time are "on the margin" and provide the richest gradient | |
| # signal. Questions it always gets right (win_rateβ1) or always gets wrong | |
| # (win_rateβ0) contribute little after the first few iterations. | |
| from collections import defaultdict | |
| _q_wins: Dict[str, int] = defaultdict(int) | |
| _q_attempts: Dict[str, int] = defaultdict(int) | |
| def _question_key(q: str) -> str: | |
| """Stable hash fingerprint β collision-resistant for any pool size.""" | |
| import hashlib | |
| return hashlib.md5(q.encode(), usedforsecurity=False).hexdigest() | |
| def _sample_by_difficulty( | |
| pool: List[Dict[str, str]], n: int, alpha: float | |
| ) -> List[Dict[str, str]]: | |
| """ | |
| Sample ``n`` questions from ``pool``, weighting by how informative each is. | |
| Informativeness = 1 - |win_rate - 0.5| Γ 2 β [0, 1] | |
| win_rate = 0.0 or 1.0 β informativeness = 0 (model already knows / lost cause) | |
| win_rate = 0.5 β informativeness = 1 (most uncertain = best signal) | |
| ``alpha`` sharpens the weighting (higher = stronger preference for win_rateβ0.5). | |
| Unseen questions get weight 0.75 to encourage exploration. | |
| A 5% floor prevents any question from being permanently excluded. | |
| """ | |
| if alpha <= 0.0: | |
| return random.sample(pool, min(n, len(pool))) | |
| weights = [] | |
| for qa in pool: | |
| key = _question_key(qa["question"]) | |
| att = _q_attempts[key] | |
| if att == 0: | |
| w = 0.75 | |
| else: | |
| win_rate = _q_wins[key] / att | |
| info = 1.0 - abs(win_rate - 0.5) * 2.0 # β [0, 1] | |
| w = max(info ** alpha, 0.05) | |
| weights.append(w) | |
| total_w = sum(weights) | |
| probs = [w / total_w for w in weights] | |
| chosen = np.random.choice( | |
| len(pool), size=min(n, len(pool)), replace=False, p=probs | |
| ) | |
| return [pool[i] for i in chosen] | |
| # ββ Metrics log βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| metrics_log: List[Dict] = [] | |
| # ββ Initial eval βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if not args.skip_initial_eval: | |
| logger.info("=" * 70) | |
| logger.info("INITIAL EVALUATION (Iteration 0)") | |
| logger.info("=" * 70) | |
| initial_eval = evaluate_policy( | |
| model, tokenizer, | |
| args.eval_data_path, args.eval_max_samples, args.eval_max_new_tokens, | |
| math_env=math_env, | |
| pass_at_k=args.eval_pass_at_k, | |
| ) | |
| # accuracy == combined_score = 0.50Γcorrect + 0.40Γprocess(prm_final,prm_mean) + 0.10Γfmt | |
| # This is identical to the GRPO training objective. | |
| _log_eval_result("INITIAL (iter 0)", initial_eval, best=None) | |
| metrics_log.append({"iteration": 0, **initial_eval}) | |
| best_accuracy = float(initial_eval.get("accuracy", 0.0)) | |
| best_combined = float(initial_eval.get("combined_score", 0.0)) | |
| best_prm_mean = float(initial_eval.get("prm_mean", 0.0)) | |
| else: | |
| best_accuracy = 0.0 | |
| best_combined = 0.0 | |
| best_prm_mean = 0.0 | |
| # ββ Training curriculum phase FSM ββββββββββββββββββββββββββββββββββββββββ | |
| # Phase 1 β GROUNDED_ONLY: self-play ratio is forced to 0 until the model | |
| # has established reliable answer correctness (gt_match_rate) and step | |
| # quality (step_accuracy) on grounded data. | |
| # Phase 2 β SELFPLAY_RAMP: self-play ratio ramps from ~0 β self_play_ratio | |
| # ceiling over selfplay_ramp_iters, keeping β₯30% grounded as an anchor. | |
| # Phase 3 β CONTINUOUS: ratio holds at ceiling; grounded floor is monitored | |
| # and self-play is suspended whenever gt_match_rate drops below the floor. | |
| from enum import Enum, auto as _auto | |
| class _Phase(Enum): | |
| GROUNDED_ONLY = _auto() | |
| SELFPLAY_RAMP = _auto() | |
| CONTINUOUS = _auto() | |
| _phase: _Phase = _Phase.GROUNDED_ONLY | |
| _selfplay_iterations: int = 0 # iterations spent in Phase 2+ | |
| _selfplay_suspended: bool = False | |
| _effective_sp_ratio: float = 0.0 # computed each iteration from phase | |
| # ββ Chain scoring calibration state ββββββββββββββββββββββββββββββββββββββ | |
| # During Phase 2 SELFPLAY_RAMP the extractor runs in shadow mode (computing | |
| # scores but NOT affecting rewards) to build a rolling calibration window. | |
| # use_chain_scoring only flips True when both the chainβPRM correlation AND | |
| # the extraction success rate cross their thresholds β a data-driven gate, | |
| # not a schedule-driven one. | |
| _use_chain_as_primary: bool = False # True once calibration passes | |
| _chain_prm_correlation: float = 0.0 # rolling Pearson r (chain vs PRM) | |
| _extraction_success_rate: float = 0.0 # rolling extraction success fraction | |
| # Cross-iteration rolling window (up to 200 paired samples) | |
| _rolling_chain_scores: List[float] = [] | |
| _rolling_prm_scores: List[float] = [] | |
| _rolling_successes: List[int] = [] # 1 = successful extraction, 0 = failed | |
| _CALIB_WINDOW = 50 # minimum samples before computing correlation | |
| _CALIB_MAX = 200 # cap rolling lists at this length | |
| # Throttle shadow extraction: only run the extractor on every Nth grounded | |
| # solution during calibration. Reduces overhead ~4Γ while still reaching | |
| # the 50-sample window within a few iterations. | |
| _SHADOW_EVERY = 4 | |
| _shadow_extract_counter: int = 0 | |
| # ββ Training βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| for iteration in range(1, args.num_iterations + 1): | |
| iter_start = time.perf_counter() | |
| logger.info("=" * 70) | |
| logger.info("GRPO ITERATION %d/%d", iteration, args.num_iterations) | |
| logger.info("=" * 70) | |
| # Sample questions β difficulty-weighted from the mixed pool. | |
| # When math_pairs is non-empty, draw proportionally: N*ratio from MATH | |
| # and N*(1-ratio) from GSM8K. The difficulty sampler handles each pool | |
| # independently so MATH problems get their own win-rate tracking. | |
| # | |
| # MATH ratio ramp: once past --math-ramp-start, linearly increase the | |
| # MATH fraction toward --math-mix-ratio-late over the next 10 iterations. | |
| # This progressively raises difficulty after the policy has stabilised. | |
| _effective_math_ratio = args.math_mix_ratio | |
| if args.math_mix_ratio_late is not None and iteration > args.math_ramp_start: | |
| _ramp_progress = min(1.0, (iteration - args.math_ramp_start) / 10.0) | |
| _effective_math_ratio = ( | |
| args.math_mix_ratio | |
| + _ramp_progress * (args.math_mix_ratio_late - args.math_mix_ratio) | |
| ) | |
| if math_pairs and _effective_math_ratio > 0.0: | |
| n_math = max(1, round(args.questions_per_iter * _effective_math_ratio)) | |
| n_gsm8k = max(1, args.questions_per_iter - n_math) | |
| math_batch = _sample_by_difficulty(math_pairs, n_math, alpha=args.difficulty_alpha) | |
| gsm8k_batch = _sample_by_difficulty(gsm8k_pairs, n_gsm8k, alpha=args.difficulty_alpha) | |
| questions_batch = math_batch + gsm8k_batch | |
| random.shuffle(questions_batch) | |
| else: | |
| questions_batch = _sample_by_difficulty( | |
| gsm8k_pairs, args.questions_per_iter, alpha=args.difficulty_alpha | |
| ) | |
| cur_lr = optimizer.param_groups[0]["lr"] | |
| # Temperature annealing: linearly decay T from peak β min_temp over the run. | |
| # Early iterations need high T for exploration; later ones need lower T | |
| # to consolidate learned strategies (and close the training/eval gap). | |
| _anneal_frac = min(1.0, (iteration - 1) / max(1, args.num_iterations - 1)) | |
| _annealed_temp = args.temperature * (1.0 - 0.5 * _anneal_frac) # 0.8 β 0.4 | |
| logger.info( | |
| "LR this iteration: %.2e | T=%.3f | MATH ratio=%.0f%%", | |
| cur_lr, _annealed_temp, 100 * _effective_math_ratio, | |
| ) | |
| all_rewards: List[float] = [] | |
| all_q_rewards: List[float] = [] | |
| _grounded_rewards: List[float] = [] | |
| _sp_rewards: List[float] = [] | |
| _grounded_step_accs: List[float] = [] | |
| _grounded_lccps: List[float] = [] | |
| _grounded_gt_matches: List[bool] = [] | |
| # Chain scoring accumulators (populated only in Phase 2+ when | |
| # math_env.use_chain_scoring is True) | |
| _chain_arith_scores: List[float] = [] | |
| _chain_dep_scores: List[float] = [] | |
| _chain_integrity_scores: List[float] = [] | |
| _sp_chain_scores: List[float] = [] # self-play chain integrity | |
| _skipped_zero_var: int = 0 # groups skipped due to zero reward variance | |
| # Per-component question quality accumulators | |
| _qc_topic: List[float] = [] | |
| _qc_diff: List[float] = [] | |
| _qc_clarity: List[float] = [] | |
| _qc_novelty: List[float] = [] | |
| _qc_solvability: List[float] = [] | |
| skipped = 0 | |
| n_groups = 0 | |
| n_self_play = 0 | |
| q_gen_attempts = 0 # total generate_question() calls | |
| q_gen_valid = 0 # non-empty questions produced (len > 10 chars) | |
| q_quality_good = 0 # self-play groups where question_reward > 0.5 | |
| total_loss_val = 0.0 | |
| # Determine how many of this iteration's groups use self-play question | |
| # generation vs grounded (dataset) questions. | |
| # Phase-driven ratio: Phase 1 forces 0; Phase 2 ramps from 0 to ceiling; | |
| # Phase 3 holds at ceiling (args.self_play_ratio). Grounded floor recovery | |
| # (computed at end of previous iteration) overrides to 0 regardless of phase. | |
| if _phase == _Phase.GROUNDED_ONLY: | |
| _effective_sp_ratio = 0.0 | |
| elif _phase == _Phase.SELFPLAY_RAMP: | |
| _grounded_anchor = max(0.30, 1.0 - (_selfplay_iterations / max(1, args.selfplay_ramp_iters))) | |
| _effective_sp_ratio = 1.0 - _grounded_anchor | |
| else: # CONTINUOUS | |
| _effective_sp_ratio = args.self_play_ratio | |
| if _selfplay_suspended: | |
| _effective_sp_ratio = 0.0 # grounded floor recovery pass | |
| n_self_play_target = int(round(len(questions_batch) * _effective_sp_ratio)) | |
| # Build a random set of group indices that will use self-play. | |
| # Random interleaving distributes self-play uniformly across the batch | |
| # instead of front-loading all self-play groups, which would cause the | |
| # gradient to shift mid-batch as the objective changes character. | |
| _all_indices = list(range(len(questions_batch))) | |
| random.shuffle(_all_indices) | |
| _self_play_indices = set(_all_indices[:n_self_play_target]) | |
| # Zero gradients once before the loop β we accumulate them via | |
| # per-group .backward() calls instead of building one giant graph. | |
| # Keeping all K*N forward passes alive until a single backward() | |
| # at the end would hold O(K*N) computation graphs in GPU memory | |
| # simultaneously (64 graphs at K=4, N=16), risking OOM. Calling | |
| # .backward() immediately after each group frees that graph right | |
| # away; gradients accumulate in .grad tensors without extra memory. | |
| optimizer.zero_grad() | |
| pbar = tqdm(questions_batch, desc=f"Iter {iteration} GRPO groups", unit="q") | |
| for _group_idx, qa in enumerate(pbar): | |
| # ββ Decide: self-play (model generates question) or grounded βββββ | |
| # Random interleaving: self-play slots chosen before the loop. | |
| use_self_play = _group_idx in _self_play_indices | |
| if use_self_play: | |
| # ββ SELF-PLAY BRANCH βββββββββββββββββββββββββββββββββββββββββ | |
| # 1. Sample a curriculum instruction (topic + difficulty target) | |
| instruction, target_topic, target_difficulty = math_env.sample_instruction() | |
| # MATH L4-L5: exclude from self-play generation β problems at this | |
| # difficulty produce unanchored reward because the verification | |
| # cascade cannot reliably confirm answers. Fall back to grounded. | |
| if target_difficulty >= 4.0: | |
| use_self_play = False | |
| # 2. Model generates the question from the instruction. | |
| # This is the "proposer" role in Theme #4 self-improvement: | |
| # the model creates its own challenge. | |
| q_gen_attempts += 1 | |
| # ββ TWO-PHASE QUESTION GRPO (when --q-group-size β₯ 2) ββββββββ | |
| # Phase 1: sample K_q question candidates, store their token | |
| # IDs for a question-level GRPO update. | |
| # Phase 2: for each candidate, generate M=group_size solutions, | |
| # score them, and run a solution-level GRPO update. | |
| # The per-question reward (mean solution reward) is then used | |
| # to run GRPO on the question tokens β gradients flow back | |
| # through the question tokens for the first time. | |
| if args.q_group_size > 1: | |
| _q_temp = min(0.90, _annealed_temp + 0.05) | |
| q_cands, q_ids_all, q_masks_all, q_olps_all = generate_questions_batched( | |
| model=model, | |
| tokenizer=tokenizer, | |
| instruction=instruction, | |
| K_q=args.q_group_size, | |
| max_new_tokens=128, | |
| temperature=_q_temp, | |
| device=device, | |
| ) | |
| # Keep only candidates with enough substance | |
| _valid_q = [ | |
| (q, ids, mask, olp) | |
| for q, ids, mask, olp | |
| in zip(q_cands, q_ids_all, q_masks_all, q_olps_all) | |
| if len(q.strip()) >= 10 | |
| ] | |
| if not _valid_q: | |
| logger.debug("Two-phase SP: all %d question candidates too short, skipping.", args.q_group_size) | |
| skipped += 1 | |
| continue | |
| q_gen_valid += 1 | |
| n_self_play += 1 | |
| # Phase 2: score solutions for each valid question candidate | |
| _question_agg_rewards: List[float] = [] # one per valid candidate | |
| _q_total_loss_val: float = 0.0 | |
| for _q_text, _q_ids, _q_mask, _q_olp in _valid_q: | |
| solution_prompt = math_env.format_solution_prompt(_q_text) | |
| sols_q, ids_q, masks_q, olps_q = generate_solutions_batched( | |
| model=model, | |
| tokenizer=tokenizer, | |
| prompt=solution_prompt, | |
| K=args.group_size, | |
| max_new_tokens=args.max_new_tokens, | |
| temperature=_annealed_temp, | |
| device=device, | |
| ) | |
| # Overlong filter | |
| if args.overlong_filter: | |
| _vf = [ | |
| t for t in zip(sols_q, ids_q, masks_q, olps_q) | |
| if int(t[2].sum().item()) < args.max_new_tokens | |
| ] | |
| if _vf: | |
| sols_q, ids_q, masks_q, olps_q = map(list, zip(*_vf)) # type: ignore | |
| else: | |
| skipped += 1 | |
| _question_agg_rewards.append(0.0) | |
| continue | |
| # Score solutions | |
| _sol_rewards: List[float] = [] | |
| for _sol in sols_q: | |
| _r, _q_rew, _, _q_met = compute_self_play_reward( | |
| question=_q_text, | |
| solution=_sol, | |
| target_topic=target_topic, | |
| target_difficulty=target_difficulty, | |
| math_env=math_env, | |
| ) | |
| _sol_rewards.append(_r) | |
| all_q_rewards.append(_q_rew) | |
| _qc_topic.append(_q_met["topic_match"]) | |
| _qc_diff.append(_q_met["difficulty_fit"]) | |
| _qc_clarity.append(_q_met["clarity"]) | |
| _qc_novelty.append(_q_met["novelty"]) | |
| _qc_solvability.append(_q_met["solvability"]) | |
| all_rewards.extend(_sol_rewards) | |
| _sp_rewards.extend(_sol_rewards) | |
| # Aggregate question reward = mean of its solution rewards | |
| _q_agg = float(np.mean(_sol_rewards)) | |
| _question_agg_rewards.append(_q_agg) | |
| # ββ Solution-level GRPO update βββββββββββββββββββββββ | |
| _sol_loss = grpo_loss_for_group( | |
| model=model, | |
| input_ids_list=ids_q, | |
| response_masks=masks_q, | |
| rewards=_sol_rewards, | |
| old_log_probs=olps_q, | |
| clip_eps=args.clip_eps, | |
| kl_coef=args.kl_coef, | |
| ref_model=ref_model, | |
| ) | |
| if _sol_loss is not None: | |
| _sol_loss.backward() | |
| total_loss_val += _sol_loss.item() | |
| _q_total_loss_val += _sol_loss.item() | |
| n_groups += 1 | |
| else: | |
| skipped += 1 | |
| _skipped_zero_var += 1 | |
| # ββ Question-level GRPO update βββββββββββββββββββββββββββ | |
| # Advantages are computed over the K_q question-reward | |
| # scalars. The IS ratio is exp(new_lp_question - old_lp_question). | |
| # kl_coef=0 here: there is no reference distribution for questions. | |
| _q_ids_v = [t[1] for t in _valid_q] | |
| _q_masks_v = [t[2] for t in _valid_q] | |
| _q_olps_v = [t[3] for t in _valid_q] | |
| _q_loss = grpo_loss_for_group( | |
| model=model, | |
| input_ids_list=_q_ids_v, | |
| response_masks=_q_masks_v, | |
| rewards=_question_agg_rewards, | |
| old_log_probs=_q_olps_v, | |
| clip_eps=args.clip_eps, | |
| kl_coef=0.0, # no ref model for question tokens | |
| ref_model=None, | |
| ) | |
| if _q_loss is not None: | |
| _q_loss.backward() | |
| logger.debug( | |
| "Q-GRPO: loss=%.4f q_rewards=%s (variance=%.4f)", | |
| _q_loss.item(), | |
| [f"{r:.3f}" for r in _question_agg_rewards], | |
| float(np.var(_question_agg_rewards)), | |
| ) | |
| # Group-level quality: at least one candidate scored > 0.5 | |
| if any(r > 0.5 for r in _question_agg_rewards): | |
| q_quality_good += 1 | |
| # pbar update then skip to next group (all done above) | |
| _mean_r_sp = float(np.mean(all_rewards[-len(_valid_q)*args.group_size:])) if all_rewards else 0.0 | |
| _q_acc_pct = 100.0 * q_quality_good / max(1, n_self_play) | |
| pbar.set_postfix( | |
| loss=f"{_q_total_loss_val / max(1, len(_valid_q)):.4f}", | |
| mean_r=f"{_mean_r_sp:.3f}", | |
| q_acc=f"{_q_acc_pct:.0f}%", | |
| q_rew=f"{float(np.mean(all_q_rewards)):.3f}" if all_q_rewards else "n/a", | |
| skip=skipped, | |
| ) | |
| continue # β everything handled above; jump to next group | |
| # ββ K_q=1: original single-question path (no question GRPO) ββ | |
| question = generate_question( | |
| model=model, | |
| tokenizer=tokenizer, | |
| instruction=instruction, | |
| max_new_tokens=128, # questions are short | |
| device=device, | |
| # Slightly warmer than solution temperature for diversity, | |
| # but anneals with the same schedule to stay consistent. | |
| temperature=min(0.90, _annealed_temp + 0.05), | |
| ) | |
| # A valid question must have at least some substance. | |
| # Reject single-word, empty, or nonsensical outputs. | |
| if len(question.strip()) < 10: | |
| logger.debug( | |
| "Self-play: generated question too short (%d chars), skipping group.", | |
| len(question.strip()), | |
| ) | |
| skipped += 1 | |
| continue | |
| q_gen_valid += 1 | |
| n_self_play += 1 | |
| gold = None # no gold answer β rewarded on question quality | |
| else: | |
| # ββ GROUNDED BRANCH ββββββββββββββββββββββββββββββββββββββββββ | |
| # Use pre-existing dataset question with known gold answer. | |
| question = qa["question"] | |
| gold = qa["gold_final"] | |
| target_topic = "grounded" | |
| target_difficulty = 0.5 | |
| # --- Generate K solutions (batched β single model.generate call) --- | |
| solution_prompt = math_env.format_solution_prompt(question) | |
| solutions, input_ids_list, response_masks, old_log_probs_list = ( | |
| generate_solutions_batched( | |
| model=model, | |
| tokenizer=tokenizer, | |
| prompt=solution_prompt, | |
| K=args.group_size, | |
| max_new_tokens=args.max_new_tokens, | |
| temperature=_annealed_temp, | |
| device=device, | |
| ) | |
| ) | |
| # --- Overlong filter: drop truncated solutions (no Final Answer) --- | |
| # A response that hit max_new_tokens was cut off mid-generation; | |
| # it almost certainly didn't produce a valid "Final Answer: X" line, | |
| # so its reward is unreliable noise. Dropping it keeps the group | |
| # advantage estimates clean. | |
| if args.overlong_filter: | |
| _valid = [ | |
| (sol, ids, mask, olp) | |
| for sol, ids, mask, olp | |
| in zip(solutions, input_ids_list, response_masks, old_log_probs_list) | |
| if int(mask.sum().item()) < args.max_new_tokens | |
| ] | |
| if _valid: | |
| solutions, input_ids_list, response_masks, old_log_probs_list = ( | |
| zip(*_valid) # type: ignore[assignment] | |
| ) | |
| solutions = list(solutions) | |
| input_ids_list = list(input_ids_list) | |
| response_masks = list(response_masks) | |
| old_log_probs_list = list(old_log_probs_list) | |
| else: | |
| # All K solutions were truncated β skip group. | |
| skipped += 1 | |
| continue | |
| # --- Score each solution (self-play: Q+S reward; grounded: S only) --- | |
| rewards = [] | |
| _sp_q_rew_this_group: List[float] = [] | |
| for sol in solutions: | |
| if use_self_play: | |
| # compute_reward = 0.40Γquestion_quality + 0.60Γsolution_quality | |
| # This is the core Theme #4 signal: the model is rewarded | |
| # for generating a well-formed, appropriately difficult, | |
| # solvable question AND for solving it correctly. | |
| r, q_rew, _, q_met = compute_self_play_reward( | |
| question=question, | |
| solution=sol, | |
| target_topic=target_topic, | |
| target_difficulty=target_difficulty, | |
| math_env=math_env, | |
| ) | |
| _sp_q_rew_this_group.append(q_rew) | |
| all_q_rewards.append(q_rew) | |
| # Collect per-component breakdown (same question, all K solutions | |
| # get the same q_metrics β average to reduce noise). | |
| _qc_topic.append(q_met["topic_match"]) | |
| _qc_diff.append(q_met["difficulty_fit"]) | |
| _qc_clarity.append(q_met["clarity"]) | |
| _qc_novelty.append(q_met["novelty"]) | |
| _qc_solvability.append(q_met["solvability"]) | |
| # Self-play chain integrity (Phase 2+ only; None in Phase 1) | |
| _sp_ci = q_met.get("sp_chain_integrity_score") | |
| if _sp_ci is not None: | |
| _sp_chain_scores.append(float(_sp_ci)) | |
| else: | |
| r_dict = compute_grounded_reward( | |
| question=question, | |
| solution=sol, | |
| gold_final=gold, | |
| math_env=math_env, | |
| ) | |
| r = r_dict["combined_score"] | |
| _grounded_step_accs.append(r_dict["step_accuracy"]) | |
| _grounded_lccps.append(r_dict["lccp"]) | |
| _grounded_gt_matches.append(bool(r_dict["gt_match"])) | |
| if r_dict.get("chain_arith_score") is not None: | |
| _chain_arith_scores.append(float(r_dict["chain_arith_score"])) | |
| if r_dict.get("chain_dep_score") is not None: | |
| _chain_dep_scores.append(float(r_dict["chain_dep_score"])) | |
| if r_dict.get("chain_integrity_score") is not None: | |
| _chain_integrity_scores.append(float(r_dict["chain_integrity_score"])) | |
| # Shadow extraction for calibration: during SELFPLAY_RAMP, | |
| # run the chain extractor even before use_chain_scoring is | |
| # activated so we can measure chainβPRM correlation. These | |
| # scores do NOT affect the reward β they only feed the | |
| # calibration window that decides when to flip use_chain_scoring. | |
| # Throttled to every _SHADOW_EVERY solutions to avoid making | |
| # each iteration ~10Γ slower (extractor adds ~8s per call). | |
| _shadow_extract_counter += 1 | |
| if ( | |
| _phase == _Phase.SELFPLAY_RAMP | |
| and not _use_chain_as_primary | |
| and _unified_calc is not None | |
| and _shadow_extract_counter % _SHADOW_EVERY == 0 | |
| ): | |
| _prm_ps = ( | |
| 0.60 * r_dict.get("prm_final_score", 0.0) | |
| + 0.40 * r_dict.get("prm_mean_score", 0.0) | |
| ) | |
| try: | |
| _shadow = _unified_calc.compute( | |
| solution=sol, | |
| gold_answer=gold, | |
| question=question, | |
| topic=target_topic, | |
| phase="grounded", | |
| ) | |
| _rolling_chain_scores.append(_shadow.chain_integrity_score) | |
| _rolling_prm_scores.append(_prm_ps) | |
| _rolling_successes.append(1 if _shadow.extraction_succeeded else 0) | |
| except Exception: | |
| _rolling_successes.append(0) | |
| rewards.append(r) | |
| all_rewards.extend(rewards) | |
| # Route to path-specific accumulators for separate batch_acc reporting | |
| if use_self_play: | |
| _sp_rewards.extend(rewards) | |
| else: | |
| _grounded_rewards.extend(rewards) | |
| # A self-play group is "accurate" if the question it generated scored | |
| # above 0.5 on question quality β meaning it was clear, on-topic, | |
| # appropriately difficult, and solvable. | |
| if use_self_play and _sp_q_rew_this_group: | |
| if float(np.mean(_sp_q_rew_this_group)) > 0.5: | |
| q_quality_good += 1 | |
| # --- PAL/SymPy verification gate (self-play only) --- | |
| # Drop the group if the tiered cascade cannot confirm a consistent, | |
| # independently-verifiable answer. This prevents circular PRM reward | |
| # from being the sole correctness anchor on self-play examples. | |
| if use_self_play: | |
| if not _verify_self_play_answer(solutions, target_topic, target_difficulty): | |
| skipped += 1 | |
| continue # no gradient for this group | |
| # --- Update difficulty stats (grounded questions only β self-play | |
| # questions are ephemeral and have no stable key) --- | |
| if not use_self_play: | |
| _key = _question_key(question) | |
| _q_attempts[_key] += len(solutions) | |
| # Win = reward in the top half of THIS group, not an absolute 0.5 threshold. | |
| # Using a relative threshold avoids the case where all solutions score 0.55 | |
| # (all "wins" β easy) or all score 0.45 (all "losses" β impossible) when the | |
| # rewards are actually similar and carry no difficulty information. | |
| _group_median = float(np.median(rewards)) | |
| _q_wins[_key] += sum(1 for r in rewards if r > _group_median) | |
| # --- GRPO loss (IS clip + optional KL penalty) + immediate backward --- | |
| # Skip near-uniform groups early: when reward std < 0.02 (on a [0,1] | |
| # scale) all advantages collapse to ~0 and the gradient contribution | |
| # is negligible β equivalent to wasted compute. This is a stricter | |
| # guard than the eps=1e-8 inside grpo_loss_for_group, which only | |
| # catches exactly-equal rewards (e.g. all 0.998 passes through it). | |
| _reward_std = float(np.std(rewards)) | |
| if _reward_std < 0.02: | |
| skipped += 1 | |
| _skipped_zero_var += 1 | |
| _pf_zv: Dict = dict(mean_r=f"{np.mean(rewards):.3f}", skip=skipped, loss="0var") | |
| pbar.set_postfix(**_pf_zv) | |
| continue | |
| group_loss = grpo_loss_for_group( | |
| model=model, | |
| input_ids_list=input_ids_list, | |
| response_masks=response_masks, | |
| rewards=rewards, | |
| old_log_probs=old_log_probs_list, | |
| clip_eps=args.clip_eps, | |
| kl_coef=args.kl_coef, | |
| ref_model=ref_model, | |
| ) | |
| if group_loss is None: | |
| skipped += 1 | |
| _skipped_zero_var += 1 | |
| _pf: Dict = dict(mean_r=f"{np.mean(rewards):.3f}", skip=skipped, loss="skip") | |
| if n_self_play > 0 and all_q_rewards: | |
| _q_acc_pct = 100.0 * q_quality_good / max(1, n_self_play) | |
| _pf["q_acc"] = f"{_q_acc_pct:.0f}%" | |
| pbar.set_postfix(**_pf) | |
| continue | |
| # Backprop immediately β frees this group's computation graph. | |
| # Gradients from all valid groups accumulate in param.grad. | |
| group_loss.backward() | |
| total_loss_val += group_loss.item() | |
| n_groups += 1 | |
| _pf = dict( | |
| mean_r=f"{np.mean(rewards):.3f}", | |
| loss=f"{group_loss.item():.4f}", | |
| skip=skipped, | |
| ) | |
| if n_self_play > 0 and all_q_rewards: | |
| # Show live question-gen accuracy in the tqdm bar. | |
| # q_acc = fraction of self-play groups whose generated question | |
| # scored > 0.5 on quality (clear, on-topic, solvable). | |
| _q_acc_pct = 100.0 * q_quality_good / max(1, n_self_play) | |
| _pf["q_acc"] = f"{_q_acc_pct:.0f}%" | |
| _pf["q_rew"] = f"{float(np.mean(all_q_rewards)):.3f}" | |
| pbar.set_postfix(**_pf) | |
| # --- Gradient step: normalise accumulated grads then step --- | |
| if n_groups > 0: | |
| # Divide accumulated grads by n_groups to get the true average | |
| # (equivalent to averaging the group losses before backward). | |
| if n_groups > 1: | |
| for p in model.parameters(): | |
| if p.grad is not None: | |
| p.grad.div_(n_groups) | |
| torch.nn.utils.clip_grad_norm_( | |
| [p for p in model.parameters() if p.requires_grad], | |
| args.max_grad_norm, | |
| ) | |
| optimizer.step() | |
| loss_val = total_loss_val / n_groups | |
| else: | |
| loss_val = 0.0 | |
| scheduler.step() | |
| iter_time = time.perf_counter() - iter_start | |
| mean_r = float(np.mean(all_rewards)) if all_rewards else 0.0 | |
| std_r = float(np.std(all_rewards)) if all_rewards else 0.0 | |
| acc_r = float(np.mean([r > 0.5 for r in all_rewards])) if all_rewards else 0.0 | |
| grounded_acc_r = ( | |
| float(np.mean([r > 0.5 for r in _grounded_rewards])) | |
| if _grounded_rewards else 0.0 | |
| ) | |
| mean_step_acc = ( | |
| float(np.mean(_grounded_step_accs)) | |
| if _grounded_step_accs else 0.0 | |
| ) | |
| mean_lccp = ( | |
| float(np.mean(_grounded_lccps)) | |
| if _grounded_lccps else 0.0 | |
| ) | |
| mean_q_r = float(np.mean(all_q_rewards)) if all_q_rewards else 0.0 | |
| # Chain scoring batch means (non-None only in Phase 2+) | |
| mean_chain_arith = float(np.mean(_chain_arith_scores)) if _chain_arith_scores else None | |
| mean_chain_dep = float(np.mean(_chain_dep_scores)) if _chain_dep_scores else None | |
| mean_chain_integrity = float(np.mean(_chain_integrity_scores)) if _chain_integrity_scores else None | |
| mean_sp_chain = float(np.mean(_sp_chain_scores)) if _sp_chain_scores else None | |
| # ββ gt_match_rate: raw answer-correctness on grounded examples ββββββββ | |
| # This is the primary Phase-1 graduation signal β unlike grounded_acc_r | |
| # which is (combined_score > 0.5), gt_match_rate is the direct SymPy | |
| # exact-match fraction and cannot be gamed by a high PRM/format score. | |
| gt_match_rate = ( | |
| float(sum(_grounded_gt_matches) / len(_grounded_gt_matches)) | |
| if _grounded_gt_matches else 0.0 | |
| ) | |
| # ββ Phase FSM transitions βββββββββββββββββββββββββββββββββββββββββββββ | |
| if _phase == _Phase.GROUNDED_ONLY: | |
| _graduation_ready = ( | |
| gt_match_rate >= args.selfplay_gt_thresh | |
| and grounded_acc_r >= args.selfplay_grounded_thresh | |
| and mean_step_acc >= args.selfplay_step_thresh | |
| and iteration >= args.min_warmup | |
| ) | |
| if _graduation_ready: | |
| _phase = _Phase.SELFPLAY_RAMP | |
| logger.info( | |
| "PHASE β SELFPLAY_RAMP at iter %d " | |
| "(gt_match=%.2f grounded_acc=%.2f step_acc=%.2f) β " | |
| "shadow extraction active; chain scoring deferred until " | |
| "calibration passes (corrβ₯0.70, success_rateβ₯0.80)", | |
| iteration, gt_match_rate, grounded_acc_r, mean_step_acc, | |
| ) | |
| # NOTE: do NOT set math_env.use_chain_scoring = True here. | |
| # The extractor runs in shadow mode first; use_chain_scoring | |
| # flips to True below once calibration thresholds are met. | |
| elif _phase in (_Phase.SELFPLAY_RAMP, _Phase.CONTINUOUS): | |
| _selfplay_iterations += 1 | |
| if _phase == _Phase.SELFPLAY_RAMP and _selfplay_iterations >= args.selfplay_ramp_iters: | |
| _phase = _Phase.CONTINUOUS | |
| logger.info( | |
| "PHASE β CONTINUOUS at iter %d (ramp complete after %d iters)", | |
| iteration, _selfplay_iterations, | |
| ) | |
| # ββ Data-driven chain scoring activation βββββββββββββββββββββββββ | |
| # Trim rolling window to _CALIB_MAX before computing correlation. | |
| if len(_rolling_chain_scores) > _CALIB_MAX: | |
| _rolling_chain_scores = _rolling_chain_scores[-_CALIB_MAX:] | |
| _rolling_prm_scores = _rolling_prm_scores[-_CALIB_MAX:] | |
| _rolling_successes = _rolling_successes[-_CALIB_MAX:] | |
| if not _use_chain_as_primary and len(_rolling_chain_scores) >= _CALIB_WINDOW: | |
| from scipy.stats import pearsonr # noqa: PLC0415 | |
| try: | |
| _r, _ = pearsonr( | |
| _rolling_chain_scores[-_CALIB_WINDOW:], | |
| _rolling_prm_scores[-_CALIB_WINDOW:], | |
| ) | |
| _chain_prm_correlation = float(_r) | |
| except Exception: | |
| _chain_prm_correlation = 0.0 | |
| _rolling_n = len(_rolling_successes[-_CALIB_WINDOW:]) | |
| _extraction_success_rate = ( | |
| sum(_rolling_successes[-_CALIB_WINDOW:]) / _rolling_n | |
| if _rolling_n > 0 else 0.0 | |
| ) | |
| if ( | |
| _chain_prm_correlation >= 0.70 | |
| and _extraction_success_rate >= 0.80 | |
| ): | |
| _use_chain_as_primary = True | |
| math_env.use_chain_scoring = True | |
| logger.info( | |
| "CHAIN PRIMARY activated at iter %d: " | |
| "corr=%.2f extraction_rate=%.2f (window=%d) β " | |
| "unified calculator now drives reward scoring", | |
| iteration, _chain_prm_correlation, | |
| _extraction_success_rate, _CALIB_WINDOW, | |
| ) | |
| else: | |
| logger.debug( | |
| "Chain calibration: corr=%.2f success_rate=%.2f " | |
| "(need corrβ₯0.70, successβ₯0.80; window=%d/%d)", | |
| _chain_prm_correlation, _extraction_success_rate, | |
| len(_rolling_chain_scores), _CALIB_WINDOW, | |
| ) | |
| # Grounded floor monitoring: suspend self-play if answer correctness | |
| # drops below the floor set at graduation minus 5pp. Self-play | |
| # resumes automatically next iteration if performance recovers. | |
| _prev_suspended = _selfplay_suspended | |
| _selfplay_suspended = ( | |
| bool(_grounded_gt_matches) and gt_match_rate < args.grounded_floor | |
| ) | |
| if _selfplay_suspended and not _prev_suspended: | |
| logger.warning( | |
| "GROUNDED FLOOR: gt_match_rate=%.2f fell below floor=%.2f β " | |
| "suspending self-play for recovery", | |
| gt_match_rate, args.grounded_floor, | |
| ) | |
| elif not _selfplay_suspended and _prev_suspended: | |
| logger.info( | |
| "GROUNDED FLOOR: gt_match_rate=%.2f recovered above floor=%.2f β " | |
| "resuming self-play", | |
| gt_match_rate, args.grounded_floor, | |
| ) | |
| # Question generation accuracy metrics (self-play only) | |
| q_gen_valid_rate = (q_gen_valid / q_gen_attempts) if q_gen_attempts > 0 else 0.0 | |
| q_quality_rate = (q_quality_good / n_self_play) if n_self_play > 0 else 0.0 | |
| # Per-component averages (all non-empty across K solutions Γ groups) | |
| mean_q_topic = float(np.mean(_qc_topic)) if _qc_topic else 0.0 | |
| mean_q_diff = float(np.mean(_qc_diff)) if _qc_diff else 0.0 | |
| mean_q_clarity = float(np.mean(_qc_clarity)) if _qc_clarity else 0.0 | |
| mean_q_novelty = float(np.mean(_qc_novelty)) if _qc_novelty else 0.0 | |
| mean_q_solvab = float(np.mean(_qc_solvability)) if _qc_solvability else 0.0 | |
| _cur_lr = optimizer.param_groups[0]["lr"] | |
| # ββ LLM classifier stats (every 5 iters to avoid log spam) βββββββββ | |
| if iteration % 5 == 0: | |
| _llm_classifier.log_stats() | |
| # ββ Primary summary line βββββββββββββββββββββββββββββββββββββββββββββ | |
| logger.info( | |
| "Iter %d | loss=%.4f | reward mean=%.3f std=%.3f | " | |
| "gt_match=%.1f%% | grounded_acc=%.1f%% | step_acc=%.1f%% | lccp=%.1f%% | " | |
| "batch_acc=%.1f%% | phase=%s sp_ratio=%.0f%% | " | |
| "groups=%d skipped=%d(0var=%d) | lr=%.2e | %.1fs", | |
| iteration, loss_val, mean_r, std_r, | |
| 100 * gt_match_rate, | |
| 100 * grounded_acc_r, | |
| 100 * mean_step_acc, | |
| 100 * mean_lccp, | |
| 100 * acc_r, | |
| _phase.name, 100 * _effective_sp_ratio, | |
| n_groups, skipped, _skipped_zero_var, _cur_lr, iter_time, | |
| ) | |
| # Starvation warning: if >30% of groups were skipped due to zero reward | |
| # variance (all K solutions same score), the curriculum difficulty is | |
| # mis-calibrated β either too easy (all correct) or too hard (all wrong). | |
| _total_attempted = n_groups + skipped | |
| if _total_attempted > 0 and _skipped_zero_var / _total_attempted > 0.30: | |
| logger.warning( | |
| "STARVATION: %.0f%% of groups skipped (zero variance). " | |
| "grounded_acc=%.1f%% suggests curriculum is %s. " | |
| "Consider adjusting --difficulty-alpha.", | |
| 100 * _skipped_zero_var / _total_attempted, | |
| 100 * grounded_acc_r, | |
| "too easy (raise alpha)" if grounded_acc_r > 0.75 else "too hard (lower alpha)", | |
| ) | |
| # ββ Question-generation accuracy line (only when self-play is active) β | |
| if n_self_play > 0: | |
| logger.info( | |
| " Question generation: %d/%d valid (%.0f%%) | " | |
| "q_reward=%.3f | q_acc=%.1f%% (>0.5 quality) | " | |
| "topic=%.2f diff=%.2f clarity=%.2f novelty=%.2f solvability=%.2f", | |
| q_gen_valid, q_gen_attempts, 100 * q_gen_valid_rate, | |
| mean_q_r, 100 * q_quality_rate, | |
| mean_q_topic, mean_q_diff, mean_q_clarity, | |
| mean_q_novelty, mean_q_solvab, | |
| ) | |
| iter_metrics: Dict = { | |
| "iteration": iteration, | |
| "loss": loss_val, | |
| "mean_reward": mean_r, | |
| "std_reward": std_r, | |
| "batch_accuracy": acc_r, | |
| "grounded_accuracy": grounded_acc_r, | |
| "gt_match_rate": round(gt_match_rate, 4), | |
| "step_accuracy": mean_step_acc, | |
| "lccp": mean_lccp, | |
| "n_groups": n_groups, | |
| "skipped_groups": skipped, | |
| "learning_rate": _cur_lr, | |
| "iter_time_s": iter_time, | |
| # ββ Phase curriculum metrics ββββββββββββββββββββββββββββββββββββ | |
| "training_phase": _phase.name, | |
| "effective_sp_ratio": round(_effective_sp_ratio, 3), | |
| "selfplay_suspended": int(_selfplay_suspended), | |
| # ββ Chain scoring metrics (Phase 2+, None in Phase 1) ββββββββββββ | |
| "chain_arith_score": round(mean_chain_arith, 4) if mean_chain_arith is not None else None, | |
| "chain_dep_score": round(mean_chain_dep, 4) if mean_chain_dep is not None else None, | |
| "chain_integrity_score": round(mean_chain_integrity, 4) if mean_chain_integrity is not None else None, | |
| "sp_chain_integrity_score": round(mean_sp_chain, 4) if mean_sp_chain is not None else None, | |
| # ββ Chain calibration metrics (populated during SELFPLAY_RAMP shadow mode) | |
| "chain_prm_correlation": round(_chain_prm_correlation, 3), | |
| "extraction_success_rate": round(_extraction_success_rate, 3), | |
| "chain_scoring_active": int(_use_chain_as_primary), | |
| # ββ Question-generation metrics βββββββββββββββββββββββββββββββββ | |
| "n_self_play_groups": n_self_play, | |
| "q_gen_attempts": q_gen_attempts, | |
| "q_gen_valid": q_gen_valid, | |
| "q_gen_valid_rate": round(q_gen_valid_rate, 4), | |
| "mean_question_reward": round(mean_q_r, 4), | |
| "q_quality_rate": round(q_quality_rate, 4), | |
| "q_topic_match": round(mean_q_topic, 4), | |
| "q_difficulty_fit": round(mean_q_diff, 4), | |
| "q_clarity": round(mean_q_clarity, 4), | |
| "q_novelty": round(mean_q_novelty, 4), | |
| "q_solvability": round(mean_q_solvab, 4), | |
| } | |
| # --- Eval --- | |
| if iteration % args.eval_every == 0: | |
| _eval_ds_label = _infer_eval_dataset_name(args.eval_data_path) | |
| logger.info("Evaluating %s (%d samples)...", _eval_ds_label, args.eval_max_samples) | |
| eval_res = evaluate_policy( | |
| model, tokenizer, | |
| args.eval_data_path, args.eval_max_samples, args.eval_max_new_tokens, | |
| math_env=math_env, | |
| pass_at_k=args.eval_pass_at_k, | |
| ) | |
| # accuracy == combined_score: 0.50Γcorrect + 0.40Γprocess(prm_final,prm_mean) + 0.10Γfmt | |
| cur_combined = float(eval_res.get("combined_score", best_combined)) | |
| cur_prm_mean = float(eval_res.get("prm_mean", best_prm_mean)) | |
| _log_eval_result(f"iter {iteration}", eval_res, best=best_combined) | |
| # ββ Checkpoint: save when combined_score strictly improves ββββββββ | |
| # combined_score is a continuous variable; any improvement in | |
| # correctness, PRM quality, SymPy, or format moves it. | |
| if cur_combined > best_combined + 1e-4: | |
| reason = f"combined {cur_combined:.4f} > {best_combined:.4f}" | |
| best_combined = cur_combined | |
| best_prm_mean = max(best_prm_mean, cur_prm_mean) | |
| best_accuracy = best_combined | |
| best_path = out_dir / "best_policy" | |
| model.save_pretrained(str(best_path)) | |
| tokenizer.save_pretrained(str(best_path)) | |
| logger.info("New best saved β %s (%s)", best_path, reason) | |
| iter_metrics.update(eval_res) | |
| # --- Save checkpoint (respect --save-every / --keep-last) --- | |
| is_last_iter = iteration == args.num_iterations | |
| should_save = is_last_iter or ( | |
| args.save_every > 0 and iteration % args.save_every == 0 | |
| ) | |
| if should_save: | |
| ckpt_path = out_dir / f"iter_{iteration:04d}" | |
| ckpt_path.mkdir(exist_ok=True) | |
| model.save_pretrained(str(ckpt_path)) | |
| tokenizer.save_pretrained(str(ckpt_path)) | |
| # Prune older iter_* checkpoints beyond the rolling window. | |
| if args.keep_last and args.keep_last > 0: | |
| existing = sorted( | |
| p for p in out_dir.iterdir() | |
| if p.is_dir() and p.name.startswith("iter_") | |
| ) | |
| to_remove = existing[: -args.keep_last] | |
| for old in to_remove: | |
| try: | |
| shutil.rmtree(old) | |
| logger.info("Pruned old checkpoint: %s", old.name) | |
| except OSError as exc: | |
| logger.warning("Could not prune %s: %s", old.name, exc) | |
| # ββ Write metrics to both JSONL (full history) and CSV (live row) ββββ | |
| metrics_log.append(iter_metrics) | |
| (out_dir / "metrics.jsonl").write_text( | |
| "\n".join(json.dumps(m) for m in metrics_log), encoding="utf-8" | |
| ) | |
| # CSV: one row per iteration, flushed immediately so you can | |
| # `tail -f logs/grpo/<run>/metrics.csv` or open it in Excel mid-run. | |
| # `iter_metrics.update(eval_res)` overwrites step_accuracy/lccp on eval iters. | |
| # We capture the is_eval flag here for clarity. | |
| _is_eval_iter = "combined_score" in iter_metrics | |
| _append_metrics_csv({ | |
| "iteration": iter_metrics["iteration"], | |
| "timestamp": datetime.now().isoformat(timespec="seconds"), | |
| # ββ Per-iteration training signal βββββββββββββββββββββββββββββββ | |
| "loss": iter_metrics.get("loss", 0.0), | |
| "mean_reward": iter_metrics.get("mean_reward", 0.0), | |
| "std_reward": iter_metrics.get("std_reward", 0.0), | |
| "batch_accuracy": iter_metrics.get("batch_accuracy", 0.0), | |
| "grounded_acc": iter_metrics.get("grounded_accuracy", 0.0), | |
| "gt_match_rate": iter_metrics.get("gt_match_rate", 0.0), | |
| # step_accuracy / lccp: training value on non-eval iters, | |
| # eval value on eval iters (update() overwrites them). | |
| "step_accuracy": iter_metrics.get("step_accuracy", 0.0), | |
| "lccp": iter_metrics.get("lccp", 0.0), | |
| "n_groups": iter_metrics.get("n_groups", 0), | |
| "skipped_groups": iter_metrics.get("skipped_groups", 0), | |
| "n_sp_groups": iter_metrics.get("n_self_play_groups", 0), | |
| "sp_ratio": iter_metrics.get("effective_sp_ratio", 0.0), | |
| "sp_suspended": iter_metrics.get("selfplay_suspended", 0), | |
| "training_phase": iter_metrics.get("training_phase", ""), | |
| "learning_rate": iter_metrics.get("learning_rate", 0.0), | |
| "iter_time_s": iter_metrics.get("iter_time_s", 0.0), | |
| # ββ Question-generation quality βββββββββββββββββββββββββββββββββ | |
| "q_reward": iter_metrics.get("mean_question_reward", ""), | |
| "q_valid_rate": iter_metrics.get("q_gen_valid_rate", ""), | |
| "q_novelty": iter_metrics.get("q_novelty", ""), | |
| "q_solvability": iter_metrics.get("q_solvability", ""), | |
| # ββ Chain scoring calibration βββββββββββββββββββββββββββββββββββ | |
| "chain_prm_corr": iter_metrics.get("chain_prm_correlation", ""), | |
| "chain_scoring_on": iter_metrics.get("chain_scoring_active", ""), | |
| # ββ Eval checkpoint metrics (every eval_every iters) ββββββββββββ | |
| "eval_combined": iter_metrics.get("combined_score", "") if _is_eval_iter else "", | |
| "eval_correct_rt": iter_metrics.get("correct_rate", "") if _is_eval_iter else "", | |
| "eval_prm": iter_metrics.get("prm_mean", "") if _is_eval_iter else "", | |
| "eval_step_acc": iter_metrics.get("step_accuracy", "") if _is_eval_iter else "", | |
| "eval_lccp": iter_metrics.get("lccp", "") if _is_eval_iter else "", | |
| "eval_format": iter_metrics.get("format_mean", "") if _is_eval_iter else "", | |
| "eval_n_scored": iter_metrics.get("n_scored", "") if _is_eval_iter else "", | |
| "eval_final_ans": iter_metrics.get("final_answer_accuracy", "") if _is_eval_iter else "", | |
| }) | |
| logger.info("=" * 70) | |
| logger.info("GRPO training complete.") | |
| logger.info( | |
| "Best training-objective score : %.4f " | |
| "(0.50Γcorrect + 0.40Γprocess[0.60Γprm_final+0.40Γprm_mean] + 0.10Γfmt)", | |
| best_combined, | |
| ) | |
| logger.info("Best PRM component mean : %.3f", best_prm_mean) | |
| logger.info("Checkpoints : %s", out_dir) | |
| logger.info("Logs : %s", log_dir) | |
| logger.info("Console log : %s", console_log_path) | |
| logger.info("=" * 70) | |
| # ββ Final summary βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| summary: Dict[str, Any] = { | |
| "run_name": run_name, | |
| "best_accuracy": best_combined, # accuracy == combined_score | |
| "best_combined": best_combined, | |
| "best_prm_mean": best_prm_mean, | |
| "total_iterations": args.num_iterations, | |
| "checkpoints_dir": str(out_dir), | |
| "log_dir": str(log_dir), | |
| "console_log": str(console_log_path), | |
| "metrics_csv": str(_metrics_csv_path), | |
| "metrics_jsonl": str(out_dir / "metrics.jsonl"), | |
| } | |
| (log_dir / "summary.json").write_text( | |
| json.dumps(summary, indent=2, default=str), encoding="utf-8" | |
| ) | |
| logger.info("Summary written to %s", log_dir / "summary.json") | |
| # ββ Auto-generate demo plots βββββββββββββββββββββββββββββββββββββββββββββββ | |
| _metrics_jsonl = out_dir / "metrics.jsonl" | |
| if _metrics_jsonl.exists(): | |
| try: | |
| import importlib | |
| if importlib.util.find_spec("matplotlib") is None: | |
| logger.warning( | |
| "matplotlib not installed β skipping auto-plot. " | |
| "Install with: pip install matplotlib then run: " | |
| "python scripts/plot_grpo_run.py %s", | |
| _metrics_jsonl, | |
| ) | |
| else: | |
| from scripts.plot_grpo_run import generate_plots as _gen_plots | |
| _plot_dir = _gen_plots(_metrics_jsonl) | |
| logger.info("Plots saved β %s", _plot_dir) | |
| except Exception as _plot_exc: | |
| logger.warning( | |
| "Plot generation failed (%s: %s). " | |
| "Run manually: python scripts/plot_grpo_run.py %s", | |
| type(_plot_exc).__name__, _plot_exc, _metrics_jsonl, | |
| ) | |
| # Explicit teardown (atexit is the safety net for crashes; calling here | |
| # ensures everything is flushed and closed before the process returns | |
| # normally β atexit won't double-close because _teardown_logging is | |
| # idempotent via the .closed checks). | |
| _teardown_logging() | |
| if __name__ == "__main__": | |
| main() | |