AxiomForgeAI / server /AxiomForgeAI_environment.py
jampuramprem's picture
Make torch import optional for CPU-only Space deployment
74e7a0b
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
AxiomForgeAI Math RL Environment.
Wraps CurriculumMathEnvironment from src/rl/math_environment_curriculum.py
to expose an OpenEnv-compatible interface (reset / step / state).
Episode semantics
-----------------
* reset() β€” Samples a new question from the adaptive curriculum (or a
grounded QA pair when a dataset is configured). Returns the
question in the observation; reward is 0.0.
* step(action) β€” Scores the agent's submitted solution with the full reward
pipeline (PRM + SymPy + format) and returns reward + feedback.
done=True always: one question per episode.
Environment variables
---------------------
AXIOMFORGE_DATA_PATH Path to a JSONL file with {"question", "gold_final"}
records (e.g. data/sft/gsm8k_sft.jsonl). When set,
the environment uses grounded QA pairs for questions
and ground-truth answer verification.
AXIOMFORGE_PRM_PATH HuggingFace model ID or local path for the Process
Reward Model (default: Qwen/Qwen2.5-Math-PRM-7B).
Set to "" to disable PRM scoring (uses SymPy only).
AXIOMFORGE_CURRICULUM_DIR
Directory where the CurriculumManager persists its
state between runs. Defaults to
"checkpoints/curriculum".
"""
from __future__ import annotations
import json
import logging
import os
import random
from pathlib import Path
from typing import Any, Dict, List, Optional
from uuid import uuid4
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State
try:
from ..models import AxiomforgeaiAction, AxiomforgeaiObservation
except ImportError:
from models import AxiomforgeaiAction, AxiomforgeaiObservation
# ── Heavy RL imports β€” fail gracefully so openenv validate passes even when
# the ML stack is not installed (e.g. lightweight CI / schema validation).
try:
import torch
from src.rl.math_environment_curriculum import CurriculumMathEnvironment
from src.rl.prm_scorer import ProcessRewardScorer
from src.sft.solution_format import extract_final_answer_numeric_str
_RL_AVAILABLE = True
except Exception as _rl_import_err: # pragma: no cover
torch = None # type: ignore[assignment]
_RL_AVAILABLE = False
CurriculumMathEnvironment = None # type: ignore[assignment,misc]
ProcessRewardScorer = None # type: ignore[assignment,misc]
extract_final_answer_numeric_str = None # type: ignore[assignment]
logger = logging.getLogger(__name__)
# Fallback question used during validation / when no dataset is configured.
_VALIDATION_QUESTION = (
"A store sells apples for $2 each and oranges for $3 each. "
"If Sarah buys 4 apples and 3 oranges, how much does she spend in total?"
)
_VALIDATION_GOLD = "17"
_VALIDATION_TOPIC = "basic_arithmetic"
_VALIDATION_DIFFICULTY = 0.1
def _load_qa_pairs(data_path: str) -> List[Dict[str, str]]:
"""Load {"question", "gold_final"} records from a JSONL file."""
pairs: List[Dict[str, str]] = []
p = Path(data_path)
if not p.exists():
logger.warning("AXIOMFORGE_DATA_PATH not found: %s", data_path)
return pairs
with p.open(encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
rec = json.loads(line)
except json.JSONDecodeError:
continue
q = rec.get("question", "").strip()
g = rec.get("gold_final", "").strip()
if q and g:
pairs.append({"question": q, "gold_final": g})
logger.info("Loaded %d QA pairs from %s", len(pairs), data_path)
return pairs
class AxiomforgeaiEnvironment(Environment):
"""
AxiomForgeAI math RL environment for OpenEnv.
Uses CurriculumMathEnvironment from src/rl/ for adaptive question
selection and reward computation. When the ML stack is unavailable
(e.g. during schema validation), falls back to a lightweight mode
that uses only the installed openenv-core dependencies.
Supports concurrent WebSocket sessions β€” each client gets its own
instance with independent episode state.
"""
SUPPORTS_CONCURRENT_SESSIONS: bool = True
def __init__(self) -> None:
self._state = State(episode_id=str(uuid4()), step_count=0)
# Per-episode state
self._current_question: str = ""
self._gold_final: str = ""
self._current_topic: str = ""
self._current_difficulty: float = 0.5
self._math_env: Optional[Any] = None # CurriculumMathEnvironment or None
if torch is not None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = "cpu"
if not _RL_AVAILABLE:
logger.warning(
"RL stack (torch/transformers/sympy) not available β€” "
"running in schema-validation mode with fixed fallback responses."
)
return
# ── Load grounded QA pairs (optional) ─────────────────────────────
grounded_qa_pairs: List[Dict[str, str]] = []
data_path = os.environ.get("AXIOMFORGE_DATA_PATH", "")
if data_path:
grounded_qa_pairs = _load_qa_pairs(data_path)
# ── Load PRM scorer (optional) ────────────────────────────────────
prm: Optional[Any] = None # ProcessRewardScorer or None
prm_path = os.environ.get("AXIOMFORGE_PRM_PATH", "")
if prm_path:
try:
prm = ProcessRewardScorer(
model_name=prm_path,
device=device,
load_in_4bit=True,
)
logger.info("PRM loaded: %s", prm_path)
except Exception as exc:
logger.warning("PRM load failed (%s) β€” scoring uses SymPy only.", exc)
# ── Create CurriculumMathEnvironment in scoring-only mode ─────────
# policy_model=None + tokenizer=None is safe when only reward-computation
# methods are called (compute_grounded_reward, sample_instruction).
# Generation methods (generate_with_logging, format_solution_prompt)
# are NOT called from the server step path β€” the agent supplies solutions.
curriculum_dir = os.environ.get(
"AXIOMFORGE_CURRICULUM_DIR", "checkpoints/curriculum"
)
try:
self._math_env = CurriculumMathEnvironment(
policy_model=None,
value_model=None,
tokenizer=None,
reference_questions=[qa["question"] for qa in grounded_qa_pairs],
grounded_qa_pairs=grounded_qa_pairs,
prm_scorer=prm,
curriculum_checkpoint_dir=curriculum_dir,
device=device,
)
logger.info(
"CurriculumMathEnvironment ready (scoring-only, %d QA pairs, PRM=%s)",
len(grounded_qa_pairs),
"yes" if prm else "no",
)
except Exception as exc:
logger.warning(
"CurriculumMathEnvironment init failed (%s) β€” "
"falling back to validation mode.",
exc,
)
self._math_env = None
# ------------------------------------------------------------------
# OpenEnv interface
# ------------------------------------------------------------------
def reset(
self,
qa: Optional[Dict[str, str]] = None,
) -> AxiomforgeaiObservation:
"""
Reset the environment and begin a new episode.
Args:
qa: Optional ``{"question": str, "gold_final": str}`` dict.
When supplied the environment is seeded with this specific
question and gold answer β€” used by the training loop for
difficulty-sampled grounded episodes. When omitted the
environment draws from its internal grounded QA pool (if
configured) or falls back to the curriculum instruction.
Returns:
AxiomforgeaiObservation with the question populated; reward=0.0.
"""
self._state = State(episode_id=str(uuid4()), step_count=0)
if qa is not None:
# Caller-supplied episode β€” honour it exactly.
self._current_question = qa.get("question", "").strip()
self._gold_final = qa.get("gold_final", "").strip()
self._current_topic = qa.get("topic", "grounded")
self._current_difficulty = float(qa.get("difficulty", 0.5))
elif self._math_env is not None:
try:
instruction, topic, difficulty = self._math_env.sample_instruction()
self._current_topic = topic
self._current_difficulty = float(difficulty)
if self._math_env.grounded_qa_pairs:
_qa = random.choice(self._math_env.grounded_qa_pairs)
self._current_question = _qa["question"]
self._gold_final = _qa["gold_final"]
else:
self._current_question = instruction
self._gold_final = ""
except Exception as exc:
logger.warning("sample_instruction failed, using fallback: %s", exc)
self._current_question = _VALIDATION_QUESTION
self._gold_final = _VALIDATION_GOLD
self._current_topic = _VALIDATION_TOPIC
self._current_difficulty = _VALIDATION_DIFFICULTY
else:
self._current_question = _VALIDATION_QUESTION
self._gold_final = _VALIDATION_GOLD
self._current_topic = _VALIDATION_TOPIC
self._current_difficulty = _VALIDATION_DIFFICULTY
return AxiomforgeaiObservation(
question=self._current_question,
topic=self._current_topic,
difficulty=self._current_difficulty,
feedback="",
done=False,
reward=0.0,
)
def step(self, action: AxiomforgeaiAction) -> AxiomforgeaiObservation: # type: ignore[override]
"""
Score the agent's submitted solution.
Uses compute_grounded_reward from CurriculumMathEnvironment when
available (PRM + SymPy + format scoring). Falls back to numeric
answer extraction when the full RL stack is not loaded.
Args:
action: AxiomforgeaiAction containing the solution text.
Returns:
AxiomforgeaiObservation with reward, feedback, and metadata.
done=True β€” one question per episode.
"""
self._state.step_count += 1
solution = action.solution
reward: float = 0.0
feedback: str = ""
metadata: Dict[str, Any] = {}
if self._math_env is not None and self._current_question:
try:
reward_result = self._math_env.compute_grounded_reward(
question=self._current_question,
solution=solution,
gold_final=self._gold_final,
)
reward = float(reward_result.get("combined_score", 0.0))
gt = reward_result.get("gt_match", False)
step_acc = reward_result.get("step_accuracy", 0.0)
lccp = reward_result.get("lccp", 0.0)
pred = reward_result.get("pred_final", "")
feedback = (
f"gt_match={gt} pred={pred!r} gold={self._gold_final!r} "
f"step_acc={step_acc:.2f} lccp={lccp:.2f}"
)
# Serialise reward breakdown into metadata; skip non-serialisable lists.
metadata = {
k: v
for k, v in reward_result.items()
if not isinstance(v, list)
}
except Exception as exc:
logger.warning("compute_grounded_reward failed: %s", exc)
reward, feedback, metadata = self._fallback_score(solution)
else:
reward, feedback, metadata = self._fallback_score(solution)
return AxiomforgeaiObservation(
question=self._current_question,
topic=self._current_topic,
difficulty=self._current_difficulty,
feedback=feedback,
done=True,
reward=reward,
metadata=metadata,
)
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _fallback_score(
self, solution: str
) -> tuple[float, str, Dict[str, Any]]:
"""Lightweight scoring used when the full RL stack is unavailable."""
pred: str = ""
if extract_final_answer_numeric_str is not None:
pred = extract_final_answer_numeric_str(solution) or ""
reward = 1.0 if pred and pred == self._gold_final else 0.0
feedback = f"pred={pred!r} gold={self._gold_final!r}"
return reward, feedback, {"pred_final": pred, "gold_final": self._gold_final}
def close(self) -> None:
"""
Persist curriculum state and release resources.
Call once at the end of a training run so the CurriculumManager's
per-topic statistics are saved to disk and can be resumed on the
next run. Safe to call multiple times.
"""
if self._math_env is not None:
try:
self._math_env.curriculum_manager.save_state(
iteration=self._math_env.curriculum_manager.current_iteration,
rollout=None,
)
logger.info(
"Curriculum state saved (iteration %d).",
self._math_env.curriculum_manager.current_iteration,
)
except Exception as exc:
logger.warning("close(): curriculum save failed β€” %s", exc)
@property
def state(self) -> State:
"""Return the current episode state (episode_id + step_count)."""
return self._state