{ "cells": [ { "cell_type": "markdown", "id": "md-header", "metadata": {}, "source": [ "# AxiomForgeAI — GRPO Training\n", "\n", "Training loop structured around the classic RL interface:\n", "\n", "```\n", "env.reset(qa) → start episode, receive question\n", "env.step(action)→ submit solution, receive reward + feedback\n", "env.state → inspect episode metadata\n", "env.close() → persist curriculum, release resources\n", "```\n", "\n", "All scoring, curriculum management, and reward computation are handled inside\n", "`AxiomforgeaiEnvironment`. The notebook owns model loading, solution generation,\n", "GRPO loss, and optimisation." ] }, { "cell_type": "code", "execution_count": null, "id": "cell-imports", "metadata": {}, "outputs": [], "source": [ "# ── Standard library ──────────────────────────────────────────────────────────\n", "from __future__ import annotations\n", "\n", "import argparse, copy, csv, hashlib, json, logging, random, re\n", "import shutil, sys, time, types\n", "from collections import defaultdict\n", "from datetime import datetime\n", "from enum import Enum, auto as _auto\n", "from pathlib import Path\n", "from typing import Any, Dict, List, Optional, Tuple\n", "\n", "# ── Third-party ───────────────────────────────────────────────────────────────\n", "import numpy as np\n", "import torch\n", "import torch.nn.functional as F\n", "from peft import PeftModel\n", "from tqdm.auto import tqdm\n", "from transformers import AutoModelForCausalLM, AutoTokenizer\n", "\n", "# Ensure the repo root is always on sys.path regardless of the kernel's cwd.\n", "_REPO_ROOT = Path(__file__).resolve().parent if \"__file__\" in dir() else Path.cwd()\n", "if str(_REPO_ROOT) not in sys.path:\n", " sys.path.insert(0, str(_REPO_ROOT))\n", "\n", "# ── RL Environment (reset / step / state / close) ───────────────────────────\n", "from server.AxiomForgeAI_environment import AxiomforgeaiEnvironment\n", "from models import AxiomforgeaiAction\n", "\n", "# ── Existing utilities from scripts/ and src/ ────────────────────────────────\n", "from scripts.convert_gsm8k_to_sft import parse_gsm8k_answer\n", "from scripts.eval_sft_inference import evaluate_gsm8k\n", "from src.rl.prm_scorer import ProcessRewardScorer\n", "from src.rl.math_environment_curriculum import CurriculumMathEnvironment\n", "from src.rl.unified_accuracy import StepChainExtractor, UnifiedAccuracyCalculator\n", "from src.rl.llm_question_classifier import LLMQuestionClassifier\n", "from src.config.prompts import create_generator_messages\n", "from src.sft.solution_format import extract_final_answer_numeric_str\n", "from src.utils.attn_backend import select_attn_implementation\n", "from src.utils.csv_logger import CSVLogger\n", "\n", "logging.basicConfig(\n", " level=logging.INFO,\n", " format=\"%(asctime)s %(levelname)-8s %(name)s - %(message)s\",\n", ")\n", "logger = logging.getLogger(__name__)\n", "\n", "if torch.cuda.is_available():\n", " torch.set_float32_matmul_precision(\"high\")\n", " torch.backends.cuda.matmul.allow_tf32 = True\n", " torch.backends.cudnn.allow_tf32 = True\n", " torch.backends.cudnn.benchmark = True" ] }, { "cell_type": "code", "execution_count": null, "id": "cell-config", "metadata": {}, "outputs": [], "source": [ "# ── Training configuration ────────────────────────────────────────────────────\n", "# Edit these values before running. Every key matches the corresponding\n", "# CLI flag in scripts/run_grpo_training.py for compatibility.\n", "\n", "args = argparse.Namespace(\n", " # ── Paths ─────────────────────────────────────────────────────────────────\n", " base_model = \"checkpoints/dual_task_v1\",\n", " output_dir = \"checkpoints/grpo\",\n", " gsm8k_data = \"data/sft/gsm8k_sft.jsonl\",\n", " eval_data_path = \"data/sft/gsm8k_test.jsonl\",\n", " math_data = None,\n", " extraction_cache = \"data/extraction_cache.json\",\n", " run_name = None, # auto-set to grpo_\n", "\n", " # ── Training scale ────────────────────────────────────────────────────────\n", " num_iterations = 60,\n", " questions_per_iter = 20,\n", " group_size = 10, # K solutions per question (GRPO group)\n", " q_group_size = 2, # K_q question candidates for two-phase self-play\n", "\n", " # ── Optimiser ─────────────────────────────────────────────────────────────\n", " learning_rate = 5e-6,\n", " max_grad_norm = 0.5,\n", " kl_coef = 0.06,\n", " clip_eps = 0.2,\n", " warmup_iters = 8,\n", " min_lr_ratio = 0.1,\n", "\n", " # ── Generation ────────────────────────────────────────────────────────────\n", " max_new_tokens = 1000,\n", " temperature = 0.8,\n", " overlong_filter = True,\n", "\n", " # ── Dataset mixing (GSM8K → MATH curriculum ramp) ─────────────────────────\n", " math_mix_ratio = 0.30, # MATH fraction at ramp start\n", " math_mix_ratio_late = 0.50, # MATH fraction after ramp\n", " math_ramp_start = 18, # iteration at which MATH mix starts increasing\n", " math_max_difficulty = 3,\n", " difficulty_alpha = 3.5, # Zipf-style sampling; higher → more hard questions\n", "\n", " # ── Evaluation ────────────────────────────────────────────────────────────\n", " eval_every = 5,\n", " eval_max_samples = 150,\n", " eval_max_new_tokens = 1000,\n", " eval_pass_at_k = 0,\n", " skip_initial_eval = False,\n", "\n", " # ── PRM (Process Reward Model) ────────────────────────────────────────────\n", " use_prm = True,\n", " prm_model = \"Qwen/Qwen2.5-Math-PRM-7B\",\n", "\n", " # ── Chain / unified accuracy extractor ───────────────────────────────────\n", " extractor_model = \"Qwen/Qwen2.5-0.5B-Instruct\",\n", "\n", " # ── Checkpointing ─────────────────────────────────────────────────────────\n", " save_every = 5,\n", " keep_last = 4,\n", "\n", " # ── Self-play phase curriculum ────────────────────────────────────────────\n", " # Phase 1 (GROUNDED_ONLY): grounded-only until min_warmup iters pass AND\n", " # grounded accuracy ≥ selfplay_gt_thresh AND step accuracy ≥ selfplay_step_thresh\n", " # Phase 2 (SELFPLAY_RAMP): linearly ramp self-play over selfplay_ramp_iters\n", " # Phase 3 (CONTINUOUS): stable mix; falls back to grounded if quality drops\n", " self_play_ratio = 0.70, # target self-play fraction in Phase 3\n", " min_warmup = 12, # minimum grounded-only iterations before SP\n", " selfplay_gt_thresh = 0.65, # gt_match_rate required to unlock self-play\n", " selfplay_grounded_thresh= 0.65, # grounded accuracy required to unlock self-play\n", " selfplay_step_thresh = 0.68, # step-level accuracy threshold\n", " selfplay_ramp_iters = 28, # iterations to ramp from 0 → self_play_ratio\n", " grounded_floor = 0.55, # below this grounded acc → suspend self-play\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "cell-infra", "metadata": {}, "outputs": [], "source": [ "# ── Run identity + directory layout ──────────────────────────────────────────\n", "run_name = args.run_name or f\"grpo_{datetime.now():%Y%m%d_%H%M%S}\"\n", "out_dir = Path(args.output_dir) / run_name\n", "log_dir = Path(\"logs\") / \"grpo\" / run_name\n", "out_dir.mkdir(parents=True, exist_ok=True)\n", "log_dir.mkdir(parents=True, exist_ok=True)\n", "\n", "# ── Console mirror (TeeStream) ────────────────────────────────────────────────\n", "class TeeStream:\n", " \"\"\"Mirrors every write to a terminal stream into a log file.\"\"\"\n", " def __init__(self, primary, secondary):\n", " self.primary, self.secondary = primary, secondary\n", " def write(self, data):\n", " self.primary.write(data); self.secondary.write(data); return len(data)\n", " def flush(self):\n", " self.primary.flush(); self.secondary.flush()\n", " def isatty(self):\n", " return getattr(self.primary, \"isatty\", lambda: False)()\n", " def fileno(self):\n", " return self.primary.fileno()\n", "\n", "console_log_path = log_dir / \"console_output.log\"\n", "_console_log_file = console_log_path.open(\"a\", encoding=\"utf-8\", buffering=1)\n", "\n", "def _add_file_logging(path: Path) -> logging.FileHandler:\n", " fh = logging.FileHandler(path, mode=\"a\", encoding=\"utf-8\")\n", " fh.setLevel(logging.DEBUG)\n", " fh.setFormatter(logging.Formatter(\"%(asctime)s %(levelname)-8s %(name)s - %(message)s\"))\n", " logging.getLogger().addHandler(fh)\n", " return fh\n", "\n", "_file_handler = _add_file_logging(console_log_path)\n", "_orig_stdout = sys.stdout\n", "_orig_stderr = sys.stderr\n", "sys.stdout = TeeStream(_orig_stdout, _console_log_file)\n", "sys.stderr = TeeStream(_orig_stderr, _console_log_file)\n", "\n", "# ── Live CSV metrics writer (via CSVLogger) ───────────────────────────────────\n", "# CSVLogger writes key metrics to metrics.csv and full metrics as per-step JSON\n", "# under logs/grpo//detailed_metrics/step_NNNN.json\n", "_csv_logger = CSVLogger(\n", " project=\"grpo\",\n", " run_name=run_name,\n", " log_dir=\"logs\",\n", " config=vars(args),\n", " log_detailed=True,\n", ")\n", "\n", "def _append_metrics_csv(row: Dict[str, Any], step: Optional[int] = None) -> None:\n", " \"\"\"Write one row of metrics via CSVLogger (key metrics → CSV, all → JSON).\"\"\"\n", " _csv_logger.log(row, step=step)\n", "\n", "# ── Teardown (atexit + explicit) ──────────────────────────────────────────────\n", "def _teardown() -> None:\n", " sys.stdout = _orig_stdout\n", " sys.stderr = _orig_stderr\n", " logging.getLogger().removeHandler(_file_handler)\n", " if not getattr(_file_handler.stream, \"closed\", False): _file_handler.close()\n", " if not _console_log_file.closed: _console_log_file.close()\n", " _csv_logger.finish()\n", "\n", "import atexit; atexit.register(_teardown)\n", "\n", "random.seed(42); np.random.seed(42); torch.manual_seed(42)\n", "\n", "logger.info(\"Run: %s | out=%s | log=%s\", run_name, out_dir, log_dir)" ] }, { "cell_type": "code", "execution_count": null, "id": "cell-model", "metadata": {}, "outputs": [], "source": [ "# ── Device + attention backend ────────────────────────────────────────────────\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "attn_impl = select_attn_implementation()\n", "logger.info(\"Device: %s | attn: %s\", device, attn_impl)\n", "if torch.cuda.is_available():\n", " _g = torch.cuda.get_device_properties(0)\n", " logger.info(\"GPU: %s | %.1f GB | sm_%d%d\", _g.name, _g.total_memory/1e9, _g.major, _g.minor)\n", "\n", "# ── Policy model ──────────────────────────────────────────────────────────────\n", "logger.info(\"Loading model from %s ...\", args.base_model)\n", "tokenizer = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=True)\n", "if tokenizer.pad_token is None:\n", " tokenizer.pad_token = tokenizer.eos_token\n", "tokenizer.padding_side = \"right\"\n", "\n", "# Patch missing chat_template (common in SFT adapter checkpoints)\n", "if tokenizer.chat_template is None:\n", " _base_name = \"Qwen/Qwen2.5-Math-1.5B-Instruct\"\n", " _meta = Path(args.base_model) / \"pipeline_meta.json\"\n", " if _meta.exists():\n", " _base_name = json.loads(_meta.read_text(encoding=\"utf-8\")).get(\"base_model\", _base_name)\n", " try:\n", " _bt = AutoTokenizer.from_pretrained(_base_name, trust_remote_code=True)\n", " if _bt.chat_template: tokenizer.chat_template = _bt.chat_template\n", " logger.info(\"Chat template loaded from %s\", _base_name)\n", " except Exception as _e:\n", " logger.warning(\"Could not load chat template: %s\", _e)\n", "\n", "# Patch missing tensor_parallel shim (PEFT ≤ 0.12)\n", "if \"transformers.integrations.tensor_parallel\" not in sys.modules:\n", " sys.modules[\"transformers.integrations.tensor_parallel\"] = types.ModuleType(\"tensor_parallel\")\n", "\n", "load_kwargs = dict(\n", " torch_dtype=torch.bfloat16, low_cpu_mem_usage=True,\n", " device_map={\"\":device}, trust_remote_code=True, attn_implementation=attn_impl)\n", "\n", "model_path = Path(args.base_model)\n", "if (model_path / \"adapter_config.json\").exists():\n", " _meta_p = model_path / \"pipeline_meta.json\"\n", " _base_w = \"Qwen/Qwen2.5-Math-1.5B-Instruct\"\n", " if _meta_p.exists():\n", " _base_w = json.loads(_meta_p.read_text(encoding=\"utf-8\")).get(\"base_model\", _base_w)\n", " logger.info(\"PEFT adapter — loading base %s then merging %s\", _base_w, args.base_model)\n", " _base = AutoModelForCausalLM.from_pretrained(_base_w, **load_kwargs)\n", " model = PeftModel.from_pretrained(_base, args.base_model).merge_and_unload().to(device)\n", "else:\n", " model = AutoModelForCausalLM.from_pretrained(args.base_model, **load_kwargs)\n", "\n", "for p in model.parameters(): p.requires_grad_(True)\n", "\n", "# Flash-Attn 2 makes gradient checkpointing redundant (same O(T) memory)\n", "if attn_impl != \"flash_attention_2\":\n", " model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={\"use_reentrant\": False})\n", " if hasattr(model, \"config\"): model.config.use_cache = False\n", " logger.info(\"Gradient checkpointing enabled.\")\n", "\n", "n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)\n", "n_total = sum(p.numel() for p in model.parameters())\n", "logger.info(\"Parameters: %s / %s trainable (%.1f%%)\",\n", " f\"{n_trainable:,}\", f\"{n_total:,}\", 100*n_trainable/max(n_total,1))\n", "\n", "# ── Frozen reference policy for KL penalty ────────────────────────────────────\n", "ref_model: Optional[AutoModelForCausalLM] = None\n", "if args.kl_coef > 0.0:\n", " ref_model = copy.deepcopy(model)\n", " ref_model.requires_grad_(False).eval()\n", " logger.info(\"Reference policy ready (kl_coef=%.4f).\", args.kl_coef)" ] }, { "cell_type": "code", "execution_count": null, "id": "cell-env", "metadata": {}, "outputs": [], "source": [ "# ── Load training data ────────────────────────────────────────────────────────\n", "def _load_jsonl_qa(path: str) -> List[Dict[str, str]]:\n", " \"\"\"Load {question, gold_final} pairs from a JSONL file.\"\"\"\n", " pairs: List[Dict[str, str]] = []\n", " p = Path(path)\n", " if not p.exists():\n", " logger.warning(\"Data file not found: %s\", path); return pairs\n", " with p.open(encoding=\"utf-8\") as f:\n", " for line in f:\n", " line = line.strip()\n", " if not line: continue\n", " try: rec = json.loads(line)\n", " except json.JSONDecodeError: continue\n", " if \"question\" in rec and \"answer\" in rec:\n", " q = rec[\"question\"].strip()\n", " _, g = parse_gsm8k_answer(str(rec[\"answer\"]))\n", " elif \"messages\" in rec:\n", " q, asst = \"\", \"\"\n", " for msg in rec[\"messages\"]:\n", " if msg.get(\"role\") == \"user\" and not q: q = msg.get(\"content\",\"\").strip()\n", " if msg.get(\"role\") == \"assistant\" and not asst: asst = msg.get(\"content\",\"\")\n", " if \"Problem:\" in q: q = q.split(\"Problem:\",1)[1].strip()\n", " g = (extract_final_answer_numeric_str(asst) or \"\").strip()\n", " else:\n", " continue\n", " if q and g: pairs.append({\"question\": q, \"gold_final\": g})\n", " logger.info(\"Loaded %d QA pairs from %s\", len(pairs), path)\n", " return pairs\n", "\n", "def _load_math_dataset(\n", " local_path: Optional[str] = None,\n", " cache: str = \"data/math/math_numeric.jsonl\",\n", " max_diff: int = 3,\n", ") -> List[Dict[str, str]]:\n", " \"\"\"Load MATH competition dataset (numeric answers, difficulty ≤ max_diff).\"\"\"\n", " for src in filter(None, [local_path, cache]):\n", " p = Path(src)\n", " if p.exists():\n", " items = [json.loads(l) for l in p.read_text(encoding=\"utf-8\").splitlines() if l.strip()]\n", " if items: logger.info(\"Loaded %d MATH pairs from %s\", len(items), p); return items\n", " try:\n", " from datasets import load_dataset\n", " ds = load_dataset(\"qwedsacf/competition_math\", split=\"train\", trust_remote_code=True)\n", " except Exception as e:\n", " logger.warning(\"MATH download failed (%s) — GSM8K only.\", e); return []\n", " pairs, _box = [], re.compile(r\"\\\\\\\\boxed\\\\{([^}]*)\\\\}\")\n", " for item in ds:\n", " lvl = item.get(\"level\",\"Level 5\")\n", " try:\n", " if int(lvl.split()[-1]) > max_diff: continue\n", " except (ValueError, IndexError): continue\n", " m = _box.search(item.get(\"solution\",\"\"))\n", " if not m: continue\n", " raw = m.group(1).strip()\n", " try: num = str(int(raw))\n", " except ValueError:\n", " try: v=float(raw); num=str(int(v)) if v==int(v) else f\"{v:.4f}\"\n", " except ValueError: continue\n", " pairs.append({\"question\": item.get(\"problem\",\"\").strip(), \"gold_final\": num})\n", " if pairs:\n", " Path(cache).parent.mkdir(parents=True,exist_ok=True)\n", " Path(cache).write_text(\"\\n\".join(json.dumps(p) for p in pairs), encoding=\"utf-8\")\n", " logger.info(\"Cached %d MATH pairs → %s\", len(pairs), cache)\n", " return pairs\n", "\n", "gsm8k_pairs = _load_jsonl_qa(args.gsm8k_data)\n", "if not gsm8k_pairs:\n", " raise SystemExit(f\"No training data at {args.gsm8k_data}\")\n", "\n", "math_pairs: List[Dict[str, str]] = []\n", "if args.math_mix_ratio > 0:\n", " math_pairs = _load_math_dataset(args.math_data, max_diff=args.math_max_difficulty)\n", " if math_pairs:\n", " logger.info(\"MATH mix: %.0f%% MATH (%d) + %.0f%% GSM8K (%d)\",\n", " 100*args.math_mix_ratio, len(math_pairs),\n", " 100*(1-args.math_mix_ratio), len(gsm8k_pairs))\n", "\n", "# ── PRM scorer ────────────────────────────────────────────────────────────────\n", "prm: Optional[ProcessRewardScorer] = None\n", "if args.use_prm:\n", " try:\n", " prm = ProcessRewardScorer(model_name=args.prm_model, device=device, load_in_4bit=True)\n", " logger.info(\"PRM loaded: %s (4-bit)\", args.prm_model)\n", " except Exception as e:\n", " logger.warning(\"PRM load failed (%s) — no PRM scoring.\", e)\n", "\n", "# ── Unified accuracy calculator (step-chain scoring, Phase 2+) ────────────────\n", "_extractor = StepChainExtractor(model_name=args.extractor_model, device=str(device),\n", " cache_path=args.extraction_cache)\n", "_unified_calc = UnifiedAccuracyCalculator(extractor=_extractor, question_evaluator=None)\n", "logger.info(\"Warming up step-chain extractor ...\")\n", "_extractor.warmup()\n", "logger.info(\"Extractor ready.\")\n", "\n", "# ── CurriculumMathEnvironment (full model — generates + scores) ───────────────\n", "math_env = CurriculumMathEnvironment(\n", " policy_model=model,\n", " value_model=None,\n", " tokenizer=tokenizer,\n", " reference_questions=[p[\"question\"] for p in gsm8k_pairs],\n", " grounded_qa_pairs=gsm8k_pairs,\n", " prm_scorer=prm,\n", " max_solution_tokens=args.max_new_tokens,\n", " device=device,\n", " unified_accuracy_calc=_unified_calc,\n", ")\n", "_unified_calc.question_evaluator = math_env.question_evaluator\n", "\n", "# LLM-backed question classifier (uses the already-loaded policy)\n", "_llm_cls = LLMQuestionClassifier(model=model, tokenizer=tokenizer,\n", " device=device, cache_size=10_000)\n", "math_env.question_evaluator.classifier = _llm_cls\n", "\n", "# Bootstrap curriculum from structured dataset (NuminaMath / OpenMathInstruct)\n", "_raw = [json.loads(l) for l in Path(args.gsm8k_data).read_text(encoding=\"utf-8\").splitlines() if l.strip()]\n", "if any(\"skill_id\" in r for r in _raw[:20]):\n", " math_env.curriculum_manager.initialize_from_dataset(_raw)\n", " logger.info(\"Curriculum bootstrapped from skill_ids.\")\n", "else:\n", " logger.info(\"Plain dataset — keyword-classifier bootstrap.\")\n", "\n", "# ── RL Environment — wraps math_env with reset / step / state / close ─────────\n", "env = AxiomforgeaiEnvironment()\n", "env._math_env = math_env # inject the training-configured math_env\n", "logger.info(\"RL environment ready — reset / step / state / close.\")" ] }, { "cell_type": "code", "execution_count": null, "id": "cell-optim", "metadata": {}, "outputs": [], "source": [ "# ── Optimiser + LR schedule ───────────────────────────────────────────────────\n", "optimizer = torch.optim.AdamW(\n", " [p for p in model.parameters() if p.requires_grad],\n", " lr=args.learning_rate,\n", " fused=torch.cuda.is_available(),\n", ")\n", "\n", "from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR\n", "_nw = max(1, args.warmup_iters)\n", "_nt = max(1, args.num_iterations)\n", "_nd = max(1, _nt - _nw)\n", "scheduler = SequentialLR(\n", " optimizer,\n", " schedulers=[\n", " LinearLR(optimizer, start_factor=0.1, end_factor=1.0, total_iters=_nw),\n", " CosineAnnealingLR(optimizer, T_max=_nd, eta_min=args.learning_rate * args.min_lr_ratio),\n", " ],\n", " milestones=[_nw],\n", ")\n", "logger.info(\"LR: %.1e warmup=%d cosine=%d min=%.1e\",\n", " args.learning_rate, _nw, _nd, args.learning_rate * args.min_lr_ratio)" ] }, { "cell_type": "code", "execution_count": null, "id": "cell-utils", "metadata": {}, "outputs": [], "source": [ "# ── GRPO utilities ────────────────────────────────────────────────────────────\n", "# These functions live in the notebook because they depend on live model\n", "# objects and are tightly coupled to the GRPO update step.\n", "\n", "def _stop_ids(tok: AutoTokenizer) -> Optional[List[int]]:\n", " ids = []\n", " if tok.eos_token_id is not None: ids.append(tok.eos_token_id)\n", " im = tok.convert_tokens_to_ids(\"<|im_end|>\")\n", " if isinstance(im, int) and im not in ids: ids.append(im)\n", " return ids or None\n", "\n", "\n", "@torch.no_grad()\n", "def generate_questions_batched(\n", " model, tokenizer, instruction: str, K_q: int,\n", " max_new_tokens: int, temperature: float, device,\n", ") -> Tuple[List[str], List, List, List]:\n", " \"\"\"Generate K_q candidate question strings from a curriculum instruction.\"\"\"\n", " messages = create_generator_messages(instruction)\n", " try: prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n", " except: prompt = f\"{messages[0]['content']}\\n\\n{instruction}\\n\"\n", " pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id\n", " enc = tokenizer(prompt, return_tensors=\"pt\", truncation=True, max_length=512).to(device)\n", " plen = enc[\"input_ids\"].shape[1]\n", " out = model.generate(\n", " input_ids=enc[\"input_ids\"].expand(K_q,-1).contiguous(),\n", " attention_mask=enc[\"attention_mask\"].expand(K_q,-1).contiguous(),\n", " max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature,\n", " top_p=0.95, pad_token_id=pad_id, eos_token_id=_stop_ids(tokenizer), use_cache=True)\n", " pad_t = torch.tensor(pad_id, device=device, dtype=out.dtype)\n", " questions, ids_list, masks_list, olps_list = [], [], [], []\n", " attn_lp = (out != pad_t); attn_lp[:,:plen] = True\n", " batch_logits = model(input_ids=out, attention_mask=attn_lp.long(),\n", " use_cache=False, return_dict=True).logits\n", " for i in range(K_q):\n", " full = out[i]; resp = full[plen:]\n", " mask = torch.zeros(full.shape[0], dtype=torch.bool, device=device)\n", " mask[plen:] = resp != pad_t\n", " questions.append(tokenizer.decode(resp, skip_special_tokens=True).strip())\n", " ids_list.append(full); masks_list.append(mask)\n", " sl = batch_logits[i,:-1]; lb = full[1:]; sm = mask[1:]\n", " lp = F.log_softmax(sl,dim=-1)[torch.arange(sl.size(0),device=device), lb]\n", " resp_lp = lp[sm]\n", " olps_list.append(resp_lp.sum().detach() if resp_lp.numel()>0 else torch.tensor(0.,device=device))\n", " return questions, ids_list, masks_list, olps_list\n", "\n", "\n", "def generate_solutions_batched(\n", " model, tokenizer, prompt: str, K: int,\n", " max_new_tokens: int, temperature: float, device,\n", ") -> Tuple[List[str], List, List, List]:\n", " \"\"\"Generate K solution strings and their per-sequence log-probs.\"\"\"\n", " pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id\n", " enc = tokenizer(prompt, return_tensors=\"pt\", padding=False,\n", " truncation=True, max_length=1024).to(device)\n", " plen = enc[\"input_ids\"].shape[1]\n", " model.eval()\n", " with torch.no_grad():\n", " out = model.generate(\n", " input_ids=enc[\"input_ids\"].expand(K,-1).contiguous(),\n", " attention_mask=enc[\"attention_mask\"].expand(K,-1).contiguous(),\n", " max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature,\n", " top_p=0.9, pad_token_id=pad_id, eos_token_id=_stop_ids(tokenizer), use_cache=True)\n", " pad_t = torch.tensor(pad_id, device=device, dtype=out.dtype)\n", " solutions, ids_list, masks_list, olps_list = [], [], [], []\n", " with torch.no_grad():\n", " attn_lp = (out != pad_t); attn_lp[:,:plen] = True\n", " batch_logits = model(input_ids=out, attention_mask=attn_lp.long(),\n", " use_cache=False, return_dict=True).logits\n", " for i in range(K):\n", " full = out[i]; resp = full[plen:]\n", " mask = torch.zeros(full.shape[0], dtype=torch.bool, device=device)\n", " mask[plen:] = resp != pad_t\n", " solutions.append(tokenizer.decode(resp, skip_special_tokens=True))\n", " ids_list.append(full); masks_list.append(mask)\n", " sl = batch_logits[i,:-1]; lb = full[1:]; sm = mask[1:]\n", " lp = F.log_softmax(sl,dim=-1)[torch.arange(sl.size(0),device=device), lb]\n", " resp_lp = lp[sm]\n", " olps_list.append(resp_lp.sum().detach() if resp_lp.numel()>0 else torch.tensor(0.,device=device))\n", " return solutions, ids_list, masks_list, olps_list\n", "\n", "\n", "def compute_sequence_log_prob(model, input_ids, response_mask) -> torch.Tensor:\n", " \"\"\"Forward pass → sum of log-probs over the response tokens.\"\"\"\n", " logits = model(input_ids=input_ids.unsqueeze(0), use_cache=False, return_dict=True).logits[0]\n", " lp = F.log_softmax(logits[:-1], dim=-1)\n", " token_lp = lp[torch.arange(lp.size(0), device=lp.device), input_ids[1:]]\n", " resp = token_lp[response_mask[1:]]\n", " return resp.sum() if resp.numel() > 0 else torch.tensor(0., requires_grad=True, device=input_ids.device)\n", "\n", "\n", "def grpo_loss_for_group(\n", " model, ids_list, masks_list, rewards: List[float], old_lps,\n", " clip_eps: float = 0.2, kl_coef: float = 0.0, ref_model=None, eps: float = 1e-8,\n", ") -> Optional[torch.Tensor]:\n", " \"\"\"GRPO policy loss for one question group (K solutions).\"\"\"\n", " r = np.array(rewards, dtype=np.float32)\n", " if r.std() < eps: return None\n", " advantages = np.clip((r - r.mean()) / (r.std() + eps), -5., 5.)\n", " dev = next(model.parameters()).device\n", " loss = torch.tensor(0., device=dev); n = 0\n", " model.train()\n", " for ids, mask, adv, olp in zip(ids_list, masks_list, advantages, old_lps):\n", " n_resp = int(mask[1:].sum().item())\n", " if n_resp == 0: continue\n", " new_lp = compute_sequence_log_prob(model, ids, mask)\n", " adv_t = torch.tensor(adv, dtype=new_lp.dtype, device=dev)\n", " if clip_eps > 0:\n", " ratio = torch.exp(new_lp - olp.to(dev).detach())\n", " li = -torch.min(ratio * adv_t, torch.clamp(ratio,1-clip_eps,1+clip_eps) * adv_t) / n_resp\n", " else:\n", " li = -(adv_t * new_lp / n_resp)\n", " if kl_coef > 0 and ref_model is not None:\n", " with torch.no_grad(): ref_lp = compute_sequence_log_prob(ref_model, ids, mask)\n", " li = li + kl_coef * (new_lp - ref_lp.to(dev).detach()) / n_resp\n", " loss = loss + li; n += 1\n", " return loss / n if n > 0 else None\n", "\n", "\n", "def compute_self_play_reward(\n", " question: str, solution: str, topic: str, difficulty: float, math_env,\n", ") -> Tuple[float, float, float, Dict]:\n", " \"\"\"Self-play reward via math_env.compute_reward (no gold answer).\"\"\"\n", " result = math_env.compute_reward(question=question, solution=solution,\n", " target_topic=topic, target_difficulty=difficulty)\n", " combined = float(result[\"combined_score\"])\n", " sol_m = result.get(\"solution_metrics\") or {}\n", " s_rew = float(sol_m.get(\"overall_score\", 0.)) if isinstance(sol_m, dict) else 0.\n", " q_raw = result.get(\"question_metrics\") or {}\n", " q_rew = float(result.get(\"effective_question_reward\", q_raw.get(\"overall_score\", 0.)))\n", " q_met: Dict = {\n", " \"overall_score\": q_rew,\n", " \"topic_match\": float(q_raw.get(\"topic_match\", 0.)),\n", " \"difficulty_fit\": float(q_raw.get(\"difficulty_score\", 0.)),\n", " \"clarity\": float(q_raw.get(\"clarity\", 0.)),\n", " \"novelty\": float(q_raw.get(\"novelty_combined\", 0.)),\n", " \"solvability\": float(q_raw.get(\"solvability_score\",0.)),\n", " \"sp_chain_integrity_score\": result.get(\"sp_chain_integrity_score\"),\n", " }\n", " return combined, q_rew, s_rew, q_met\n", "\n", "\n", "def _verify_sp_answer(solutions: List[str], topic: str, difficulty: float) -> bool:\n", " \"\"\"Consensus check: majority of K solutions agree on a numeric answer.\"\"\"\n", " t = topic.lower().replace(\" \",\"_\")\n", " if t in {\"geometry\"} or difficulty >= 4.: return False\n", " answers: List[float] = []\n", " for sol in solutions:\n", " m = re.search(r\"final answer[:\\s]*([^\\n]+)\", sol, re.I)\n", " if not m: continue\n", " raw = m.group(1).strip()\n", " for fn in (lambda s: float(eval(s, {\"__builtins__\":{}}, {})),\n", " lambda s: float(__import__(\"sympy\").N(__import__(\"sympy\").sympify(s), 15))):\n", " try: v = fn(raw); answers.append(round(v, 6)); break\n", " except: pass\n", " if not answers: return False\n", " maj = max(set(answers), key=answers.count)\n", " return answers.count(maj) >= max(1, len(solutions)//2)\n", "\n", "\n", "def evaluate_policy(\n", " model, tokenizer, data_path: str, max_samples: int,\n", " max_new_tokens: int, math_env=None, pass_at_k: int = 4,\n", ") -> Dict[str, Any]:\n", " \"\"\"Run evaluation on held-out data; returns combined_score and related metrics.\"\"\"\n", " if not Path(data_path).exists(): return {\"accuracy\": 0., \"combined_score\": 0., \"total\": 0}\n", " model.eval()\n", " reward_fn = None\n", " if math_env is not None:\n", " import logging as _lm\n", " _ml = _lm.getLogger(\"src.rl.math_environment_curriculum\")\n", " _pl = _lm.getLogger(\"src.rl.prm_scorer\")\n", " def reward_fn(q, s, g):\n", " _ml.setLevel(_lm.WARNING); _pl.setLevel(_lm.WARNING)\n", " try: return math_env.compute_grounded_reward(q, s, g)\n", " finally: _ml.setLevel(_lm.INFO); _pl.setLevel(_lm.INFO)\n", " stem = Path(data_path).stem.lower()\n", " ds_name = \"AQuA-RAT\" if \"aqua\" in stem else \"MATH\" if \"math\" in stem else \"GSM8K\"\n", " results = evaluate_gsm8k(model=model, tokenizer=tokenizer, data_path=data_path,\n", " max_samples=max_samples, max_new_tokens=max_new_tokens,\n", " reward_fn=reward_fn, pass_at_k=pass_at_k, dataset_name=ds_name)\n", " model.train()\n", " return results\n", "\n", "\n", "# ── Difficulty-adaptive question sampling ─────────────────────────────────────\n", "_q_wins: Dict[str, int] = defaultdict(int)\n", "_q_attempts: Dict[str, int] = defaultdict(int)\n", "\n", "def _qkey(q: str) -> str:\n", " return hashlib.md5(q.encode(), usedforsecurity=False).hexdigest()\n", "\n", "def _sample_by_difficulty(pool: List[Dict], n: int, alpha: float) -> List[Dict]:\n", " \"\"\"Weight questions by how informative they are (win-rate close to 50%).\"\"\"\n", " if alpha <= 0: return random.sample(pool, min(n, len(pool)))\n", " weights = []\n", " for qa in pool:\n", " att = _q_attempts[_qkey(qa[\"question\"])]\n", " w = 0.75 if att == 0 else max(\n", " (1. - abs(_q_wins[_qkey(qa[\"question\"])]/att - 0.5)*2.) ** alpha, 0.05)\n", " weights.append(w)\n", " tw = sum(weights)\n", " probs = [w/tw for w in weights]\n", " return [pool[i] for i in np.random.choice(len(pool), size=min(n,len(pool)), replace=False, p=probs)]" ] }, { "cell_type": "code", "execution_count": null, "id": "cell-init-eval", "metadata": {}, "outputs": [], "source": [ "# ── Optional initial evaluation (Iteration 0 baseline) ───────────────────────\n", "metrics_log: List[Dict] = []\n", "best_combined = best_prm_mean = best_accuracy = 0.\n", "\n", "if not args.skip_initial_eval:\n", " logger.info(\"=\" * 70)\n", " logger.info(\"INITIAL EVALUATION (Iteration 0)\")\n", " logger.info(\"=\" * 70)\n", " _init = evaluate_policy(model, tokenizer, args.eval_data_path,\n", " args.eval_max_samples, args.eval_max_new_tokens,\n", " math_env=math_env, pass_at_k=args.eval_pass_at_k)\n", " best_combined = best_accuracy = float(_init.get(\"combined_score\", 0.))\n", " best_prm_mean = float(_init.get(\"prm_mean\", 0.))\n", " logger.info(\"Baseline combined_score=%.4f correct=%.1f%%\",\n", " best_combined, 100*float(_init.get(\"correct_rate\", 0.)))\n", " metrics_log.append({\"iteration\": 0, **_init})" ] }, { "cell_type": "code", "execution_count": null, "id": "cell-train", "metadata": {}, "outputs": [], "source": [ "# ══════════════════════════════════════════════════════════════════════════════\n", "# GRPO Training Loop — reset → step → state → close\n", "# ══════════════════════════════════════════════════════════════════════════════\n", "\n", "# ── Phase curriculum state ────────────────────────────────────────────────────\n", "class _Phase(Enum):\n", " GROUNDED_ONLY = _auto() # grounded only until model is ready\n", " SELFPLAY_RAMP = _auto() # ramp self-play ratio up from 0\n", " CONTINUOUS = _auto() # steady-state mixed training\n", "\n", "_phase: _Phase = _Phase.GROUNDED_ONLY\n", "_sp_iters: int = 0\n", "_sp_suspended: bool = False\n", "_eff_sp: float = 0.\n", "_use_chain: bool = False\n", "_chain_corr: float = 0.\n", "_extract_rate: float = 0.\n", "_chain_buf: List[float] = []\n", "_prm_buf: List[float] = []\n", "_succ_buf: List[int] = []\n", "_CWIN, _CMAX, _SHDW = 50, 200, 4\n", "_shadow_ctr: int = 0\n", "\n", "for iteration in range(1, args.num_iterations + 1):\n", " iter_start = time.perf_counter()\n", " logger.info(\"=\" * 70)\n", " logger.info(\"GRPO ITERATION %d / %d [phase=%s]\", iteration, args.num_iterations, _phase.name)\n", " logger.info(\"=\" * 70)\n", "\n", " # ── Dataset batch (with MATH ramp) ────────────────────────────────────────\n", " _eff_math = args.math_mix_ratio\n", " if args.math_mix_ratio_late and iteration > args.math_ramp_start:\n", " _r = min(1., (iteration - args.math_ramp_start) / 10.)\n", " _eff_math = args.math_mix_ratio + _r * (args.math_mix_ratio_late - args.math_mix_ratio)\n", " if math_pairs and _eff_math > 0:\n", " nm = max(1, round(args.questions_per_iter * _eff_math))\n", " ng = max(1, args.questions_per_iter - nm)\n", " batch = (_sample_by_difficulty(math_pairs, nm, args.difficulty_alpha) +\n", " _sample_by_difficulty(gsm8k_pairs, ng, args.difficulty_alpha))\n", " random.shuffle(batch)\n", " else:\n", " batch = _sample_by_difficulty(gsm8k_pairs, args.questions_per_iter, args.difficulty_alpha)\n", "\n", " # Temperature annealing: 0.8 → 0.4 over the full run\n", " _ann = min(1., (iteration-1) / max(1, args.num_iterations-1))\n", " _temp = args.temperature * (1. - 0.5 * _ann)\n", "\n", " # ── Effective self-play ratio (phase-dependent) ────────────────────────────\n", " if _phase == _Phase.GROUNDED_ONLY or _sp_suspended: _eff_sp = 0.\n", " elif _phase == _Phase.SELFPLAY_RAMP:\n", " _eff_sp = 1. - max(0.30, 1. - _sp_iters / max(1, args.selfplay_ramp_iters))\n", " else: _eff_sp = args.self_play_ratio\n", "\n", " _sp_idx = set(random.sample(range(len(batch)), int(round(len(batch)*_eff_sp))))\n", "\n", " # ── Per-iteration metric accumulators ─────────────────────────────────────\n", " all_r, all_qr = [], []\n", " gr_r, sp_r = [], []\n", " gr_step, gr_lccp, gr_gt = [], [], []\n", " ch_arith, ch_dep, ch_int, sp_ch = [], [], [], []\n", " qc = dict(topic=[], diff=[], clarity=[], novelty=[], solvab=[])\n", " skipped = n_grps = n_sp = q_att = q_val = q_good = 0\n", " skip0var = 0; total_loss = 0.\n", "\n", " optimizer.zero_grad()\n", "\n", " pbar = tqdm(batch, desc=f\"Iter {iteration}\", unit=\"q\")\n", " for _idx, qa in enumerate(pbar):\n", " is_sp = _idx in _sp_idx\n", "\n", " # ════════════════════════════════════════════════════════════════════\n", " # RESET — start a new episode\n", " # ════════════════════════════════════════════════════════════════════\n", " if is_sp:\n", " # Self-play: curriculum provides the instruction\n", " instr, topic, diff = env._math_env.sample_instruction()\n", " if diff >= 4.: skipped += 1; continue\n", " q_att += 1\n", "\n", " if args.q_group_size > 1:\n", " # Two-phase SP: generate K_q questions, then K solutions per question\n", " _qt = min(0.90, _temp + 0.05)\n", " qcands, qids, qmasks, qolps = generate_questions_batched(\n", " model, tokenizer, instr, args.q_group_size, 128, _qt, device)\n", " vq = [(q,i,m,o) for q,i,m,o in zip(qcands,qids,qmasks,qolps) if len(q.strip())>=10]\n", " if not vq: skipped += 1; continue\n", " q_val += 1; n_sp += 1\n", " qagg: List[float] = []\n", " for _qt2, _qi, _qm, _qo in vq:\n", " sols, sids, smasks, solps = generate_solutions_batched(\n", " model, tokenizer, math_env.format_solution_prompt(_qt2),\n", " args.group_size, args.max_new_tokens, _temp, device)\n", " if args.overlong_filter:\n", " vf = [(s,i,m,o) for s,i,m,o in zip(sols,sids,smasks,solps)\n", " if int(m.sum())0.5 for r in qagg): q_good+=1\n", " pbar.set_postfix(loss=f\"{total_loss/max(1,n_grps):.4f}\",\n", " sp_r=f\"{np.mean(sp_r or [0]):.3f}\",skip=skipped)\n", " continue\n", "\n", " # Single-question self-play\n", " from src.config.prompts import create_generator_messages as _cgm\n", " _msgs = _cgm(instr)\n", " try: _pr = tokenizer.apply_chat_template(_msgs,tokenize=False,add_generation_prompt=True)\n", " except: _pr = f\"{_msgs[0]['content']}\\n\\n{instr}\\n\"\n", " _enc = tokenizer(_pr,return_tensors=\"pt\",truncation=True,max_length=512).to(device)\n", " _plen = _enc[\"input_ids\"].shape[1]\n", " with torch.no_grad():\n", " _out = model.generate(\n", " **_enc, max_new_tokens=128, do_sample=True,\n", " temperature=min(0.90,_temp+0.05), top_p=0.95,\n", " pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,\n", " eos_token_id=_stop_ids(tokenizer), use_cache=True)\n", " question = tokenizer.decode(_out[0][_plen:], skip_special_tokens=True).strip()\n", " if len(question.strip()) < 10: skipped+=1; continue\n", " q_val+=1; n_sp+=1\n", " else:\n", " # ── RESET (grounded): inject difficulty-sampled QA pair ────────────\n", " obs = env.reset(qa=qa) # state: question from dataset\n", " question = obs.question # the grounded math question\n", " topic, diff = \"grounded\", 0.5\n", "\n", " # ════════════════════════════════════════════════════════════════════\n", " # GENERATE — policy produces K candidate solutions\n", " # ════════════════════════════════════════════════════════════════════\n", " solutions, ids_list, masks_list, lps_list = generate_solutions_batched(\n", " model, tokenizer, math_env.format_solution_prompt(question),\n", " args.group_size, args.max_new_tokens, _temp, device)\n", " if args.overlong_filter:\n", " vf = [(s,i,m,o) for s,i,m,o in zip(solutions,ids_list,masks_list,lps_list)\n", " if int(m.sum()) < args.max_new_tokens]\n", " if vf: solutions,ids_list,masks_list,lps_list = map(list, zip(*vf))\n", " else: skipped+=1; continue\n", "\n", " # ════════════════════════════════════════════════════════════════════\n", " # STEP — score each solution with the RL environment\n", " # ════════════════════════════════════════════════════════════════════\n", " rewards: List[float] = []\n", " _sp_qr: List[float] = []\n", " for sol in solutions:\n", " if is_sp:\n", " # Self-play: env.compute_reward (no gold answer)\n", " r, qr, _, qm = compute_self_play_reward(question, sol, topic, diff, math_env)\n", " _sp_qr.append(qr); all_qr.append(qr)\n", " qc[\"topic\"].append(qm[\"topic_match\"]); qc[\"diff\"].append(qm[\"difficulty_fit\"])\n", " qc[\"clarity\"].append(qm[\"clarity\"]); qc[\"novelty\"].append(qm[\"novelty\"])\n", " qc[\"solvab\"].append(qm[\"solvability\"])\n", " _sc = qm.get(\"sp_chain_integrity_score\")\n", " if _sc is not None: sp_ch.append(float(_sc))\n", " else:\n", " # Grounded: env.step → compute_grounded_reward internally\n", " step_obs = env.step(AxiomforgeaiAction(solution=sol))\n", " r = step_obs.reward\n", " m = step_obs.metadata or {}\n", " gr_step.append(float(m.get(\"step_accuracy\", 0.)))\n", " gr_lccp.append(float(m.get(\"lccp\", 0.)))\n", " gr_gt.append(bool(m.get(\"gt_match\", False)))\n", " if m.get(\"chain_arith_score\") is not None: ch_arith.append(float(m[\"chain_arith_score\"]))\n", " if m.get(\"chain_dep_score\") is not None: ch_dep.append(float(m[\"chain_dep_score\"]))\n", " if m.get(\"chain_integrity_score\") is not None: ch_int.append(float(m[\"chain_integrity_score\"]))\n", " # Shadow chain extraction for Phase 2 calibration\n", " _shadow_ctr += 1\n", " if (_phase == _Phase.SELFPLAY_RAMP and not _use_chain\n", " and _unified_calc and _shadow_ctr % _SHDW == 0):\n", " _pps = 0.60*m.get(\"prm_final_score\",0.) + 0.40*m.get(\"prm_mean_score\",0.)\n", " try:\n", " _sh = _unified_calc.compute(solution=sol,gold_answer=qa[\"gold_final\"],\n", " question=question,topic=\"grounded\",phase=\"grounded\")\n", " _chain_buf.append(_sh.chain_integrity_score)\n", " _prm_buf.append(_pps)\n", " _succ_buf.append(1 if _sh.extraction_succeeded else 0)\n", " except Exception: _succ_buf.append(0)\n", " rewards.append(r)\n", "\n", " all_r.extend(rewards)\n", " if is_sp: sp_r.extend(rewards)\n", " else: gr_r.extend(rewards)\n", "\n", " if is_sp:\n", " if _sp_qr and np.mean(_sp_qr) > 0.5: q_good += 1\n", " if not _verify_sp_answer(solutions, topic, diff): skipped+=1; continue\n", " else:\n", " k = _qkey(question)\n", " _q_attempts[k] += len(solutions)\n", " _q_wins[k] += sum(1 for r in rewards if r > float(np.median(rewards)))\n", "\n", " # Zero-variance guard\n", " if np.std(rewards) < 0.02:\n", " skipped+=1; skip0var+=1\n", " pbar.set_postfix(mean_r=f\"{np.mean(rewards):.3f}\",skip=skipped,loss=\"0var\"); continue\n", "\n", " # GRPO loss\n", " g_loss = grpo_loss_for_group(model, ids_list, masks_list, rewards, lps_list,\n", " args.clip_eps, args.kl_coef, ref_model)\n", " if g_loss is None:\n", " skipped+=1; skip0var+=1\n", " pbar.set_postfix(mean_r=f\"{np.mean(rewards):.3f}\",skip=skipped,loss=\"skip\"); continue\n", "\n", " g_loss.backward()\n", " total_loss += g_loss.item(); n_grps += 1\n", " pbar.set_postfix(mean_r=f\"{np.mean(rewards):.3f}\",\n", " loss=f\"{g_loss.item():.4f}\", skip=skipped)\n", "\n", " # ── Optimiser step ────────────────────────────────────────────────────────\n", " if n_grps > 0:\n", " if n_grps > 1:\n", " for p in model.parameters():\n", " if p.grad is not None: p.grad.div_(n_grps)\n", " torch.nn.utils.clip_grad_norm_(\n", " [p for p in model.parameters() if p.requires_grad], args.max_grad_norm)\n", " optimizer.step()\n", " loss_val = total_loss / n_grps\n", " else:\n", " loss_val = 0.\n", " scheduler.step()\n", "\n", " # ════════════════════════════════════════════════════════════════════════\n", " # STATE — collect iteration metrics + phase transitions\n", " # ════════════════════════════════════════════════════════════════════════\n", " _epi_state = env.state # episode_id + step_count for the last episode\n", " iter_time = time.perf_counter() - iter_start\n", " mean_r = float(np.mean(all_r)) if all_r else 0.\n", " std_r = float(np.std(all_r)) if all_r else 0.\n", " acc_r = float(np.mean([r>0.5 for r in all_r])) if all_r else 0.\n", " gr_acc = float(np.mean([r>0.5 for r in gr_r])) if gr_r else 0.\n", " step_a = float(np.mean(gr_step)) if gr_step else 0.\n", " lccp_a = float(np.mean(gr_lccp)) if gr_lccp else 0.\n", " mean_qr = float(np.mean(all_qr)) if all_qr else 0.\n", " gt_rate = float(sum(gr_gt)/len(gr_gt)) if gr_gt else 0.\n", " cur_lr = optimizer.param_groups[0][\"lr\"]\n", "\n", " # Phase transition logic\n", " if _phase == _Phase.GROUNDED_ONLY:\n", " if (gt_rate >= args.selfplay_gt_thresh\n", " and gr_acc >= args.selfplay_grounded_thresh\n", " and step_a >= args.selfplay_step_thresh\n", " and iteration >= args.min_warmup):\n", " _phase = _Phase.SELFPLAY_RAMP\n", " logger.info(\"PHASE → SELFPLAY_RAMP at iter %d (gt=%.2f acc=%.2f step=%.2f)\",\n", " iteration, gt_rate, gr_acc, step_a)\n", " elif _phase in (_Phase.SELFPLAY_RAMP, _Phase.CONTINUOUS):\n", " _sp_iters += 1\n", " if _phase == _Phase.SELFPLAY_RAMP and _sp_iters >= args.selfplay_ramp_iters:\n", " _phase = _Phase.CONTINUOUS\n", " logger.info(\"PHASE → CONTINUOUS at iter %d\", iteration)\n", " # Chain scoring calibration\n", " if len(_chain_buf) > _CMAX:\n", " _chain_buf[:] = _chain_buf[-_CMAX:]\n", " _prm_buf[:] = _prm_buf[-_CMAX:]\n", " _succ_buf[:] = _succ_buf[-_CMAX:]\n", " if not _use_chain and len(_chain_buf) >= _CWIN:\n", " try:\n", " from scipy.stats import pearsonr\n", " _r2, _ = pearsonr(_chain_buf[-_CWIN:], _prm_buf[-_CWIN:])\n", " _chain_corr = float(_r2)\n", " except Exception: _chain_corr = 0.\n", " _n = len(_succ_buf[-_CWIN:])\n", " _extract_rate = sum(_succ_buf[-_CWIN:])/_n if _n else 0.\n", " if _chain_corr >= 0.70 and _extract_rate >= 0.80:\n", " _use_chain = True; math_env.use_chain_scoring = True\n", " logger.info(\"CHAIN PRIMARY activated iter %d: corr=%.2f rate=%.2f\",\n", " iteration, _chain_corr, _extract_rate)\n", " _prev_susp = _sp_suspended\n", " _sp_suspended = bool(gr_gt) and gt_rate < args.grounded_floor\n", " if _sp_suspended and not _prev_susp:\n", " logger.warning(\"GROUNDED FLOOR: self-play suspended (gt=%.2f)\", gt_rate)\n", " elif not _sp_suspended and _prev_susp:\n", " logger.info(\"GROUNDED FLOOR: self-play resumed (gt=%.2f)\", gt_rate)\n", "\n", " # ── Logging ───────────────────────────────────────────────────────────────\n", " logger.info(\n", " \"Iter %d | loss=%.4f | r=%.3f±%.3f | gt=%.1f%% | gr_acc=%.1f%% | \"\n", " \"step=%.1f%% | lccp=%.1f%% | phase=%s sp=%.0f%% | \"\n", " \"grps=%d skip=%d | lr=%.2e | %.1fs\",\n", " iteration, loss_val, mean_r, std_r,\n", " 100*gt_rate, 100*gr_acc, 100*step_a, 100*lccp_a,\n", " _phase.name, 100*_eff_sp, n_grps, skipped, cur_lr, iter_time)\n", " if (n_grps+skipped) > 0 and skip0var/(n_grps+skipped) > 0.30:\n", " logger.warning(\"STARVATION: %.0f%% zero-var groups — curriculum %s\",\n", " 100*skip0var/(n_grps+skipped),\n", " \"too easy\" if gr_acc>0.75 else \"too hard\")\n", "\n", " # ── Evaluation (every eval_every iterations) ───────────────────────────────\n", " iter_metrics: Dict[str, Any] = {\n", " \"iteration\": iteration, \"loss\": loss_val, \"mean_reward\": mean_r,\n", " \"std_reward\": std_r, \"batch_accuracy\": acc_r, \"grounded_accuracy\": gr_acc,\n", " \"gt_match_rate\": round(gt_rate,4), \"step_accuracy\": step_a, \"lccp\": lccp_a,\n", " \"n_groups\": n_grps, \"skipped_groups\": skipped, \"learning_rate\": cur_lr,\n", " \"iter_time_s\": iter_time, \"training_phase\": _phase.name,\n", " \"effective_sp_ratio\": round(_eff_sp,3), \"selfplay_suspended\": int(_sp_suspended),\n", " \"chain_prm_corr\": round(_chain_corr,3), \"chain_scoring_active\": int(_use_chain),\n", " \"n_sp_groups\": n_sp, \"mean_q_reward\": round(mean_qr,4),\n", " \"q_gen_valid_rate\": round(q_val/q_att if q_att>0 else 0,4),\n", " \"episode_id\": _epi_state.episode_id, # from env.state\n", " \"episode_steps\": _epi_state.step_count, # from env.state\n", " }\n", "\n", " if iteration % args.eval_every == 0:\n", " logger.info(\"Evaluating (%d samples) ...\", args.eval_max_samples)\n", " eval_res = evaluate_policy(model, tokenizer, args.eval_data_path,\n", " args.eval_max_samples, args.eval_max_new_tokens,\n", " math_env=math_env, pass_at_k=args.eval_pass_at_k)\n", " cur_comb = float(eval_res.get(\"combined_score\", best_combined))\n", " logger.info(\"Eval combined=%.4f correct=%.1f%% best=%.4f\",\n", " cur_comb, 100*float(eval_res.get(\"correct_rate\",0.)), best_combined)\n", " if cur_comb > best_combined + 1e-4:\n", " best_combined = cur_comb\n", " best_prm_mean = max(best_prm_mean, float(eval_res.get(\"prm_mean\",0.)))\n", " model.save_pretrained(str(out_dir/\"best_policy\"))\n", " tokenizer.save_pretrained(str(out_dir/\"best_policy\"))\n", " logger.info(\"New best → %s\", out_dir/\"best_policy\")\n", " iter_metrics.update(eval_res)\n", "\n", " # ── Checkpoint ────────────────────────────────────────────────────────────\n", " if iteration == args.num_iterations or (args.save_every>0 and iteration%args.save_every==0):\n", " ck = out_dir / f\"iter_{iteration:04d}\"\n", " ck.mkdir(exist_ok=True)\n", " model.save_pretrained(str(ck)); tokenizer.save_pretrained(str(ck))\n", " if args.keep_last and args.keep_last > 0:\n", " old = sorted(p for p in out_dir.iterdir() if p.is_dir() and p.name.startswith(\"iter_\"))\n", " for o in old[:-args.keep_last]:\n", " try: shutil.rmtree(o); logger.info(\"Pruned: %s\", o.name)\n", " except OSError as e: logger.warning(\"Could not prune %s: %s\", o.name, e)\n", "\n", " # ── Write metrics ─────────────────────────────────────────────────────────\n", " metrics_log.append(iter_metrics)\n", " (out_dir/\"metrics.jsonl\").write_text(\n", " \"\\n\".join(json.dumps(m) for m in metrics_log), encoding=\"utf-8\")\n", " _append_metrics_csv({\n", " \"iteration\": iter_metrics[\"iteration\"],\n", " \"timestamp\": datetime.now().isoformat(timespec=\"seconds\"),\n", " \"loss\": iter_metrics.get(\"loss\",0.),\n", " \"mean_reward\": iter_metrics.get(\"mean_reward\",0.),\n", " \"batch_acc\": iter_metrics.get(\"batch_accuracy\",0.),\n", " \"grounded_acc\": iter_metrics.get(\"grounded_accuracy\",0.),\n", " \"gt_match\": iter_metrics.get(\"gt_match_rate\",0.),\n", " \"step_acc\": iter_metrics.get(\"step_accuracy\",0.),\n", " \"lccp\": iter_metrics.get(\"lccp\",0.),\n", " \"n_groups\": iter_metrics.get(\"n_groups\",0),\n", " \"skipped\": iter_metrics.get(\"skipped_groups\",0),\n", " \"sp_ratio\": iter_metrics.get(\"effective_sp_ratio\",0.),\n", " \"phase\": iter_metrics.get(\"training_phase\",\"\"),\n", " \"lr\": iter_metrics.get(\"learning_rate\",0.),\n", " \"iter_s\": iter_metrics.get(\"iter_time_s\",0.),\n", " \"eval_combined\":iter_metrics.get(\"combined_score\",\"\") if \"combined_score\" in iter_metrics else \"\",\n", " \"eval_correct\": iter_metrics.get(\"correct_rate\",\"\") if \"combined_score\" in iter_metrics else \"\",\n", " \"eval_prm\": iter_metrics.get(\"prm_mean\",\"\") if \"combined_score\" in iter_metrics else \"\",\n", " }, step=iter_metrics[\"iteration\"])\n" ] }, { "cell_type": "code", "execution_count": null, "id": "cell-close", "metadata": {}, "outputs": [], "source": [ "# ════════════════════════════════════════════════════════════════════════════\n", "# CLOSE — persist curriculum state and finalise run\n", "# ════════════════════════════════════════════════════════════════════════════\n", "env.close() # saves CurriculumManager state to checkpoints/curriculum/\n", "_teardown() # restore stdout/stderr, flush CSV and log files\n", "\n", "logger.info(\"=\" * 70)\n", "logger.info(\"GRPO training complete.\")\n", "logger.info(\"Best combined score : %.4f\", best_combined)\n", "logger.info(\"Best PRM mean : %.3f\", best_prm_mean)\n", "logger.info(\"Checkpoints : %s\", out_dir)\n", "logger.info(\"Logs : %s\", log_dir)\n", "logger.info(\"=\" * 70)\n", "\n", "summary = {\n", " \"run_name\": run_name,\n", " \"best_combined\": best_combined,\n", " \"best_prm_mean\": best_prm_mean,\n", " \"total_iters\": args.num_iterations,\n", " \"checkpoints\": str(out_dir),\n", " \"log_dir\": str(log_dir),\n", " \"metrics_csv\": str(_csv_logger.metrics_file),\n", " \"metrics_json\": str(_csv_logger.log_path / \"detailed_metrics\"),\n", "}\n", "_csv_logger.save_summary(summary)\n", "logger.info(\"Summary → %s\", _csv_logger.log_path / \"summary.json\")\n", "\n", "# Auto-generate training plots if matplotlib is available\n", "_jsonl = out_dir / \"metrics.jsonl\"\n", "if _jsonl.exists():\n", " try:\n", " from scripts.plot_grpo_run import generate_plots\n", " _pdir = generate_plots(_jsonl)\n", " logger.info(\"Plots → %s\", _pdir)\n", " except Exception as _pe:\n", " logger.warning(\"Plot generation skipped (%s). Run manually: \"\n", " \"python scripts/plot_grpo_run.py %s\", _pe, _jsonl)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.11.0" } }, "nbformat": 4, "nbformat_minor": 5 }