Spaces:
Sleeping
Sleeping
| """ | |
| Curriculum-aware math environment with dual reward signals. | |
| This file is deliberately minimal: a single ``collect_rollouts`` method is all | |
| the training loop needs. Rollouts and PPO updates run in the same process on | |
| a single GPU — no subprocesses, no RPC, no vLLM colocation. | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import random | |
| import re | |
| from dataclasses import asdict, dataclass | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import torch | |
| from sympy import simplify | |
| from sympy.parsing.sympy_parser import parse_expr | |
| from tqdm.auto import tqdm | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from src.config.prompts import create_generator_messages, create_solver_messages | |
| from src.rl.curriculum_manager import CurriculumManager | |
| from src.rl.expert_panel import SimulatedExpertPanel | |
| from src.rl.mdp_components import Action, State, Trajectory, Transition | |
| from src.rl.prm_scorer import ProcessRewardScorer | |
| from src.rl.quality_filter import QualityFilter | |
| from src.rl.question_quality_evaluator import QuestionQualityEvaluator | |
| from src.rl.replay_buffer import GenerationalReplayBuffer | |
| from src.rl.value_network import ValueHead | |
| from src.sft.solution_format import extract_final_answer_numeric_str | |
| from src.sft.sympy_normalize import normalize_for_parse_expr | |
| logger = logging.getLogger(__name__) | |
| class TrajectoryMetadata: | |
| curriculum_iteration: int | |
| target_topic: str | |
| target_difficulty: float | |
| instruction: str | |
| generated_question: str | |
| generated_solution: str | |
| question_length: int | |
| solution_length: int | |
| detected_topic: str | |
| detected_secondary_topics: List[str] | |
| topic_match_score: float | |
| estimated_difficulty: float | |
| clarity_score: float | |
| novelty_scores: Dict[str, float] | |
| consensus_achieved: bool | |
| consensus_strength: float | |
| answer_diversity: int | |
| majority_answer: Optional[float] | |
| primary_matches_majority: bool | |
| sympy_verified: bool | |
| steps_total: int | |
| steps_verified_ok: int | |
| steps_failed: int | |
| final_answer_ok: bool | |
| question_reward: float | |
| solution_reward: float | |
| pre_expert_reward: float | |
| expert_reward_modifier: float | |
| expert_phase: str | |
| expert_feedback: str | |
| replay_candidate: bool | |
| replay_novelty: float | |
| replay_added: bool | |
| combined_reward: float | |
| reward_breakdown: Dict[str, object] | |
| topics_in_sweet_spot: List[str] | |
| current_focus_topics: List[str] | |
| curriculum_state_snapshot: Dict[str, object] | |
| class CurriculumMathEnvironment: | |
| """Standalone curriculum environment with PRM-based rewards and GRPO training support.""" | |
| def __init__( | |
| self, | |
| policy_model: AutoModelForCausalLM, | |
| value_model: Optional[ValueHead], | |
| tokenizer: AutoTokenizer, | |
| reference_questions: Optional[List[str]] = None, | |
| grounded_qa_pairs: Optional[List[Dict[str, str]]] = None, | |
| prm_scorer: Optional[ProcessRewardScorer] = None, | |
| curriculum_checkpoint_dir: str = "checkpoints/curriculum", | |
| max_question_tokens: int = 200, | |
| max_solution_tokens: int = 500, | |
| temperature: float = 0.7, | |
| top_p: float = 0.9, | |
| consensus_temperature: float = 0.7, | |
| device: Optional[torch.device] = None, | |
| unified_accuracy_calc: Optional[Any] = None, | |
| ): | |
| # ── Core model attributes (used by generation helpers) ─────────── | |
| self.policy = policy_model | |
| self.value = value_model | |
| self.tokenizer = tokenizer | |
| self.max_question_tokens = max_question_tokens | |
| self.max_solution_tokens = max_solution_tokens | |
| self.temperature = temperature | |
| self.top_p = top_p | |
| if device is not None: | |
| self.device = torch.device(device) | |
| else: | |
| try: | |
| self.device = next(policy_model.parameters()).device | |
| except StopIteration: | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.reference_questions = reference_questions or [] | |
| self.grounded_qa_pairs: List[Dict[str, str]] = [ | |
| qa for qa in (grounded_qa_pairs or []) | |
| if qa.get("question") and qa.get("gold_final") | |
| ] | |
| self.consensus_temperature = consensus_temperature | |
| self.curriculum_manager = CurriculumManager(checkpoint_dir=curriculum_checkpoint_dir) | |
| self.curriculum_manager.initialize(bootstrap_questions=self.reference_questions) | |
| self.curriculum_manager.load_checkpoint_safe() | |
| self.question_evaluator = QuestionQualityEvaluator( | |
| reference_questions=self.reference_questions | |
| ) | |
| # PRM is the sole process-quality signal. Passing prm_scorer=None | |
| # will cause compute_reward/compute_grounded_reward to raise at | |
| # call time — GRPO training always supplies the PRM. | |
| self.prm_scorer = prm_scorer | |
| # Unified accuracy calculator — activated on Phase 2+ transition. | |
| # When use_chain_scoring is True, chain_integrity_score from this | |
| # calculator replaces PRM-based process_score in both grounded and | |
| # self-play reward paths. | |
| self.unified_accuracy_calc: Optional[Any] = unified_accuracy_calc | |
| self.use_chain_scoring: bool = False | |
| self.expert_panel = SimulatedExpertPanel() | |
| self.replay_buffer = GenerationalReplayBuffer(max_size=500) | |
| self.quality_filter = QualityFilter(novelty_threshold=0.5) | |
| self.last_replay_ratio: float = 0.0 | |
| self.last_rollout_mix: Dict[str, int] = { | |
| "fresh": 0, | |
| "replay": 0, | |
| "grounded": 0, | |
| } | |
| # Running counts for the most recent grounded batch, so the training | |
| # script can log grounded accuracy per iteration without re-parsing | |
| # trajectory metadata. | |
| self.last_grounded_stats: Dict[str, float] = { | |
| "count": 0, | |
| "correct": 0, | |
| "accuracy": 0.0, | |
| "mean_reward": 0.0, | |
| } | |
| def sample_instruction(self) -> Tuple[str, str, float]: | |
| topic, difficulty = self.curriculum_manager.select_topic_and_difficulty() | |
| instruction = self.curriculum_manager.generate_instruction( | |
| topic=topic, target_difficulty=difficulty | |
| ) | |
| return instruction, topic, difficulty | |
| def format_solution_prompt(self, question: str) -> str: | |
| """Format a question into a chat-templated solver prompt.""" | |
| messages = create_solver_messages(question) | |
| return self.tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| def format_question_generation_prompt(self, instruction: str) -> str: | |
| """Format a curriculum instruction into a chat-templated generator prompt.""" | |
| messages = create_generator_messages(instruction) | |
| return self.tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| def generate_with_logging( | |
| self, | |
| initial_prompt: str, | |
| max_tokens: int, | |
| phase: str, | |
| ) -> Tuple[str, List[Transition]]: | |
| """ | |
| Generate text with per-step PPO-grade transition logging. | |
| Used by the PPO-compatible rollout methods (``collect_rollouts``, | |
| ``rollout_trajectory``, ``rollout_grounded_trajectory``). The GRPO | |
| training loop uses ``generate_solutions_batched`` instead. | |
| """ | |
| import torch.nn.functional as F # local import to keep top-level clean | |
| prompt_ids = self.tokenizer.encode( | |
| initial_prompt, return_tensors="pt" | |
| ).to(self.device) | |
| prompt_length = prompt_ids.shape[1] | |
| prompt_attn = torch.ones_like(prompt_ids) | |
| temperature = float(self.temperature) | |
| do_sample = temperature > 1e-4 | |
| eos_id = self.tokenizer.eos_token_id | |
| pad_id = self.tokenizer.pad_token_id or eos_id | |
| gen_kwargs: Dict[str, Any] = dict( | |
| input_ids=prompt_ids, | |
| attention_mask=prompt_attn, | |
| max_new_tokens=max_tokens, | |
| do_sample=do_sample, | |
| use_cache=True, | |
| output_logits=True, | |
| return_dict_in_generate=True, | |
| pad_token_id=pad_id, | |
| eos_token_id=eos_id, | |
| ) | |
| if do_sample: | |
| gen_kwargs["temperature"] = max(temperature, 1e-6) | |
| gen_kwargs["top_p"] = float(self.top_p) | |
| with torch.no_grad(): | |
| gen_out = self.policy.generate(**gen_kwargs) | |
| full_ids = gen_out.sequences # [1, P + T] | |
| T_gen = int(full_ids.shape[1] - prompt_length) | |
| if T_gen <= 0: | |
| return "", [] | |
| raw_logits = torch.stack([lg[0] for lg in gen_out.logits], dim=0).float() | |
| raw_log_probs = F.log_softmax(raw_logits, dim=-1) | |
| sampled_tokens = full_ids[0, prompt_length:] | |
| chosen_log_probs = raw_log_probs.gather( | |
| 1, sampled_tokens.unsqueeze(1) | |
| ).squeeze(1) | |
| entropies = -(raw_log_probs.exp() * raw_log_probs).sum(dim=-1) | |
| positions = torch.arange( | |
| prompt_length - 1, prompt_length + T_gen - 1, device=self.device | |
| ) | |
| full_attn = torch.ones_like(full_ids) | |
| if self.value is not None: | |
| values = self.value.values_at_positions( | |
| input_ids=full_ids, positions=positions, attention_mask=full_attn | |
| ) | |
| else: | |
| values = torch.zeros(T_gen, device=self.device) | |
| piece_by_piece: List[str] = self.tokenizer.batch_decode( | |
| [[tok.item()] for tok in sampled_tokens], skip_special_tokens=False | |
| ) | |
| transitions: List[Transition] = [] | |
| running_text = initial_prompt | |
| for t in range(T_gen): | |
| state_input_ids = full_ids[0, : prompt_length + t] | |
| current_state = State( | |
| text=running_text, | |
| input_ids=state_input_ids, | |
| attention_mask=torch.ones_like(state_input_ids), | |
| phase=phase, | |
| ) | |
| action_token = int(sampled_tokens[t].item()) | |
| action = Action( | |
| token_id=action_token, | |
| log_prob=float(chosen_log_probs[t].item()), | |
| entropy=float(entropies[t].item()), | |
| ) | |
| next_text = running_text + piece_by_piece[t] | |
| next_input_ids = full_ids[0, : prompt_length + t + 1] | |
| next_state = State( | |
| text=next_text, | |
| input_ids=next_input_ids, | |
| attention_mask=torch.ones_like(next_input_ids), | |
| phase=phase, | |
| ) | |
| is_done = eos_id is not None and action_token == eos_id | |
| transitions.append( | |
| Transition( | |
| state=current_state, | |
| action=action, | |
| reward=0.0, | |
| next_state=next_state, | |
| value=float(values[t].item()), | |
| done=is_done, | |
| ) | |
| ) | |
| running_text = next_text | |
| if is_done: | |
| break | |
| generated_ids = full_ids[0, prompt_length : prompt_length + len(transitions)] | |
| generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True).strip() | |
| return generated_text, transitions | |
| def _compute_format_score(self, solution: str) -> float: | |
| """ | |
| Structural format score based purely on text patterns — no SymPy. | |
| Checks: | |
| - Presence of 'Step N:' lines (multi-step structure) | |
| - Presence of 'Final Answer:' line (correct termination) | |
| - Length: ≥2 step lines scores highest | |
| Returns a score in [0, 1]. | |
| """ | |
| lines = solution.splitlines() | |
| step_lines = [l for l in lines if re.match(r"^\s*Step\s+\d+\s*:", l)] | |
| has_final = any(re.match(r"^\s*Final Answer\s*:", l, re.IGNORECASE) for l in lines) | |
| n_steps = len(step_lines) | |
| if n_steps >= 2: | |
| length_bonus = 1.0 | |
| elif n_steps == 1: | |
| length_bonus = 0.5 | |
| else: | |
| length_bonus = 0.0 | |
| final_ok = 1.0 if has_final else 0.0 | |
| # 0.7 × step-structure + 0.3 × final-answer presence | |
| return max(0.0, min(1.0, 0.7 * length_bonus + 0.3 * final_ok)) | |
| def compute_reward( | |
| self, | |
| question: str, | |
| solution: str, | |
| target_topic: str, | |
| target_difficulty: float, | |
| ) -> Dict[str, object]: | |
| # With a PRM scorer plugged in we skip the expensive (and noisy) | |
| # TripleVerifier consensus step. PRM gives per-step correctness | |
| # against the actual question semantics, which is strictly better | |
| # than "do 3 independent samples agree?" | |
| if self.prm_scorer is not None: | |
| return self._compute_reward_with_prm( | |
| question=question, | |
| solution=solution, | |
| target_topic=target_topic, | |
| target_difficulty=target_difficulty, | |
| ) | |
| raise RuntimeError( | |
| "compute_reward called without a PRM scorer. " | |
| "CurriculumMathEnvironment requires prm_scorer to be set. " | |
| "Pass prm_scorer=ProcessRewardScorer(...) at construction time." | |
| ) | |
| def _compute_reward_with_prm( | |
| self, | |
| question: str, | |
| solution: str, | |
| target_topic: str, | |
| target_difficulty: float, | |
| ) -> Dict[str, object]: | |
| """ | |
| Self-play reward using Qwen2.5-Math-PRM as the semantic-correctness | |
| signal. PRM gives per-step probabilities that each reasoning step | |
| is correct *given the question* — exactly the signal consensus | |
| voting was supposed to approximate but couldn't (three samples | |
| from the same policy agree on wrong answers). | |
| Solution reward (PRM path): | |
| R_sol = 0.45·prm_final + 0.35·prm_mean + 0.20·lccp | |
| R = 0.4·R_q + 0.6·R_sol (then expert-panel modifier) | |
| * ``prm_final`` (final step score) is the strongest predictor of | |
| overall answer correctness. | |
| * ``prm_mean`` provides a smooth gradient over all steps. | |
| * ``lccp`` (Longest Correct Consecutive Prefix) rewards chain | |
| integrity — consecutive correct steps before the first failure. | |
| * The 0.4/0.6 Q/Sol split boosts gradient to question-generation | |
| without starving the solution-correctness signal. | |
| """ | |
| assert self.prm_scorer is not None, "caller must check self.prm_scorer" | |
| prm_result = self.prm_scorer.score_solution( | |
| question=question, solution=solution | |
| ) | |
| format_score = self._compute_format_score(solution) | |
| prm_mean = float(prm_result.get("mean_score", 0.0)) | |
| prm_min = float(prm_result.get("min_score", 0.0)) | |
| prm_final = float(prm_result.get("final_score", 0.0)) | |
| prm_num_steps = int(prm_result.get("num_steps", 0)) | |
| prm_degraded = bool(prm_result.get("degraded", False)) | |
| # If the PRM degraded (empty solution, tokeniser mismatch, truncation), | |
| # the output is effectively unparseable. Prior behavior was to fall | |
| # back on SymPy+format, but the upstream ``base_combined_score`` also | |
| # blends in the question reward — so the policy got a positive signal | |
| # for producing a broken solution as long as the *question* looked | |
| # fine. We now treat a degraded PRM as a hard zero on the solution | |
| # reward; the question reward is gated below so the full combined | |
| # score also collapses. | |
| if prm_degraded or prm_num_steps == 0: | |
| solution_reward = 0.0 | |
| _sp_lccp = 0.0 | |
| sol_valid = False | |
| _sp_chain_integrity: Optional[float] = None | |
| logger.info( | |
| "PRM degraded (%s); sol_reward set to 0.0 (format=%.2f).", | |
| prm_result.get("degraded_reason", "unknown"), | |
| format_score, | |
| ) | |
| else: | |
| # LCCP for self-play: same chain-integrity measure as grounded path | |
| _sp_step_scores = prm_result.get("step_scores", []) or [] | |
| if _sp_step_scores: | |
| _first_fail = next( | |
| (i for i, s in enumerate(_sp_step_scores) if s <= 0.5), | |
| len(_sp_step_scores), | |
| ) | |
| _sp_lccp = _first_fail / len(_sp_step_scores) | |
| else: | |
| _sp_lccp = 0.0 | |
| # Self-play solution: PRM-only reward blending mean, final & chain integrity. | |
| # LCCP anchors the grade to *consecutive* correctness, not just bag-of-steps. | |
| solution_reward = ( | |
| 0.45 * prm_final | |
| + 0.35 * prm_mean | |
| + 0.20 * _sp_lccp | |
| ) | |
| # Phase 2+ chain scoring: replace PRM solution blend with unified | |
| # chain integrity + dependency consistency. This also populates the | |
| # question_score from the unified calculator so the Q/Sol weighting | |
| # below uses chain-verified signals instead of PRM proxies. | |
| _sp_chain_integrity = None | |
| if self.use_chain_scoring and self.unified_accuracy_calc is not None: | |
| try: | |
| _sp_report = self.unified_accuracy_calc.compute( | |
| solution=solution, | |
| gold_answer=None, | |
| question=question, | |
| topic=target_topic, | |
| phase="selfplay", | |
| ) | |
| solution_reward = _sp_report.composite_accuracy | |
| _sp_chain_integrity = _sp_report.chain_integrity_score | |
| except Exception as _sp_exc: | |
| logger.debug("Unified accuracy calc (self-play) failed: %s", _sp_exc) | |
| sol_valid = True | |
| solution_reward = max(0.0, min(1.0, solution_reward)) | |
| question_result = self.question_evaluator.evaluate( | |
| question=question, | |
| solution=solution, | |
| # Synthesize a "consensus-equivalent" dict so the question | |
| # evaluator keeps working unchanged. PRM mean score stands | |
| # in for consensus strength since both are correctness proxies. | |
| consensus_result={ | |
| "has_majority": prm_mean >= 0.5, | |
| "consensus_strength": prm_mean, | |
| "primary_matches_majority": prm_mean >= 0.5, | |
| "answer_diversity": 0, | |
| "majority_answer": None, | |
| "primary_answer": None, | |
| }, | |
| target_topic=target_topic, | |
| target_difficulty=target_difficulty, | |
| ) | |
| question_reward = float(question_result["overall_score"]) | |
| # Gate the question-quality bonus on having a parseable solution. | |
| # A great-looking question with a broken solution is not progress | |
| # toward self-improvement — it's the policy gaming whichever | |
| # signal is easier to produce. | |
| effective_question_reward = question_reward if sol_valid else 0.0 | |
| # Q/Sol = 0.4/0.6 — see note in compute_reward (non-PRM path). | |
| base_combined_score = ( | |
| 0.4 * effective_question_reward + 0.6 * solution_reward | |
| ) | |
| # Format floor: if the solution structure is broken (<0.5 format), | |
| # cap the overall reward at 0.3 regardless of how much the PRM | |
| # likes the prose. Previously we saw combined=0.83 with | |
| # Format=0.30, i.e. the PRM "approved" an output that didn't have | |
| # parseable Step/Final Answer lines — pure reward hacking. | |
| format_floor_active = format_score < 0.5 | |
| format_cap = 0.3 if format_floor_active else 1.0 | |
| base_combined_score = min(base_combined_score, format_cap) | |
| # Novelty gate: prevent template-copying reward hacking. | |
| # If the model just generates "John has X apples..." with different numbers, | |
| # n-gram similarity to the reference corpus is high → dataset_novelty is LOW. | |
| # We cap the reward to discourage this without penalising genuinely novel questions. | |
| # < 0.20: near-copy of a training question (template + new variables) → cap 0.35 | |
| # > 0.85: completely off-domain (not a real math problem style) → cap 0.55 | |
| # [0.20, 0.85]: Goldilocks zone → full reward (novelty_cap = 1.0) | |
| _dataset_novelty = float( | |
| question_result.get("novelty", {}).get("dataset_novelty", 0.5) | |
| if isinstance(question_result.get("novelty"), dict) | |
| else 0.5 | |
| ) | |
| if _dataset_novelty < 0.20: | |
| _novelty_cap = 0.35 | |
| elif _dataset_novelty > 0.85: | |
| _novelty_cap = 0.55 | |
| else: | |
| _novelty_cap = 1.0 | |
| if _novelty_cap < 1.0: | |
| base_combined_score = min(base_combined_score, _novelty_cap) | |
| logger.debug( | |
| "Novelty gate: dataset_novelty=%.2f → cap=%.2f (was %.3f → now %.3f)", | |
| _dataset_novelty, _novelty_cap, | |
| base_combined_score / _novelty_cap if _novelty_cap > 0 else 0, | |
| base_combined_score, | |
| ) | |
| expert_adjustment = self.expert_panel.apply_expert_preferences( | |
| base_reward=base_combined_score, | |
| question_metrics=question_result, | |
| solution_metrics={ | |
| # Only format_compliance still influences shaping — the | |
| # PRM/correctness signal lives inside ``solution_reward`` | |
| # already and must not be double-counted here. | |
| "format_compliance": format_score, | |
| }, | |
| iteration=self.curriculum_manager.current_iteration, | |
| ) | |
| combined_score = float(expert_adjustment["adjusted_reward"]) | |
| # Re-clip after additive shaping + respect the format cap one more | |
| # time so the shaping can't lift a badly-formatted solution back | |
| # above the cap. | |
| combined_score = max(0.0, min(format_cap, combined_score)) | |
| # Curriculum mastery: consider self-play solution "successful" when | |
| # both the chain mean AND the final concluding step are above threshold. | |
| # Using prm_final as a required condition prevents a solution that gets | |
| # most steps right but fails the conclusion from being marked "mastered". | |
| solution_success = ( | |
| (not prm_degraded) | |
| and (prm_mean >= 0.65) | |
| and (prm_final >= 0.50) | |
| ) | |
| self.curriculum_manager.update_from_trajectory( | |
| topic=target_topic, | |
| question_reward=question_reward, | |
| solution_success=solution_success, | |
| combined_reward=combined_score, | |
| measured_difficulty=float(question_result["measured_difficulty"]), | |
| ) | |
| modifier_val = float(expert_adjustment.get("reward_modifier", 0.0)) | |
| floor_tag = " FLOOR" if format_floor_active else "" | |
| valid_tag = "" if sol_valid else " [SOL_INVALID]" | |
| logger.info( | |
| "PRM reward%s: combined=%.3f = clip(base=%.3f + mod=%+.3f, cap=%.2f)%s " | |
| "| Q=%.2f sol=%.3f novelty=%.2f | " | |
| "sol=0.45*prm_final(%.2f)+0.35*prm_mean(%.2f)+0.20*lccp(%.2f) " | |
| "| steps=%d", | |
| valid_tag, | |
| combined_score, | |
| base_combined_score, | |
| modifier_val, | |
| format_cap, | |
| floor_tag, | |
| effective_question_reward, | |
| solution_reward, | |
| _dataset_novelty, | |
| prm_final, | |
| prm_mean, | |
| _sp_lccp if sol_valid else 0.0, | |
| prm_num_steps, | |
| ) | |
| # Shape a consensus-style verification_details dict so downstream | |
| # aggregation (which reads these keys) keeps working unchanged. | |
| verification_details = { | |
| "consensus": { | |
| "has_majority": prm_mean >= 0.5, | |
| "consensus_strength": prm_mean, | |
| "primary_matches_majority": prm_mean >= 0.5, | |
| "answer_diversity": 0, | |
| "majority_answer": None, | |
| "primary_answer": extract_final_answer_numeric_str(solution) or None, | |
| "prm_mean_score": prm_mean, | |
| "prm_min_score": prm_min, | |
| "prm_final_score": prm_final, | |
| "prm_step_scores": prm_result.get("step_scores", []), | |
| "prm_num_steps": prm_num_steps, | |
| "prm_degraded": prm_degraded, | |
| }, | |
| } | |
| return { | |
| "combined_score": combined_score, | |
| "base_combined_score": base_combined_score, | |
| "effective_question_reward": effective_question_reward, # gated (0 when sol invalid) | |
| "question_metrics": question_result, | |
| "solution_metrics": { | |
| "overall_score": solution_reward, | |
| "correctness": prm_mean, | |
| "format_compliance": format_score, | |
| "efficiency": prm_mean, # legacy slot | |
| "consensus_score": prm_mean, # legacy slot | |
| "prm_mean_score": prm_mean, | |
| "prm_min_score": prm_min, | |
| "prm_final_score": prm_final, | |
| "prm_step_scores": prm_result.get("step_scores", []), | |
| "prm_num_steps": prm_num_steps, | |
| "prm_degraded": prm_degraded, | |
| "verification_details": verification_details, | |
| }, | |
| "curriculum_metrics": { | |
| "target_topic": target_topic, | |
| "target_difficulty": target_difficulty, | |
| "detected_topic": question_result["detected_topic"], | |
| "measured_difficulty": question_result["measured_difficulty"], | |
| }, | |
| "expert_metrics": expert_adjustment, | |
| # Chain scoring metrics (Phase 2+; None when use_chain_scoring=False) | |
| "sp_chain_integrity_score": _sp_chain_integrity, | |
| } | |
| # ------------------------------------------------------------------ | |
| # Grounded (GSM8K-anchored) rollouts | |
| # ------------------------------------------------------------------ | |
| # | |
| # Why this exists: self-play rewards are dominated by consensus voting | |
| # between 3 same-model samples, which correlates poorly with GSM8K | |
| # accuracy (all three samples can be wrong in the same way). For the | |
| # grounded path we solve a known GSM8K problem and score the solution | |
| # directly against the gold final answer, which is the only signal | |
| # guaranteed to move the benchmark we actually evaluate on. | |
| # | |
| # The reward: R = 0.50·gt_match + 0.40·process(PRM) + 0.10·format | |
| # | |
| # * gt_match = 1.0 iff the model's Final Answer is mathematically | |
| # equivalent to the GSM8K gold final (via sympy.simplify on the | |
| # extracted numeric string). | |
| # * process = 0.60·prm_final + 0.40·prm_mean (PRM step-level quality) | |
| # * format rewards Step N: lines and a Final Answer: line. | |
| # | |
| # No TripleVerifier call on this path — ground truth obviates consensus. | |
| def _norm_expr_for_match(s: str) -> str: | |
| s = (s or "").strip() | |
| s = s.replace("^", "**") | |
| s = re.sub(r"[,$€£\s]+", "", s) | |
| return s | |
| def _answers_equivalent(cls, pred: str, gold: str) -> bool: | |
| """Return True iff ``pred`` and ``gold`` parse to the same number.""" | |
| if not pred or not gold: | |
| return False | |
| p = cls._norm_expr_for_match(pred) | |
| g = cls._norm_expr_for_match(gold) | |
| if p == g: | |
| return True | |
| try: | |
| diff = simplify( | |
| parse_expr(normalize_for_parse_expr(p)) | |
| - parse_expr(normalize_for_parse_expr(g)) | |
| ) | |
| return bool(diff == 0) | |
| except Exception: | |
| return False | |
| def compute_grounded_reward( | |
| self, | |
| question: str, | |
| solution: str, | |
| gold_final: str, | |
| ) -> Dict[str, object]: | |
| """ | |
| Compute a ground-truth-anchored reward for a solution to a known | |
| GSM8K problem. No TripleVerifier call — the gold final answer | |
| replaces consensus voting as the semantic check. | |
| """ | |
| format_score = self._compute_format_score(solution) | |
| pred_final = extract_final_answer_numeric_str(solution) or "" | |
| gt_match_bool = self._answers_equivalent(pred_final, gold_final) | |
| if gt_match_bool: | |
| gt_match = 1.0 | |
| else: | |
| # Soft numeric proximity: reward near-misses rather than cliffing at 0. | |
| # Gives partial credit proportional to how close the numeric answer is. | |
| # Capped at 0.85 so an exact match (1.0) is always strictly better. | |
| # Non-numeric wrong answers still get 0.0. | |
| try: | |
| _p = float(pred_final.replace(",", "").strip()) | |
| _g = float(gold_final.replace(",", "").strip()) | |
| _denom = max(abs(_g), 1.0) | |
| gt_match = min(0.85, 1.0 / (1.0 + 2.0 * abs(_p - _g) / _denom)) | |
| except (ValueError, TypeError, AttributeError): | |
| gt_match = 0.0 | |
| # Optional PRM step-level quality on grounded rollouts. | |
| # prm_final (last step score) is the strongest single predictor of | |
| # answer correctness. step_accuracy = fraction of steps the PRM | |
| # considers correct — the direct measure of reasoning process quality. | |
| prm_mean = 0.0 | |
| prm_final = 0.0 | |
| prm_step_scores: List[float] = [] | |
| prm_num_steps = 0 | |
| prm_degraded = True | |
| if self.prm_scorer is not None: | |
| prm_result = self.prm_scorer.score_solution( | |
| question=question, solution=solution | |
| ) | |
| prm_degraded = bool(prm_result.get("degraded", False)) | |
| if not prm_degraded: | |
| prm_mean = float(prm_result.get("mean_score", 0.0)) | |
| prm_final = float(prm_result.get("final_score", 0.0)) | |
| prm_step_scores = list(prm_result.get("step_scores", [])) | |
| prm_num_steps = int(prm_result.get("num_steps", 0)) | |
| # Step accuracy: fraction of individual steps rated correct by PRM. | |
| step_accuracy = ( | |
| sum(1.0 for s in prm_step_scores if s > 0.5) / len(prm_step_scores) | |
| if prm_step_scores else 0.0 | |
| ) | |
| # Longest Correct Consecutive Prefix (LCCP): fraction of steps from | |
| # the start that are ALL rated correct before the first failure. | |
| # This captures chain integrity — a broken step 3 makes steps 4+ invalid | |
| # regardless of their individual PRM scores. | |
| # LCCP=1.0 means every step was correct (necessary condition for right answer). | |
| # LCCP=0.0 means step 1 itself was wrong (model never had a valid chain). | |
| if prm_step_scores: | |
| first_fail = next( | |
| (i for i, s in enumerate(prm_step_scores) if s <= 0.5), len(prm_step_scores) | |
| ) | |
| lccp = first_fail / len(prm_step_scores) | |
| else: | |
| lccp = 0.0 | |
| if self.prm_scorer is not None and not prm_degraded: | |
| # process_score: weight prm_final (conclusion step) more than mean | |
| # — the final step is the most critical and most predictive. | |
| process_score = 0.60 * prm_final + 0.40 * prm_mean | |
| combined = ( | |
| 0.50 * gt_match | |
| + 0.40 * process_score | |
| + 0.10 * format_score | |
| ) | |
| _gt_tag = "exact" if gt_match_bool else f"prox={gt_match:.2f}" | |
| components_str = ( | |
| f"0.50×{gt_match:.2f}({_gt_tag}) + 0.40×proc({process_score:.3f}" | |
| f"[fin={prm_final:.2f},mean={prm_mean:.2f}]) + " | |
| f"0.10×fmt({format_score:.3f})" | |
| ) | |
| else: | |
| combined = 0.85 * gt_match + 0.15 * format_score | |
| components_str = ( | |
| f"0.85×{gt_match:.2f} + 0.15×fmt({format_score:.3f})" | |
| ) | |
| # Phase 2+ chain scoring: override process_score, step_accuracy, lccp, | |
| # and combined with formally-verified chain integrity metrics. | |
| # PRM is still called above so its scores remain logged for comparison. | |
| _chain_report = None | |
| if self.use_chain_scoring and self.unified_accuracy_calc is not None: | |
| try: | |
| _chain_report = self.unified_accuracy_calc.compute( | |
| solution=solution, | |
| gold_answer=gold_final, | |
| topic="grounded", | |
| phase="grounded", | |
| ) | |
| process_score = _chain_report.chain_integrity_score | |
| step_accuracy = _chain_report.step_arithmetic_score | |
| lccp = _chain_report.lccp_score | |
| combined = max(0.0, min(1.0, | |
| 0.50 * gt_match + 0.30 * process_score + 0.20 * lccp | |
| )) | |
| components_str = ( | |
| f"0.50×{gt_match:.2f} + 0.30×chain({process_score:.3f}" | |
| f"[arith={_chain_report.step_arithmetic_score:.2f}," | |
| f"dep={_chain_report.step_dependency_score:.2f}]) + " | |
| f"0.20×lccp({lccp:.3f})" | |
| ) | |
| except Exception as _chain_exc: | |
| logger.debug("Unified accuracy calc failed, keeping PRM scores: %s", _chain_exc) | |
| else: | |
| combined = max(0.0, min(1.0, combined)) | |
| # Hard negative mining: wrong-answer solutions still get a partial signal | |
| # proportional to how far they got before the first error (LCCP). | |
| # This prevents gradient starvation on hard problems where no solution in | |
| # the group is fully correct — the model still learns "longer correct prefix | |
| # is better" rather than receiving zero reward for all K samples. | |
| if gt_match < 0.5 and lccp > 0.0 and self.prm_scorer is not None: | |
| # Bonus = 0.15 × LCCP, capped so that a wrong answer (combined ≈ 0.40) | |
| # can never exceed 0.55 — always well below a correct answer (≈ 0.90+). | |
| _hnm_bonus = 0.15 * lccp | |
| combined = min(combined + _hnm_bonus, 0.55) | |
| _chain_depth = first_fail if prm_step_scores else 0 | |
| logger.info( | |
| "Grounded reward: combined=%.3f = %s | pred=%r gold=%r | " | |
| "step_acc=%.0f%% lccp=%.0f%% (chain=%d/%d ok_count=%d) n_steps=%d", | |
| combined, | |
| components_str, | |
| pred_final, | |
| gold_final, | |
| 100 * step_accuracy, | |
| 100 * lccp, | |
| _chain_depth, | |
| len(prm_step_scores), | |
| sum(1 for s in prm_step_scores if s > 0.5), | |
| prm_num_steps, | |
| ) | |
| return { | |
| "combined_score": combined, | |
| "gt_match": gt_match_bool, | |
| # process metrics | |
| "step_accuracy": step_accuracy, | |
| "lccp": lccp, # longest correct consecutive prefix ratio | |
| "prm_mean_score": prm_mean, | |
| "prm_final_score": prm_final, | |
| "prm_step_scores": prm_step_scores, | |
| "prm_num_steps": prm_num_steps, | |
| "prm_degraded": prm_degraded, | |
| # format / answer | |
| "format_score": format_score, | |
| "pred_final": pred_final, | |
| "gold_final": gold_final, | |
| # chain scoring metrics (populated in Phase 2+, None otherwise) | |
| "chain_arith_score": _chain_report.step_arithmetic_score if _chain_report else None, | |
| "chain_dep_score": _chain_report.step_dependency_score if _chain_report else None, | |
| "chain_integrity_score": _chain_report.chain_integrity_score if _chain_report else None, | |
| "first_failure_step": _chain_report.first_failure_step if _chain_report else None, | |
| "final_consistent": _chain_report.final_answer_consistent if _chain_report else None, | |
| } | |
| def rollout_grounded_trajectory(self, qa_pair: Dict[str, str]) -> Trajectory: | |
| """ | |
| Run a rollout on a known GSM8K (question, gold_final) pair. | |
| The policy generates a solution to the real question; reward is | |
| dominated by whether the model's final number matches the gold | |
| final (ground-truth-anchored). | |
| """ | |
| question = str(qa_pair["question"]).strip() | |
| gold_final = str(qa_pair["gold_final"]).strip() | |
| solution_prompt = self.format_solution_prompt(question) | |
| generated_solution, solution_transitions = self.generate_with_logging( | |
| initial_prompt=solution_prompt, | |
| max_tokens=self.max_solution_tokens, | |
| phase="grounded_solution", | |
| ) | |
| reward_result = self.compute_grounded_reward( | |
| question=question, | |
| solution=generated_solution, | |
| gold_final=gold_final, | |
| ) | |
| terminal_reward = float(reward_result["combined_score"]) | |
| trajectory = Trajectory() | |
| for idx, transition in enumerate(solution_transitions): | |
| transition.reward = ( | |
| terminal_reward if idx == len(solution_transitions) - 1 else 0.0 | |
| ) | |
| trajectory.add(transition) | |
| metadata = { | |
| "rollout_source": "grounded", | |
| "curriculum_iteration": self.curriculum_manager.current_iteration, | |
| "target_topic": "grounded_gsm8k", | |
| "target_difficulty": 0.5, | |
| "instruction": "", | |
| "generated_question": question, | |
| "generated_solution": generated_solution, | |
| "question_length": 0, | |
| "solution_length": len(solution_transitions), | |
| "detected_topic": "grounded_gsm8k", | |
| "detected_secondary_topics": [], | |
| "topic_match_score": 1.0, | |
| "estimated_difficulty": 0.5, | |
| "clarity_score": 1.0, | |
| "novelty_scores": {"combined": 0.0}, | |
| "consensus_achieved": bool(reward_result["gt_match"]), | |
| "consensus_strength": 1.0 if reward_result["gt_match"] else 0.0, | |
| "answer_diversity": 0, | |
| "majority_answer": None, | |
| "primary_matches_majority": bool(reward_result["gt_match"]), | |
| "question_reward": 0.0, | |
| "solution_reward": terminal_reward, | |
| "pre_expert_reward": terminal_reward, | |
| "expert_reward_modifier": 0.0, | |
| "expert_phase": "grounded", | |
| "expert_feedback": "ground-truth anchored", | |
| "replay_candidate": False, | |
| "replay_novelty": 0.0, | |
| "replay_added": False, | |
| "combined_reward": terminal_reward, | |
| "reward_breakdown": { | |
| "grounded": True, | |
| "gt_match": bool(reward_result["gt_match"]), | |
| "format_score": float(reward_result["format_score"]), | |
| "pred_final": reward_result["pred_final"], | |
| "gold_final": reward_result["gold_final"], | |
| "prm_mean_score": float(reward_result.get("prm_mean_score", 0.0)), | |
| "prm_num_steps": int(reward_result.get("prm_num_steps", 0)), | |
| "prm_step_scores": list(reward_result.get("prm_step_scores", [])), | |
| "prm_degraded": bool(reward_result.get("prm_degraded", True)), | |
| }, | |
| "topics_in_sweet_spot": self.curriculum_manager.get_sweet_spot_topics(), | |
| "current_focus_topics": self.curriculum_manager.get_current_focus(), | |
| "curriculum_state_snapshot": self.curriculum_manager.get_curriculum_stats(), | |
| "grounded_gt_match": bool(reward_result["gt_match"]), | |
| "grounded_pred_final": reward_result["pred_final"], | |
| "grounded_gold_final": reward_result["gold_final"], | |
| } | |
| trajectory.metadata = metadata | |
| return trajectory | |
| def rollout_trajectory(self) -> Trajectory: | |
| instruction, target_topic, target_difficulty = self.sample_instruction() | |
| question_prompt = self.format_question_generation_prompt(instruction) | |
| generated_question, question_transitions = self.generate_with_logging( | |
| initial_prompt=question_prompt, | |
| max_tokens=self.max_question_tokens, | |
| phase="question_generation", | |
| ) | |
| return self._build_trajectory_from_question( | |
| instruction=instruction, | |
| target_topic=target_topic, | |
| target_difficulty=target_difficulty, | |
| generated_question=generated_question, | |
| question_transitions=question_transitions, | |
| ) | |
| def _build_trajectory_from_question( | |
| self, | |
| instruction: str, | |
| target_topic: str, | |
| target_difficulty: float, | |
| generated_question: str, | |
| question_transitions: Optional[List] = None, | |
| ) -> Trajectory: | |
| trajectory = Trajectory() | |
| question_transitions = question_transitions or [] | |
| solution_prompt = self.format_solution_prompt(generated_question) | |
| generated_solution, solution_transitions = self.generate_with_logging( | |
| initial_prompt=solution_prompt, | |
| max_tokens=self.max_solution_tokens, | |
| phase="solution", | |
| ) | |
| reward_result = self.compute_reward( | |
| question=generated_question, | |
| solution=generated_solution, | |
| target_topic=target_topic, | |
| target_difficulty=target_difficulty, | |
| ) | |
| terminal_reward = float(reward_result["combined_score"]) | |
| all_transitions = question_transitions + solution_transitions | |
| # Terminal-only reward — gae_lambda=1.0 makes A_t = R - V(s_t) for all t. | |
| for idx, transition in enumerate(all_transitions): | |
| transition.reward = ( | |
| terminal_reward if idx == len(all_transitions) - 1 else 0.0 | |
| ) | |
| trajectory.add(transition) | |
| verification = reward_result["solution_metrics"]["verification_details"] | |
| consensus = verification["consensus"] | |
| question_metrics = reward_result["question_metrics"] | |
| metadata = TrajectoryMetadata( | |
| curriculum_iteration=self.curriculum_manager.current_iteration, | |
| target_topic=target_topic, | |
| target_difficulty=target_difficulty, | |
| instruction=instruction, | |
| generated_question=generated_question, | |
| generated_solution=generated_solution, | |
| question_length=len(question_transitions), | |
| solution_length=len(solution_transitions), | |
| detected_topic=str(question_metrics["detected_topic"]["primary_topic"]), | |
| detected_secondary_topics=[ | |
| str(x) for x in question_metrics["detected_topic"]["secondary_topics"] | |
| ], | |
| topic_match_score=float(question_metrics["topic_match"]), | |
| estimated_difficulty=float(question_metrics["measured_difficulty"]), | |
| clarity_score=float(question_metrics["clarity"]), | |
| novelty_scores=dict(question_metrics["novelty"]), | |
| consensus_achieved=bool(consensus["has_majority"]), | |
| consensus_strength=float(consensus["consensus_strength"]), | |
| answer_diversity=int(consensus["answer_diversity"]), | |
| majority_answer=consensus.get("majority_answer"), | |
| primary_matches_majority=bool(consensus["primary_matches_majority"]), | |
| sympy_verified=True, | |
| steps_total=int(consensus.get("prm_num_steps", 0)), | |
| steps_verified_ok=int(consensus.get("prm_num_steps", 0)), | |
| steps_failed=0, | |
| final_answer_ok=bool(consensus.get("primary_matches_majority", False)), | |
| question_reward=float(question_metrics["overall_score"]), | |
| solution_reward=float(reward_result["solution_metrics"]["overall_score"]), | |
| pre_expert_reward=float(reward_result["base_combined_score"]), | |
| expert_reward_modifier=float( | |
| reward_result["expert_metrics"]["reward_modifier"] | |
| ), | |
| expert_phase=str(reward_result["expert_metrics"]["phase"]), | |
| expert_feedback=str(reward_result["expert_metrics"]["feedback"]), | |
| replay_candidate=False, | |
| replay_novelty=0.0, | |
| replay_added=False, | |
| combined_reward=terminal_reward, | |
| reward_breakdown=reward_result, | |
| topics_in_sweet_spot=self.curriculum_manager.get_sweet_spot_topics(), | |
| current_focus_topics=self.curriculum_manager.get_current_focus(), | |
| curriculum_state_snapshot=self.curriculum_manager.get_curriculum_stats(), | |
| ) | |
| metadata_dict = asdict(metadata) | |
| trajectory.metadata = metadata_dict | |
| # Replay admission: requires trajectory.metadata to already exist | |
| # because check_novelty reads metadata["generated_question"]. | |
| is_candidate, reason = self.quality_filter.meets_replay_criteria(metadata_dict) | |
| metadata_dict["replay_candidate"] = is_candidate | |
| if is_candidate: | |
| novelty_score = self.quality_filter.check_novelty( | |
| trajectory, self.replay_buffer.buffer | |
| ) | |
| metadata_dict["replay_novelty"] = float(novelty_score) | |
| if self.quality_filter.is_novel_enough(novelty_score): | |
| quality_score = self.quality_filter.compute_quality_score(metadata_dict) | |
| self.replay_buffer.add_trajectory( | |
| trajectory=trajectory, | |
| metadata=metadata_dict, | |
| iteration=self.curriculum_manager.current_iteration, | |
| quality_score=quality_score, | |
| ) | |
| metadata_dict["replay_added"] = True | |
| else: | |
| metadata_dict["replay_added"] = False | |
| else: | |
| metadata_dict["replay_added"] = False | |
| metadata_dict["replay_reject_reason"] = reason | |
| trajectory.metadata = metadata_dict | |
| return trajectory | |
| def _get_adaptive_replay_ratio(self) -> float: | |
| iteration = self.curriculum_manager.current_iteration | |
| if iteration < 3: | |
| return 0.0 | |
| if iteration < 5: | |
| return 0.15 | |
| buffer_stats = self.replay_buffer.get_buffer_stats(current_iteration=iteration) | |
| buffer_health = float(buffer_stats.get("buffer_health", 0.0)) | |
| if buffer_health >= 0.75: | |
| return 0.3 | |
| if buffer_health >= 0.6: | |
| return 0.25 | |
| return 0.2 | |
| def collect_rollouts( | |
| self, | |
| num_trajectories: int, | |
| verbose: bool = True, | |
| grounded_ratio: float = 0.0, | |
| ) -> List[Trajectory]: | |
| """ | |
| Generate ``num_trajectories`` episodes in-process on the current | |
| device. | |
| Mix: | |
| * ``grounded_ratio`` of rollouts are GSM8K-anchored (real question, | |
| reward scored against gold final answer). These give the policy | |
| a clean gradient toward benchmark correctness and are also ~3x | |
| faster than self-play rollouts (no TripleVerifier call). | |
| * an adaptive fraction is drawn from the replay buffer when buffer | |
| health is good (self-play only). | |
| * the remainder are fresh self-play rollouts. | |
| """ | |
| if num_trajectories <= 0: | |
| return [] | |
| # Defensive .eval() on both policy and value before any generation. | |
| # The first iteration runs rollouts right after model load (HF default | |
| # is .train()). Qwen2.5 has zero dropout so this is currently cosmetic, | |
| # but cheap insurance against any future model swap with stochastic layers. | |
| if self.policy is not None: | |
| self.policy.eval() | |
| if self.value is not None: | |
| self.value.eval() | |
| # Grounded rollouts: only if we actually have QA pairs loaded. | |
| if grounded_ratio > 0.0 and self.grounded_qa_pairs: | |
| num_grounded = int(round(num_trajectories * grounded_ratio)) | |
| num_grounded = min(num_grounded, num_trajectories) | |
| else: | |
| num_grounded = 0 | |
| num_selfplay = num_trajectories - num_grounded | |
| # Within the self-play half, the existing replay-buffer mix applies. | |
| replay_ratio = self._get_adaptive_replay_ratio() | |
| num_replay = int(num_selfplay * replay_ratio) | |
| num_replay = min(num_replay, len(self.replay_buffer)) | |
| num_fresh = max(0, num_selfplay - num_replay) | |
| # ---- Grounded rollouts (GSM8K-anchored) -------------------------- | |
| grounded_trajectories: List[Trajectory] = [] | |
| grounded_correct = 0 | |
| grounded_reward_sum = 0.0 | |
| if num_grounded > 0: | |
| qa_sample = random.sample( | |
| self.grounded_qa_pairs, | |
| k=min(num_grounded, len(self.grounded_qa_pairs)), | |
| ) | |
| # If we asked for more grounded rollouts than we have distinct | |
| # pairs, pad by re-sampling with replacement. | |
| while len(qa_sample) < num_grounded: | |
| qa_sample.append(random.choice(self.grounded_qa_pairs)) | |
| pbar = tqdm( | |
| qa_sample, | |
| desc="Grounded rollouts", | |
| unit="ep", | |
| dynamic_ncols=True, | |
| leave=False, | |
| disable=not verbose, | |
| ) | |
| for qa in pbar: | |
| trajectory = self.rollout_grounded_trajectory(qa) | |
| grounded_trajectories.append(trajectory) | |
| r = float(trajectory.metadata.get("combined_reward", 0.0)) | |
| grounded_reward_sum += r | |
| if bool(trajectory.metadata.get("grounded_gt_match", False)): | |
| grounded_correct += 1 | |
| done = len(grounded_trajectories) | |
| pbar.set_postfix( | |
| acc=f"{grounded_correct / done:.1%}", | |
| reward=f"{grounded_reward_sum / done:+.3f}", | |
| refresh=False, | |
| ) | |
| # ---- Fresh self-play rollouts ------------------------------------ | |
| fresh_trajectories: List[Trajectory] = [] | |
| pbar = tqdm( | |
| range(num_fresh), | |
| desc="Self-play rollouts", | |
| unit="ep", | |
| dynamic_ncols=True, | |
| leave=False, | |
| disable=not verbose, | |
| ) | |
| running_reward = 0.0 | |
| running_ok = 0 | |
| for _ in pbar: | |
| trajectory = self.rollout_trajectory() | |
| trajectory.metadata["rollout_source"] = "fresh" | |
| fresh_trajectories.append(trajectory) | |
| running_reward += float(trajectory.metadata.get("combined_reward", 0.0)) | |
| if trajectory.metadata.get("final_answer_ok", False): | |
| running_ok += 1 | |
| done = len(fresh_trajectories) | |
| pbar.set_postfix( | |
| reward=f"{running_reward / done:+.3f}", | |
| ok=f"{running_ok}/{done}", | |
| refresh=False, | |
| ) | |
| # ---- Replay buffer draws ----------------------------------------- | |
| replay_trajectories = self.replay_buffer.sample_replay_batch( | |
| num_replay, diversity_sample=True | |
| ) | |
| for trajectory in replay_trajectories: | |
| trajectory.metadata["rollout_source"] = "replay" | |
| trajectories = ( | |
| grounded_trajectories + fresh_trajectories + replay_trajectories | |
| ) | |
| random.shuffle(trajectories) | |
| self.last_replay_ratio = replay_ratio | |
| self.last_rollout_mix = { | |
| "fresh": len(fresh_trajectories), | |
| "replay": len(replay_trajectories), | |
| "grounded": len(grounded_trajectories), | |
| } | |
| grounded_count = len(grounded_trajectories) | |
| self.last_grounded_stats = { | |
| "count": grounded_count, | |
| "correct": grounded_correct, | |
| "accuracy": ( | |
| grounded_correct / grounded_count if grounded_count > 0 else 0.0 | |
| ), | |
| "mean_reward": ( | |
| grounded_reward_sum / grounded_count if grounded_count > 0 else 0.0 | |
| ), | |
| } | |
| if verbose: | |
| buffer_stats = self.replay_buffer.get_buffer_stats( | |
| current_iteration=self.curriculum_manager.current_iteration | |
| ) | |
| logger.info( | |
| "Rollout mix: %d grounded + %d fresh + %d replay " | |
| "(grounded_ratio=%.2f, replay_ratio=%.2f, buffer_size=%d, health=%.3f)", | |
| len(grounded_trajectories), | |
| len(fresh_trajectories), | |
| len(replay_trajectories), | |
| grounded_ratio, | |
| replay_ratio, | |
| len(self.replay_buffer), | |
| float(buffer_stats.get("buffer_health", 0.0)), | |
| ) | |
| if grounded_count > 0: | |
| logger.info( | |
| "Grounded accuracy this iter: %d/%d = %.1f%% (mean reward %.3f)", | |
| grounded_correct, | |
| grounded_count, | |
| 100.0 * grounded_correct / grounded_count, | |
| grounded_reward_sum / grounded_count, | |
| ) | |
| self.curriculum_manager.increment_iteration() | |
| self.curriculum_manager.save_state( | |
| iteration=self.curriculum_manager.current_iteration, rollout=None | |
| ) | |
| return trajectories | |