""" Curriculum-aware math environment with dual reward signals. This file is deliberately minimal: a single ``collect_rollouts`` method is all the training loop needs. Rollouts and PPO updates run in the same process on a single GPU — no subprocesses, no RPC, no vLLM colocation. """ from __future__ import annotations import logging import random import re from dataclasses import asdict, dataclass from typing import Any, Dict, List, Optional, Tuple import torch from sympy import simplify from sympy.parsing.sympy_parser import parse_expr from tqdm.auto import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer from src.config.prompts import create_generator_messages, create_solver_messages from src.rl.curriculum_manager import CurriculumManager from src.rl.expert_panel import SimulatedExpertPanel from src.rl.mdp_components import Action, State, Trajectory, Transition from src.rl.prm_scorer import ProcessRewardScorer from src.rl.quality_filter import QualityFilter from src.rl.question_quality_evaluator import QuestionQualityEvaluator from src.rl.replay_buffer import GenerationalReplayBuffer from src.rl.value_network import ValueHead from src.sft.solution_format import extract_final_answer_numeric_str from src.sft.sympy_normalize import normalize_for_parse_expr logger = logging.getLogger(__name__) @dataclass class TrajectoryMetadata: curriculum_iteration: int target_topic: str target_difficulty: float instruction: str generated_question: str generated_solution: str question_length: int solution_length: int detected_topic: str detected_secondary_topics: List[str] topic_match_score: float estimated_difficulty: float clarity_score: float novelty_scores: Dict[str, float] consensus_achieved: bool consensus_strength: float answer_diversity: int majority_answer: Optional[float] primary_matches_majority: bool sympy_verified: bool steps_total: int steps_verified_ok: int steps_failed: int final_answer_ok: bool question_reward: float solution_reward: float pre_expert_reward: float expert_reward_modifier: float expert_phase: str expert_feedback: str replay_candidate: bool replay_novelty: float replay_added: bool combined_reward: float reward_breakdown: Dict[str, object] topics_in_sweet_spot: List[str] current_focus_topics: List[str] curriculum_state_snapshot: Dict[str, object] class CurriculumMathEnvironment: """Standalone curriculum environment with PRM-based rewards and GRPO training support.""" def __init__( self, policy_model: AutoModelForCausalLM, value_model: Optional[ValueHead], tokenizer: AutoTokenizer, reference_questions: Optional[List[str]] = None, grounded_qa_pairs: Optional[List[Dict[str, str]]] = None, prm_scorer: Optional[ProcessRewardScorer] = None, curriculum_checkpoint_dir: str = "checkpoints/curriculum", max_question_tokens: int = 200, max_solution_tokens: int = 500, temperature: float = 0.7, top_p: float = 0.9, consensus_temperature: float = 0.7, device: Optional[torch.device] = None, unified_accuracy_calc: Optional[Any] = None, ): # ── Core model attributes (used by generation helpers) ─────────── self.policy = policy_model self.value = value_model self.tokenizer = tokenizer self.max_question_tokens = max_question_tokens self.max_solution_tokens = max_solution_tokens self.temperature = temperature self.top_p = top_p if device is not None: self.device = torch.device(device) else: try: self.device = next(policy_model.parameters()).device except StopIteration: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.reference_questions = reference_questions or [] self.grounded_qa_pairs: List[Dict[str, str]] = [ qa for qa in (grounded_qa_pairs or []) if qa.get("question") and qa.get("gold_final") ] self.consensus_temperature = consensus_temperature self.curriculum_manager = CurriculumManager(checkpoint_dir=curriculum_checkpoint_dir) self.curriculum_manager.initialize(bootstrap_questions=self.reference_questions) self.curriculum_manager.load_checkpoint_safe() self.question_evaluator = QuestionQualityEvaluator( reference_questions=self.reference_questions ) # PRM is the sole process-quality signal. Passing prm_scorer=None # will cause compute_reward/compute_grounded_reward to raise at # call time — GRPO training always supplies the PRM. self.prm_scorer = prm_scorer # Unified accuracy calculator — activated on Phase 2+ transition. # When use_chain_scoring is True, chain_integrity_score from this # calculator replaces PRM-based process_score in both grounded and # self-play reward paths. self.unified_accuracy_calc: Optional[Any] = unified_accuracy_calc self.use_chain_scoring: bool = False self.expert_panel = SimulatedExpertPanel() self.replay_buffer = GenerationalReplayBuffer(max_size=500) self.quality_filter = QualityFilter(novelty_threshold=0.5) self.last_replay_ratio: float = 0.0 self.last_rollout_mix: Dict[str, int] = { "fresh": 0, "replay": 0, "grounded": 0, } # Running counts for the most recent grounded batch, so the training # script can log grounded accuracy per iteration without re-parsing # trajectory metadata. self.last_grounded_stats: Dict[str, float] = { "count": 0, "correct": 0, "accuracy": 0.0, "mean_reward": 0.0, } def sample_instruction(self) -> Tuple[str, str, float]: topic, difficulty = self.curriculum_manager.select_topic_and_difficulty() instruction = self.curriculum_manager.generate_instruction( topic=topic, target_difficulty=difficulty ) return instruction, topic, difficulty def format_solution_prompt(self, question: str) -> str: """Format a question into a chat-templated solver prompt.""" messages = create_solver_messages(question) return self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) def format_question_generation_prompt(self, instruction: str) -> str: """Format a curriculum instruction into a chat-templated generator prompt.""" messages = create_generator_messages(instruction) return self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) def generate_with_logging( self, initial_prompt: str, max_tokens: int, phase: str, ) -> Tuple[str, List[Transition]]: """ Generate text with per-step PPO-grade transition logging. Used by the PPO-compatible rollout methods (``collect_rollouts``, ``rollout_trajectory``, ``rollout_grounded_trajectory``). The GRPO training loop uses ``generate_solutions_batched`` instead. """ import torch.nn.functional as F # local import to keep top-level clean prompt_ids = self.tokenizer.encode( initial_prompt, return_tensors="pt" ).to(self.device) prompt_length = prompt_ids.shape[1] prompt_attn = torch.ones_like(prompt_ids) temperature = float(self.temperature) do_sample = temperature > 1e-4 eos_id = self.tokenizer.eos_token_id pad_id = self.tokenizer.pad_token_id or eos_id gen_kwargs: Dict[str, Any] = dict( input_ids=prompt_ids, attention_mask=prompt_attn, max_new_tokens=max_tokens, do_sample=do_sample, use_cache=True, output_logits=True, return_dict_in_generate=True, pad_token_id=pad_id, eos_token_id=eos_id, ) if do_sample: gen_kwargs["temperature"] = max(temperature, 1e-6) gen_kwargs["top_p"] = float(self.top_p) with torch.no_grad(): gen_out = self.policy.generate(**gen_kwargs) full_ids = gen_out.sequences # [1, P + T] T_gen = int(full_ids.shape[1] - prompt_length) if T_gen <= 0: return "", [] raw_logits = torch.stack([lg[0] for lg in gen_out.logits], dim=0).float() raw_log_probs = F.log_softmax(raw_logits, dim=-1) sampled_tokens = full_ids[0, prompt_length:] chosen_log_probs = raw_log_probs.gather( 1, sampled_tokens.unsqueeze(1) ).squeeze(1) entropies = -(raw_log_probs.exp() * raw_log_probs).sum(dim=-1) positions = torch.arange( prompt_length - 1, prompt_length + T_gen - 1, device=self.device ) full_attn = torch.ones_like(full_ids) if self.value is not None: values = self.value.values_at_positions( input_ids=full_ids, positions=positions, attention_mask=full_attn ) else: values = torch.zeros(T_gen, device=self.device) piece_by_piece: List[str] = self.tokenizer.batch_decode( [[tok.item()] for tok in sampled_tokens], skip_special_tokens=False ) transitions: List[Transition] = [] running_text = initial_prompt for t in range(T_gen): state_input_ids = full_ids[0, : prompt_length + t] current_state = State( text=running_text, input_ids=state_input_ids, attention_mask=torch.ones_like(state_input_ids), phase=phase, ) action_token = int(sampled_tokens[t].item()) action = Action( token_id=action_token, log_prob=float(chosen_log_probs[t].item()), entropy=float(entropies[t].item()), ) next_text = running_text + piece_by_piece[t] next_input_ids = full_ids[0, : prompt_length + t + 1] next_state = State( text=next_text, input_ids=next_input_ids, attention_mask=torch.ones_like(next_input_ids), phase=phase, ) is_done = eos_id is not None and action_token == eos_id transitions.append( Transition( state=current_state, action=action, reward=0.0, next_state=next_state, value=float(values[t].item()), done=is_done, ) ) running_text = next_text if is_done: break generated_ids = full_ids[0, prompt_length : prompt_length + len(transitions)] generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True).strip() return generated_text, transitions def _compute_format_score(self, solution: str) -> float: """ Structural format score based purely on text patterns — no SymPy. Checks: - Presence of 'Step N:' lines (multi-step structure) - Presence of 'Final Answer:' line (correct termination) - Length: ≥2 step lines scores highest Returns a score in [0, 1]. """ lines = solution.splitlines() step_lines = [l for l in lines if re.match(r"^\s*Step\s+\d+\s*:", l)] has_final = any(re.match(r"^\s*Final Answer\s*:", l, re.IGNORECASE) for l in lines) n_steps = len(step_lines) if n_steps >= 2: length_bonus = 1.0 elif n_steps == 1: length_bonus = 0.5 else: length_bonus = 0.0 final_ok = 1.0 if has_final else 0.0 # 0.7 × step-structure + 0.3 × final-answer presence return max(0.0, min(1.0, 0.7 * length_bonus + 0.3 * final_ok)) def compute_reward( self, question: str, solution: str, target_topic: str, target_difficulty: float, ) -> Dict[str, object]: # With a PRM scorer plugged in we skip the expensive (and noisy) # TripleVerifier consensus step. PRM gives per-step correctness # against the actual question semantics, which is strictly better # than "do 3 independent samples agree?" if self.prm_scorer is not None: return self._compute_reward_with_prm( question=question, solution=solution, target_topic=target_topic, target_difficulty=target_difficulty, ) raise RuntimeError( "compute_reward called without a PRM scorer. " "CurriculumMathEnvironment requires prm_scorer to be set. " "Pass prm_scorer=ProcessRewardScorer(...) at construction time." ) def _compute_reward_with_prm( self, question: str, solution: str, target_topic: str, target_difficulty: float, ) -> Dict[str, object]: """ Self-play reward using Qwen2.5-Math-PRM as the semantic-correctness signal. PRM gives per-step probabilities that each reasoning step is correct *given the question* — exactly the signal consensus voting was supposed to approximate but couldn't (three samples from the same policy agree on wrong answers). Solution reward (PRM path): R_sol = 0.45·prm_final + 0.35·prm_mean + 0.20·lccp R = 0.4·R_q + 0.6·R_sol (then expert-panel modifier) * ``prm_final`` (final step score) is the strongest predictor of overall answer correctness. * ``prm_mean`` provides a smooth gradient over all steps. * ``lccp`` (Longest Correct Consecutive Prefix) rewards chain integrity — consecutive correct steps before the first failure. * The 0.4/0.6 Q/Sol split boosts gradient to question-generation without starving the solution-correctness signal. """ assert self.prm_scorer is not None, "caller must check self.prm_scorer" prm_result = self.prm_scorer.score_solution( question=question, solution=solution ) format_score = self._compute_format_score(solution) prm_mean = float(prm_result.get("mean_score", 0.0)) prm_min = float(prm_result.get("min_score", 0.0)) prm_final = float(prm_result.get("final_score", 0.0)) prm_num_steps = int(prm_result.get("num_steps", 0)) prm_degraded = bool(prm_result.get("degraded", False)) # If the PRM degraded (empty solution, tokeniser mismatch, truncation), # the output is effectively unparseable. Prior behavior was to fall # back on SymPy+format, but the upstream ``base_combined_score`` also # blends in the question reward — so the policy got a positive signal # for producing a broken solution as long as the *question* looked # fine. We now treat a degraded PRM as a hard zero on the solution # reward; the question reward is gated below so the full combined # score also collapses. if prm_degraded or prm_num_steps == 0: solution_reward = 0.0 _sp_lccp = 0.0 sol_valid = False _sp_chain_integrity: Optional[float] = None logger.info( "PRM degraded (%s); sol_reward set to 0.0 (format=%.2f).", prm_result.get("degraded_reason", "unknown"), format_score, ) else: # LCCP for self-play: same chain-integrity measure as grounded path _sp_step_scores = prm_result.get("step_scores", []) or [] if _sp_step_scores: _first_fail = next( (i for i, s in enumerate(_sp_step_scores) if s <= 0.5), len(_sp_step_scores), ) _sp_lccp = _first_fail / len(_sp_step_scores) else: _sp_lccp = 0.0 # Self-play solution: PRM-only reward blending mean, final & chain integrity. # LCCP anchors the grade to *consecutive* correctness, not just bag-of-steps. solution_reward = ( 0.45 * prm_final + 0.35 * prm_mean + 0.20 * _sp_lccp ) # Phase 2+ chain scoring: replace PRM solution blend with unified # chain integrity + dependency consistency. This also populates the # question_score from the unified calculator so the Q/Sol weighting # below uses chain-verified signals instead of PRM proxies. _sp_chain_integrity = None if self.use_chain_scoring and self.unified_accuracy_calc is not None: try: _sp_report = self.unified_accuracy_calc.compute( solution=solution, gold_answer=None, question=question, topic=target_topic, phase="selfplay", ) solution_reward = _sp_report.composite_accuracy _sp_chain_integrity = _sp_report.chain_integrity_score except Exception as _sp_exc: logger.debug("Unified accuracy calc (self-play) failed: %s", _sp_exc) sol_valid = True solution_reward = max(0.0, min(1.0, solution_reward)) question_result = self.question_evaluator.evaluate( question=question, solution=solution, # Synthesize a "consensus-equivalent" dict so the question # evaluator keeps working unchanged. PRM mean score stands # in for consensus strength since both are correctness proxies. consensus_result={ "has_majority": prm_mean >= 0.5, "consensus_strength": prm_mean, "primary_matches_majority": prm_mean >= 0.5, "answer_diversity": 0, "majority_answer": None, "primary_answer": None, }, target_topic=target_topic, target_difficulty=target_difficulty, ) question_reward = float(question_result["overall_score"]) # Gate the question-quality bonus on having a parseable solution. # A great-looking question with a broken solution is not progress # toward self-improvement — it's the policy gaming whichever # signal is easier to produce. effective_question_reward = question_reward if sol_valid else 0.0 # Q/Sol = 0.4/0.6 — see note in compute_reward (non-PRM path). base_combined_score = ( 0.4 * effective_question_reward + 0.6 * solution_reward ) # Format floor: if the solution structure is broken (<0.5 format), # cap the overall reward at 0.3 regardless of how much the PRM # likes the prose. Previously we saw combined=0.83 with # Format=0.30, i.e. the PRM "approved" an output that didn't have # parseable Step/Final Answer lines — pure reward hacking. format_floor_active = format_score < 0.5 format_cap = 0.3 if format_floor_active else 1.0 base_combined_score = min(base_combined_score, format_cap) # Novelty gate: prevent template-copying reward hacking. # If the model just generates "John has X apples..." with different numbers, # n-gram similarity to the reference corpus is high → dataset_novelty is LOW. # We cap the reward to discourage this without penalising genuinely novel questions. # < 0.20: near-copy of a training question (template + new variables) → cap 0.35 # > 0.85: completely off-domain (not a real math problem style) → cap 0.55 # [0.20, 0.85]: Goldilocks zone → full reward (novelty_cap = 1.0) _dataset_novelty = float( question_result.get("novelty", {}).get("dataset_novelty", 0.5) if isinstance(question_result.get("novelty"), dict) else 0.5 ) if _dataset_novelty < 0.20: _novelty_cap = 0.35 elif _dataset_novelty > 0.85: _novelty_cap = 0.55 else: _novelty_cap = 1.0 if _novelty_cap < 1.0: base_combined_score = min(base_combined_score, _novelty_cap) logger.debug( "Novelty gate: dataset_novelty=%.2f → cap=%.2f (was %.3f → now %.3f)", _dataset_novelty, _novelty_cap, base_combined_score / _novelty_cap if _novelty_cap > 0 else 0, base_combined_score, ) expert_adjustment = self.expert_panel.apply_expert_preferences( base_reward=base_combined_score, question_metrics=question_result, solution_metrics={ # Only format_compliance still influences shaping — the # PRM/correctness signal lives inside ``solution_reward`` # already and must not be double-counted here. "format_compliance": format_score, }, iteration=self.curriculum_manager.current_iteration, ) combined_score = float(expert_adjustment["adjusted_reward"]) # Re-clip after additive shaping + respect the format cap one more # time so the shaping can't lift a badly-formatted solution back # above the cap. combined_score = max(0.0, min(format_cap, combined_score)) # Curriculum mastery: consider self-play solution "successful" when # both the chain mean AND the final concluding step are above threshold. # Using prm_final as a required condition prevents a solution that gets # most steps right but fails the conclusion from being marked "mastered". solution_success = ( (not prm_degraded) and (prm_mean >= 0.65) and (prm_final >= 0.50) ) self.curriculum_manager.update_from_trajectory( topic=target_topic, question_reward=question_reward, solution_success=solution_success, combined_reward=combined_score, measured_difficulty=float(question_result["measured_difficulty"]), ) modifier_val = float(expert_adjustment.get("reward_modifier", 0.0)) floor_tag = " FLOOR" if format_floor_active else "" valid_tag = "" if sol_valid else " [SOL_INVALID]" logger.info( "PRM reward%s: combined=%.3f = clip(base=%.3f + mod=%+.3f, cap=%.2f)%s " "| Q=%.2f sol=%.3f novelty=%.2f | " "sol=0.45*prm_final(%.2f)+0.35*prm_mean(%.2f)+0.20*lccp(%.2f) " "| steps=%d", valid_tag, combined_score, base_combined_score, modifier_val, format_cap, floor_tag, effective_question_reward, solution_reward, _dataset_novelty, prm_final, prm_mean, _sp_lccp if sol_valid else 0.0, prm_num_steps, ) # Shape a consensus-style verification_details dict so downstream # aggregation (which reads these keys) keeps working unchanged. verification_details = { "consensus": { "has_majority": prm_mean >= 0.5, "consensus_strength": prm_mean, "primary_matches_majority": prm_mean >= 0.5, "answer_diversity": 0, "majority_answer": None, "primary_answer": extract_final_answer_numeric_str(solution) or None, "prm_mean_score": prm_mean, "prm_min_score": prm_min, "prm_final_score": prm_final, "prm_step_scores": prm_result.get("step_scores", []), "prm_num_steps": prm_num_steps, "prm_degraded": prm_degraded, }, } return { "combined_score": combined_score, "base_combined_score": base_combined_score, "effective_question_reward": effective_question_reward, # gated (0 when sol invalid) "question_metrics": question_result, "solution_metrics": { "overall_score": solution_reward, "correctness": prm_mean, "format_compliance": format_score, "efficiency": prm_mean, # legacy slot "consensus_score": prm_mean, # legacy slot "prm_mean_score": prm_mean, "prm_min_score": prm_min, "prm_final_score": prm_final, "prm_step_scores": prm_result.get("step_scores", []), "prm_num_steps": prm_num_steps, "prm_degraded": prm_degraded, "verification_details": verification_details, }, "curriculum_metrics": { "target_topic": target_topic, "target_difficulty": target_difficulty, "detected_topic": question_result["detected_topic"], "measured_difficulty": question_result["measured_difficulty"], }, "expert_metrics": expert_adjustment, # Chain scoring metrics (Phase 2+; None when use_chain_scoring=False) "sp_chain_integrity_score": _sp_chain_integrity, } # ------------------------------------------------------------------ # Grounded (GSM8K-anchored) rollouts # ------------------------------------------------------------------ # # Why this exists: self-play rewards are dominated by consensus voting # between 3 same-model samples, which correlates poorly with GSM8K # accuracy (all three samples can be wrong in the same way). For the # grounded path we solve a known GSM8K problem and score the solution # directly against the gold final answer, which is the only signal # guaranteed to move the benchmark we actually evaluate on. # # The reward: R = 0.50·gt_match + 0.40·process(PRM) + 0.10·format # # * gt_match = 1.0 iff the model's Final Answer is mathematically # equivalent to the GSM8K gold final (via sympy.simplify on the # extracted numeric string). # * process = 0.60·prm_final + 0.40·prm_mean (PRM step-level quality) # * format rewards Step N: lines and a Final Answer: line. # # No TripleVerifier call on this path — ground truth obviates consensus. @staticmethod def _norm_expr_for_match(s: str) -> str: s = (s or "").strip() s = s.replace("^", "**") s = re.sub(r"[,$€£\s]+", "", s) return s @classmethod def _answers_equivalent(cls, pred: str, gold: str) -> bool: """Return True iff ``pred`` and ``gold`` parse to the same number.""" if not pred or not gold: return False p = cls._norm_expr_for_match(pred) g = cls._norm_expr_for_match(gold) if p == g: return True try: diff = simplify( parse_expr(normalize_for_parse_expr(p)) - parse_expr(normalize_for_parse_expr(g)) ) return bool(diff == 0) except Exception: return False def compute_grounded_reward( self, question: str, solution: str, gold_final: str, ) -> Dict[str, object]: """ Compute a ground-truth-anchored reward for a solution to a known GSM8K problem. No TripleVerifier call — the gold final answer replaces consensus voting as the semantic check. """ format_score = self._compute_format_score(solution) pred_final = extract_final_answer_numeric_str(solution) or "" gt_match_bool = self._answers_equivalent(pred_final, gold_final) if gt_match_bool: gt_match = 1.0 else: # Soft numeric proximity: reward near-misses rather than cliffing at 0. # Gives partial credit proportional to how close the numeric answer is. # Capped at 0.85 so an exact match (1.0) is always strictly better. # Non-numeric wrong answers still get 0.0. try: _p = float(pred_final.replace(",", "").strip()) _g = float(gold_final.replace(",", "").strip()) _denom = max(abs(_g), 1.0) gt_match = min(0.85, 1.0 / (1.0 + 2.0 * abs(_p - _g) / _denom)) except (ValueError, TypeError, AttributeError): gt_match = 0.0 # Optional PRM step-level quality on grounded rollouts. # prm_final (last step score) is the strongest single predictor of # answer correctness. step_accuracy = fraction of steps the PRM # considers correct — the direct measure of reasoning process quality. prm_mean = 0.0 prm_final = 0.0 prm_step_scores: List[float] = [] prm_num_steps = 0 prm_degraded = True if self.prm_scorer is not None: prm_result = self.prm_scorer.score_solution( question=question, solution=solution ) prm_degraded = bool(prm_result.get("degraded", False)) if not prm_degraded: prm_mean = float(prm_result.get("mean_score", 0.0)) prm_final = float(prm_result.get("final_score", 0.0)) prm_step_scores = list(prm_result.get("step_scores", [])) prm_num_steps = int(prm_result.get("num_steps", 0)) # Step accuracy: fraction of individual steps rated correct by PRM. step_accuracy = ( sum(1.0 for s in prm_step_scores if s > 0.5) / len(prm_step_scores) if prm_step_scores else 0.0 ) # Longest Correct Consecutive Prefix (LCCP): fraction of steps from # the start that are ALL rated correct before the first failure. # This captures chain integrity — a broken step 3 makes steps 4+ invalid # regardless of their individual PRM scores. # LCCP=1.0 means every step was correct (necessary condition for right answer). # LCCP=0.0 means step 1 itself was wrong (model never had a valid chain). if prm_step_scores: first_fail = next( (i for i, s in enumerate(prm_step_scores) if s <= 0.5), len(prm_step_scores) ) lccp = first_fail / len(prm_step_scores) else: lccp = 0.0 if self.prm_scorer is not None and not prm_degraded: # process_score: weight prm_final (conclusion step) more than mean # — the final step is the most critical and most predictive. process_score = 0.60 * prm_final + 0.40 * prm_mean combined = ( 0.50 * gt_match + 0.40 * process_score + 0.10 * format_score ) _gt_tag = "exact" if gt_match_bool else f"prox={gt_match:.2f}" components_str = ( f"0.50×{gt_match:.2f}({_gt_tag}) + 0.40×proc({process_score:.3f}" f"[fin={prm_final:.2f},mean={prm_mean:.2f}]) + " f"0.10×fmt({format_score:.3f})" ) else: combined = 0.85 * gt_match + 0.15 * format_score components_str = ( f"0.85×{gt_match:.2f} + 0.15×fmt({format_score:.3f})" ) # Phase 2+ chain scoring: override process_score, step_accuracy, lccp, # and combined with formally-verified chain integrity metrics. # PRM is still called above so its scores remain logged for comparison. _chain_report = None if self.use_chain_scoring and self.unified_accuracy_calc is not None: try: _chain_report = self.unified_accuracy_calc.compute( solution=solution, gold_answer=gold_final, topic="grounded", phase="grounded", ) process_score = _chain_report.chain_integrity_score step_accuracy = _chain_report.step_arithmetic_score lccp = _chain_report.lccp_score combined = max(0.0, min(1.0, 0.50 * gt_match + 0.30 * process_score + 0.20 * lccp )) components_str = ( f"0.50×{gt_match:.2f} + 0.30×chain({process_score:.3f}" f"[arith={_chain_report.step_arithmetic_score:.2f}," f"dep={_chain_report.step_dependency_score:.2f}]) + " f"0.20×lccp({lccp:.3f})" ) except Exception as _chain_exc: logger.debug("Unified accuracy calc failed, keeping PRM scores: %s", _chain_exc) else: combined = max(0.0, min(1.0, combined)) # Hard negative mining: wrong-answer solutions still get a partial signal # proportional to how far they got before the first error (LCCP). # This prevents gradient starvation on hard problems where no solution in # the group is fully correct — the model still learns "longer correct prefix # is better" rather than receiving zero reward for all K samples. if gt_match < 0.5 and lccp > 0.0 and self.prm_scorer is not None: # Bonus = 0.15 × LCCP, capped so that a wrong answer (combined ≈ 0.40) # can never exceed 0.55 — always well below a correct answer (≈ 0.90+). _hnm_bonus = 0.15 * lccp combined = min(combined + _hnm_bonus, 0.55) _chain_depth = first_fail if prm_step_scores else 0 logger.info( "Grounded reward: combined=%.3f = %s | pred=%r gold=%r | " "step_acc=%.0f%% lccp=%.0f%% (chain=%d/%d ok_count=%d) n_steps=%d", combined, components_str, pred_final, gold_final, 100 * step_accuracy, 100 * lccp, _chain_depth, len(prm_step_scores), sum(1 for s in prm_step_scores if s > 0.5), prm_num_steps, ) return { "combined_score": combined, "gt_match": gt_match_bool, # process metrics "step_accuracy": step_accuracy, "lccp": lccp, # longest correct consecutive prefix ratio "prm_mean_score": prm_mean, "prm_final_score": prm_final, "prm_step_scores": prm_step_scores, "prm_num_steps": prm_num_steps, "prm_degraded": prm_degraded, # format / answer "format_score": format_score, "pred_final": pred_final, "gold_final": gold_final, # chain scoring metrics (populated in Phase 2+, None otherwise) "chain_arith_score": _chain_report.step_arithmetic_score if _chain_report else None, "chain_dep_score": _chain_report.step_dependency_score if _chain_report else None, "chain_integrity_score": _chain_report.chain_integrity_score if _chain_report else None, "first_failure_step": _chain_report.first_failure_step if _chain_report else None, "final_consistent": _chain_report.final_answer_consistent if _chain_report else None, } def rollout_grounded_trajectory(self, qa_pair: Dict[str, str]) -> Trajectory: """ Run a rollout on a known GSM8K (question, gold_final) pair. The policy generates a solution to the real question; reward is dominated by whether the model's final number matches the gold final (ground-truth-anchored). """ question = str(qa_pair["question"]).strip() gold_final = str(qa_pair["gold_final"]).strip() solution_prompt = self.format_solution_prompt(question) generated_solution, solution_transitions = self.generate_with_logging( initial_prompt=solution_prompt, max_tokens=self.max_solution_tokens, phase="grounded_solution", ) reward_result = self.compute_grounded_reward( question=question, solution=generated_solution, gold_final=gold_final, ) terminal_reward = float(reward_result["combined_score"]) trajectory = Trajectory() for idx, transition in enumerate(solution_transitions): transition.reward = ( terminal_reward if idx == len(solution_transitions) - 1 else 0.0 ) trajectory.add(transition) metadata = { "rollout_source": "grounded", "curriculum_iteration": self.curriculum_manager.current_iteration, "target_topic": "grounded_gsm8k", "target_difficulty": 0.5, "instruction": "", "generated_question": question, "generated_solution": generated_solution, "question_length": 0, "solution_length": len(solution_transitions), "detected_topic": "grounded_gsm8k", "detected_secondary_topics": [], "topic_match_score": 1.0, "estimated_difficulty": 0.5, "clarity_score": 1.0, "novelty_scores": {"combined": 0.0}, "consensus_achieved": bool(reward_result["gt_match"]), "consensus_strength": 1.0 if reward_result["gt_match"] else 0.0, "answer_diversity": 0, "majority_answer": None, "primary_matches_majority": bool(reward_result["gt_match"]), "question_reward": 0.0, "solution_reward": terminal_reward, "pre_expert_reward": terminal_reward, "expert_reward_modifier": 0.0, "expert_phase": "grounded", "expert_feedback": "ground-truth anchored", "replay_candidate": False, "replay_novelty": 0.0, "replay_added": False, "combined_reward": terminal_reward, "reward_breakdown": { "grounded": True, "gt_match": bool(reward_result["gt_match"]), "format_score": float(reward_result["format_score"]), "pred_final": reward_result["pred_final"], "gold_final": reward_result["gold_final"], "prm_mean_score": float(reward_result.get("prm_mean_score", 0.0)), "prm_num_steps": int(reward_result.get("prm_num_steps", 0)), "prm_step_scores": list(reward_result.get("prm_step_scores", [])), "prm_degraded": bool(reward_result.get("prm_degraded", True)), }, "topics_in_sweet_spot": self.curriculum_manager.get_sweet_spot_topics(), "current_focus_topics": self.curriculum_manager.get_current_focus(), "curriculum_state_snapshot": self.curriculum_manager.get_curriculum_stats(), "grounded_gt_match": bool(reward_result["gt_match"]), "grounded_pred_final": reward_result["pred_final"], "grounded_gold_final": reward_result["gold_final"], } trajectory.metadata = metadata return trajectory def rollout_trajectory(self) -> Trajectory: instruction, target_topic, target_difficulty = self.sample_instruction() question_prompt = self.format_question_generation_prompt(instruction) generated_question, question_transitions = self.generate_with_logging( initial_prompt=question_prompt, max_tokens=self.max_question_tokens, phase="question_generation", ) return self._build_trajectory_from_question( instruction=instruction, target_topic=target_topic, target_difficulty=target_difficulty, generated_question=generated_question, question_transitions=question_transitions, ) def _build_trajectory_from_question( self, instruction: str, target_topic: str, target_difficulty: float, generated_question: str, question_transitions: Optional[List] = None, ) -> Trajectory: trajectory = Trajectory() question_transitions = question_transitions or [] solution_prompt = self.format_solution_prompt(generated_question) generated_solution, solution_transitions = self.generate_with_logging( initial_prompt=solution_prompt, max_tokens=self.max_solution_tokens, phase="solution", ) reward_result = self.compute_reward( question=generated_question, solution=generated_solution, target_topic=target_topic, target_difficulty=target_difficulty, ) terminal_reward = float(reward_result["combined_score"]) all_transitions = question_transitions + solution_transitions # Terminal-only reward — gae_lambda=1.0 makes A_t = R - V(s_t) for all t. for idx, transition in enumerate(all_transitions): transition.reward = ( terminal_reward if idx == len(all_transitions) - 1 else 0.0 ) trajectory.add(transition) verification = reward_result["solution_metrics"]["verification_details"] consensus = verification["consensus"] question_metrics = reward_result["question_metrics"] metadata = TrajectoryMetadata( curriculum_iteration=self.curriculum_manager.current_iteration, target_topic=target_topic, target_difficulty=target_difficulty, instruction=instruction, generated_question=generated_question, generated_solution=generated_solution, question_length=len(question_transitions), solution_length=len(solution_transitions), detected_topic=str(question_metrics["detected_topic"]["primary_topic"]), detected_secondary_topics=[ str(x) for x in question_metrics["detected_topic"]["secondary_topics"] ], topic_match_score=float(question_metrics["topic_match"]), estimated_difficulty=float(question_metrics["measured_difficulty"]), clarity_score=float(question_metrics["clarity"]), novelty_scores=dict(question_metrics["novelty"]), consensus_achieved=bool(consensus["has_majority"]), consensus_strength=float(consensus["consensus_strength"]), answer_diversity=int(consensus["answer_diversity"]), majority_answer=consensus.get("majority_answer"), primary_matches_majority=bool(consensus["primary_matches_majority"]), sympy_verified=True, steps_total=int(consensus.get("prm_num_steps", 0)), steps_verified_ok=int(consensus.get("prm_num_steps", 0)), steps_failed=0, final_answer_ok=bool(consensus.get("primary_matches_majority", False)), question_reward=float(question_metrics["overall_score"]), solution_reward=float(reward_result["solution_metrics"]["overall_score"]), pre_expert_reward=float(reward_result["base_combined_score"]), expert_reward_modifier=float( reward_result["expert_metrics"]["reward_modifier"] ), expert_phase=str(reward_result["expert_metrics"]["phase"]), expert_feedback=str(reward_result["expert_metrics"]["feedback"]), replay_candidate=False, replay_novelty=0.0, replay_added=False, combined_reward=terminal_reward, reward_breakdown=reward_result, topics_in_sweet_spot=self.curriculum_manager.get_sweet_spot_topics(), current_focus_topics=self.curriculum_manager.get_current_focus(), curriculum_state_snapshot=self.curriculum_manager.get_curriculum_stats(), ) metadata_dict = asdict(metadata) trajectory.metadata = metadata_dict # Replay admission: requires trajectory.metadata to already exist # because check_novelty reads metadata["generated_question"]. is_candidate, reason = self.quality_filter.meets_replay_criteria(metadata_dict) metadata_dict["replay_candidate"] = is_candidate if is_candidate: novelty_score = self.quality_filter.check_novelty( trajectory, self.replay_buffer.buffer ) metadata_dict["replay_novelty"] = float(novelty_score) if self.quality_filter.is_novel_enough(novelty_score): quality_score = self.quality_filter.compute_quality_score(metadata_dict) self.replay_buffer.add_trajectory( trajectory=trajectory, metadata=metadata_dict, iteration=self.curriculum_manager.current_iteration, quality_score=quality_score, ) metadata_dict["replay_added"] = True else: metadata_dict["replay_added"] = False else: metadata_dict["replay_added"] = False metadata_dict["replay_reject_reason"] = reason trajectory.metadata = metadata_dict return trajectory def _get_adaptive_replay_ratio(self) -> float: iteration = self.curriculum_manager.current_iteration if iteration < 3: return 0.0 if iteration < 5: return 0.15 buffer_stats = self.replay_buffer.get_buffer_stats(current_iteration=iteration) buffer_health = float(buffer_stats.get("buffer_health", 0.0)) if buffer_health >= 0.75: return 0.3 if buffer_health >= 0.6: return 0.25 return 0.2 def collect_rollouts( self, num_trajectories: int, verbose: bool = True, grounded_ratio: float = 0.0, ) -> List[Trajectory]: """ Generate ``num_trajectories`` episodes in-process on the current device. Mix: * ``grounded_ratio`` of rollouts are GSM8K-anchored (real question, reward scored against gold final answer). These give the policy a clean gradient toward benchmark correctness and are also ~3x faster than self-play rollouts (no TripleVerifier call). * an adaptive fraction is drawn from the replay buffer when buffer health is good (self-play only). * the remainder are fresh self-play rollouts. """ if num_trajectories <= 0: return [] # Defensive .eval() on both policy and value before any generation. # The first iteration runs rollouts right after model load (HF default # is .train()). Qwen2.5 has zero dropout so this is currently cosmetic, # but cheap insurance against any future model swap with stochastic layers. if self.policy is not None: self.policy.eval() if self.value is not None: self.value.eval() # Grounded rollouts: only if we actually have QA pairs loaded. if grounded_ratio > 0.0 and self.grounded_qa_pairs: num_grounded = int(round(num_trajectories * grounded_ratio)) num_grounded = min(num_grounded, num_trajectories) else: num_grounded = 0 num_selfplay = num_trajectories - num_grounded # Within the self-play half, the existing replay-buffer mix applies. replay_ratio = self._get_adaptive_replay_ratio() num_replay = int(num_selfplay * replay_ratio) num_replay = min(num_replay, len(self.replay_buffer)) num_fresh = max(0, num_selfplay - num_replay) # ---- Grounded rollouts (GSM8K-anchored) -------------------------- grounded_trajectories: List[Trajectory] = [] grounded_correct = 0 grounded_reward_sum = 0.0 if num_grounded > 0: qa_sample = random.sample( self.grounded_qa_pairs, k=min(num_grounded, len(self.grounded_qa_pairs)), ) # If we asked for more grounded rollouts than we have distinct # pairs, pad by re-sampling with replacement. while len(qa_sample) < num_grounded: qa_sample.append(random.choice(self.grounded_qa_pairs)) pbar = tqdm( qa_sample, desc="Grounded rollouts", unit="ep", dynamic_ncols=True, leave=False, disable=not verbose, ) for qa in pbar: trajectory = self.rollout_grounded_trajectory(qa) grounded_trajectories.append(trajectory) r = float(trajectory.metadata.get("combined_reward", 0.0)) grounded_reward_sum += r if bool(trajectory.metadata.get("grounded_gt_match", False)): grounded_correct += 1 done = len(grounded_trajectories) pbar.set_postfix( acc=f"{grounded_correct / done:.1%}", reward=f"{grounded_reward_sum / done:+.3f}", refresh=False, ) # ---- Fresh self-play rollouts ------------------------------------ fresh_trajectories: List[Trajectory] = [] pbar = tqdm( range(num_fresh), desc="Self-play rollouts", unit="ep", dynamic_ncols=True, leave=False, disable=not verbose, ) running_reward = 0.0 running_ok = 0 for _ in pbar: trajectory = self.rollout_trajectory() trajectory.metadata["rollout_source"] = "fresh" fresh_trajectories.append(trajectory) running_reward += float(trajectory.metadata.get("combined_reward", 0.0)) if trajectory.metadata.get("final_answer_ok", False): running_ok += 1 done = len(fresh_trajectories) pbar.set_postfix( reward=f"{running_reward / done:+.3f}", ok=f"{running_ok}/{done}", refresh=False, ) # ---- Replay buffer draws ----------------------------------------- replay_trajectories = self.replay_buffer.sample_replay_batch( num_replay, diversity_sample=True ) for trajectory in replay_trajectories: trajectory.metadata["rollout_source"] = "replay" trajectories = ( grounded_trajectories + fresh_trajectories + replay_trajectories ) random.shuffle(trajectories) self.last_replay_ratio = replay_ratio self.last_rollout_mix = { "fresh": len(fresh_trajectories), "replay": len(replay_trajectories), "grounded": len(grounded_trajectories), } grounded_count = len(grounded_trajectories) self.last_grounded_stats = { "count": grounded_count, "correct": grounded_correct, "accuracy": ( grounded_correct / grounded_count if grounded_count > 0 else 0.0 ), "mean_reward": ( grounded_reward_sum / grounded_count if grounded_count > 0 else 0.0 ), } if verbose: buffer_stats = self.replay_buffer.get_buffer_stats( current_iteration=self.curriculum_manager.current_iteration ) logger.info( "Rollout mix: %d grounded + %d fresh + %d replay " "(grounded_ratio=%.2f, replay_ratio=%.2f, buffer_size=%d, health=%.3f)", len(grounded_trajectories), len(fresh_trajectories), len(replay_trajectories), grounded_ratio, replay_ratio, len(self.replay_buffer), float(buffer_stats.get("buffer_health", 0.0)), ) if grounded_count > 0: logger.info( "Grounded accuracy this iter: %d/%d = %.1f%% (mean reward %.3f)", grounded_correct, grounded_count, 100.0 * grounded_correct / grounded_count, grounded_reward_sum / grounded_count, ) self.curriculum_manager.increment_iteration() self.curriculum_manager.save_state( iteration=self.curriculum_manager.current_iteration, rollout=None ) return trajectories