Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| AxiomForgeAI Math RL Environment. | |
| Wraps CurriculumMathEnvironment from src/rl/math_environment_curriculum.py | |
| to expose an OpenEnv-compatible interface (reset / step / state). | |
| Episode semantics | |
| ----------------- | |
| * reset() β Samples a new question from the adaptive curriculum (or a | |
| grounded QA pair when a dataset is configured). Returns the | |
| question in the observation; reward is 0.0. | |
| * step(action) β Scores the agent's submitted solution with the full reward | |
| pipeline (PRM + SymPy + format) and returns reward + feedback. | |
| done=True always: one question per episode. | |
| Environment variables | |
| --------------------- | |
| AXIOMFORGE_DATA_PATH Path to a JSONL file with {"question", "gold_final"} | |
| records (e.g. data/sft/gsm8k_sft.jsonl). When set, | |
| the environment uses grounded QA pairs for questions | |
| and ground-truth answer verification. | |
| AXIOMFORGE_PRM_PATH HuggingFace model ID or local path for the Process | |
| Reward Model (default: Qwen/Qwen2.5-Math-PRM-7B). | |
| Set to "" to disable PRM scoring (uses SymPy only). | |
| AXIOMFORGE_CURRICULUM_DIR | |
| Directory where the CurriculumManager persists its | |
| state between runs. Defaults to | |
| "checkpoints/curriculum". | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import logging | |
| import os | |
| import random | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional | |
| from uuid import uuid4 | |
| from openenv.core.env_server.interfaces import Environment | |
| from openenv.core.env_server.types import State | |
| try: | |
| from ..models import AxiomforgeaiAction, AxiomforgeaiObservation | |
| except ImportError: | |
| from models import AxiomforgeaiAction, AxiomforgeaiObservation | |
| # ββ Heavy RL imports β fail gracefully so openenv validate passes even when | |
| # the ML stack is not installed (e.g. lightweight CI / schema validation). | |
| try: | |
| import torch | |
| from src.rl.math_environment_curriculum import CurriculumMathEnvironment | |
| from src.rl.prm_scorer import ProcessRewardScorer | |
| from src.sft.solution_format import extract_final_answer_numeric_str | |
| _RL_AVAILABLE = True | |
| except Exception as _rl_import_err: # pragma: no cover | |
| torch = None # type: ignore[assignment] | |
| _RL_AVAILABLE = False | |
| CurriculumMathEnvironment = None # type: ignore[assignment,misc] | |
| ProcessRewardScorer = None # type: ignore[assignment,misc] | |
| extract_final_answer_numeric_str = None # type: ignore[assignment] | |
| logger = logging.getLogger(__name__) | |
| # Fallback question used during validation / when no dataset is configured. | |
| _VALIDATION_QUESTION = ( | |
| "A store sells apples for $2 each and oranges for $3 each. " | |
| "If Sarah buys 4 apples and 3 oranges, how much does she spend in total?" | |
| ) | |
| _VALIDATION_GOLD = "17" | |
| _VALIDATION_TOPIC = "basic_arithmetic" | |
| _VALIDATION_DIFFICULTY = 0.1 | |
| def _load_qa_pairs(data_path: str) -> List[Dict[str, str]]: | |
| """Load {"question", "gold_final"} records from a JSONL file.""" | |
| pairs: List[Dict[str, str]] = [] | |
| p = Path(data_path) | |
| if not p.exists(): | |
| logger.warning("AXIOMFORGE_DATA_PATH not found: %s", data_path) | |
| return pairs | |
| with p.open(encoding="utf-8") as f: | |
| for line in f: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| try: | |
| rec = json.loads(line) | |
| except json.JSONDecodeError: | |
| continue | |
| q = rec.get("question", "").strip() | |
| g = rec.get("gold_final", "").strip() | |
| if q and g: | |
| pairs.append({"question": q, "gold_final": g}) | |
| logger.info("Loaded %d QA pairs from %s", len(pairs), data_path) | |
| return pairs | |
| class AxiomforgeaiEnvironment(Environment): | |
| """ | |
| AxiomForgeAI math RL environment for OpenEnv. | |
| Uses CurriculumMathEnvironment from src/rl/ for adaptive question | |
| selection and reward computation. When the ML stack is unavailable | |
| (e.g. during schema validation), falls back to a lightweight mode | |
| that uses only the installed openenv-core dependencies. | |
| Supports concurrent WebSocket sessions β each client gets its own | |
| instance with independent episode state. | |
| """ | |
| SUPPORTS_CONCURRENT_SESSIONS: bool = True | |
| def __init__(self) -> None: | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| # Per-episode state | |
| self._current_question: str = "" | |
| self._gold_final: str = "" | |
| self._current_topic: str = "" | |
| self._current_difficulty: float = 0.5 | |
| self._math_env: Optional[Any] = None # CurriculumMathEnvironment or None | |
| if torch is not None: | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| else: | |
| device = "cpu" | |
| if not _RL_AVAILABLE: | |
| logger.warning( | |
| "RL stack (torch/transformers/sympy) not available β " | |
| "running in schema-validation mode with fixed fallback responses." | |
| ) | |
| return | |
| # ββ Load grounded QA pairs (optional) βββββββββββββββββββββββββββββ | |
| grounded_qa_pairs: List[Dict[str, str]] = [] | |
| data_path = os.environ.get("AXIOMFORGE_DATA_PATH", "") | |
| if data_path: | |
| grounded_qa_pairs = _load_qa_pairs(data_path) | |
| # ββ Load PRM scorer (optional) ββββββββββββββββββββββββββββββββββββ | |
| prm: Optional[Any] = None # ProcessRewardScorer or None | |
| prm_path = os.environ.get("AXIOMFORGE_PRM_PATH", "") | |
| if prm_path: | |
| try: | |
| prm = ProcessRewardScorer( | |
| model_name=prm_path, | |
| device=device, | |
| load_in_4bit=True, | |
| ) | |
| logger.info("PRM loaded: %s", prm_path) | |
| except Exception as exc: | |
| logger.warning("PRM load failed (%s) β scoring uses SymPy only.", exc) | |
| # ββ Create CurriculumMathEnvironment in scoring-only mode βββββββββ | |
| # policy_model=None + tokenizer=None is safe when only reward-computation | |
| # methods are called (compute_grounded_reward, sample_instruction). | |
| # Generation methods (generate_with_logging, format_solution_prompt) | |
| # are NOT called from the server step path β the agent supplies solutions. | |
| curriculum_dir = os.environ.get( | |
| "AXIOMFORGE_CURRICULUM_DIR", "checkpoints/curriculum" | |
| ) | |
| try: | |
| self._math_env = CurriculumMathEnvironment( | |
| policy_model=None, | |
| value_model=None, | |
| tokenizer=None, | |
| reference_questions=[qa["question"] for qa in grounded_qa_pairs], | |
| grounded_qa_pairs=grounded_qa_pairs, | |
| prm_scorer=prm, | |
| curriculum_checkpoint_dir=curriculum_dir, | |
| device=device, | |
| ) | |
| logger.info( | |
| "CurriculumMathEnvironment ready (scoring-only, %d QA pairs, PRM=%s)", | |
| len(grounded_qa_pairs), | |
| "yes" if prm else "no", | |
| ) | |
| except Exception as exc: | |
| logger.warning( | |
| "CurriculumMathEnvironment init failed (%s) β " | |
| "falling back to validation mode.", | |
| exc, | |
| ) | |
| self._math_env = None | |
| # ------------------------------------------------------------------ | |
| # OpenEnv interface | |
| # ------------------------------------------------------------------ | |
| def reset( | |
| self, | |
| qa: Optional[Dict[str, str]] = None, | |
| ) -> AxiomforgeaiObservation: | |
| """ | |
| Reset the environment and begin a new episode. | |
| Args: | |
| qa: Optional ``{"question": str, "gold_final": str}`` dict. | |
| When supplied the environment is seeded with this specific | |
| question and gold answer β used by the training loop for | |
| difficulty-sampled grounded episodes. When omitted the | |
| environment draws from its internal grounded QA pool (if | |
| configured) or falls back to the curriculum instruction. | |
| Returns: | |
| AxiomforgeaiObservation with the question populated; reward=0.0. | |
| """ | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| if qa is not None: | |
| # Caller-supplied episode β honour it exactly. | |
| self._current_question = qa.get("question", "").strip() | |
| self._gold_final = qa.get("gold_final", "").strip() | |
| self._current_topic = qa.get("topic", "grounded") | |
| self._current_difficulty = float(qa.get("difficulty", 0.5)) | |
| elif self._math_env is not None: | |
| try: | |
| instruction, topic, difficulty = self._math_env.sample_instruction() | |
| self._current_topic = topic | |
| self._current_difficulty = float(difficulty) | |
| if self._math_env.grounded_qa_pairs: | |
| _qa = random.choice(self._math_env.grounded_qa_pairs) | |
| self._current_question = _qa["question"] | |
| self._gold_final = _qa["gold_final"] | |
| else: | |
| self._current_question = instruction | |
| self._gold_final = "" | |
| except Exception as exc: | |
| logger.warning("sample_instruction failed, using fallback: %s", exc) | |
| self._current_question = _VALIDATION_QUESTION | |
| self._gold_final = _VALIDATION_GOLD | |
| self._current_topic = _VALIDATION_TOPIC | |
| self._current_difficulty = _VALIDATION_DIFFICULTY | |
| else: | |
| self._current_question = _VALIDATION_QUESTION | |
| self._gold_final = _VALIDATION_GOLD | |
| self._current_topic = _VALIDATION_TOPIC | |
| self._current_difficulty = _VALIDATION_DIFFICULTY | |
| return AxiomforgeaiObservation( | |
| question=self._current_question, | |
| topic=self._current_topic, | |
| difficulty=self._current_difficulty, | |
| feedback="", | |
| done=False, | |
| reward=0.0, | |
| ) | |
| def step(self, action: AxiomforgeaiAction) -> AxiomforgeaiObservation: # type: ignore[override] | |
| """ | |
| Score the agent's submitted solution. | |
| Uses compute_grounded_reward from CurriculumMathEnvironment when | |
| available (PRM + SymPy + format scoring). Falls back to numeric | |
| answer extraction when the full RL stack is not loaded. | |
| Args: | |
| action: AxiomforgeaiAction containing the solution text. | |
| Returns: | |
| AxiomforgeaiObservation with reward, feedback, and metadata. | |
| done=True β one question per episode. | |
| """ | |
| self._state.step_count += 1 | |
| solution = action.solution | |
| reward: float = 0.0 | |
| feedback: str = "" | |
| metadata: Dict[str, Any] = {} | |
| if self._math_env is not None and self._current_question: | |
| try: | |
| reward_result = self._math_env.compute_grounded_reward( | |
| question=self._current_question, | |
| solution=solution, | |
| gold_final=self._gold_final, | |
| ) | |
| reward = float(reward_result.get("combined_score", 0.0)) | |
| gt = reward_result.get("gt_match", False) | |
| step_acc = reward_result.get("step_accuracy", 0.0) | |
| lccp = reward_result.get("lccp", 0.0) | |
| pred = reward_result.get("pred_final", "") | |
| feedback = ( | |
| f"gt_match={gt} pred={pred!r} gold={self._gold_final!r} " | |
| f"step_acc={step_acc:.2f} lccp={lccp:.2f}" | |
| ) | |
| # Serialise reward breakdown into metadata; skip non-serialisable lists. | |
| metadata = { | |
| k: v | |
| for k, v in reward_result.items() | |
| if not isinstance(v, list) | |
| } | |
| except Exception as exc: | |
| logger.warning("compute_grounded_reward failed: %s", exc) | |
| reward, feedback, metadata = self._fallback_score(solution) | |
| else: | |
| reward, feedback, metadata = self._fallback_score(solution) | |
| return AxiomforgeaiObservation( | |
| question=self._current_question, | |
| topic=self._current_topic, | |
| difficulty=self._current_difficulty, | |
| feedback=feedback, | |
| done=True, | |
| reward=reward, | |
| metadata=metadata, | |
| ) | |
| # ------------------------------------------------------------------ | |
| # Helpers | |
| # ------------------------------------------------------------------ | |
| def _fallback_score( | |
| self, solution: str | |
| ) -> tuple[float, str, Dict[str, Any]]: | |
| """Lightweight scoring used when the full RL stack is unavailable.""" | |
| pred: str = "" | |
| if extract_final_answer_numeric_str is not None: | |
| pred = extract_final_answer_numeric_str(solution) or "" | |
| reward = 1.0 if pred and pred == self._gold_final else 0.0 | |
| feedback = f"pred={pred!r} gold={self._gold_final!r}" | |
| return reward, feedback, {"pred_final": pred, "gold_final": self._gold_final} | |
| def close(self) -> None: | |
| """ | |
| Persist curriculum state and release resources. | |
| Call once at the end of a training run so the CurriculumManager's | |
| per-topic statistics are saved to disk and can be resumed on the | |
| next run. Safe to call multiple times. | |
| """ | |
| if self._math_env is not None: | |
| try: | |
| self._math_env.curriculum_manager.save_state( | |
| iteration=self._math_env.curriculum_manager.current_iteration, | |
| rollout=None, | |
| ) | |
| logger.info( | |
| "Curriculum state saved (iteration %d).", | |
| self._math_env.curriculum_manager.current_iteration, | |
| ) | |
| except Exception as exc: | |
| logger.warning("close(): curriculum save failed β %s", exc) | |
| def state(self) -> State: | |
| """Return the current episode state (episode_id + step_count).""" | |
| return self._state | |