File size: 14,713 Bytes
ec4ae03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74e7a0b
ec4ae03
 
 
 
 
 
74e7a0b
ec4ae03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74e7a0b
 
 
 
ec4ae03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
AxiomForgeAI Math RL Environment.

Wraps CurriculumMathEnvironment from src/rl/math_environment_curriculum.py
to expose an OpenEnv-compatible interface (reset / step / state).

Episode semantics
-----------------
* reset()  β€” Samples a new question from the adaptive curriculum (or a
             grounded QA pair when a dataset is configured).  Returns the
             question in the observation; reward is 0.0.
* step(action) β€” Scores the agent's submitted solution with the full reward
             pipeline (PRM + SymPy + format) and returns reward + feedback.
             done=True always: one question per episode.

Environment variables
---------------------
AXIOMFORGE_DATA_PATH   Path to a JSONL file with {"question", "gold_final"}
                       records (e.g. data/sft/gsm8k_sft.jsonl).  When set,
                       the environment uses grounded QA pairs for questions
                       and ground-truth answer verification.

AXIOMFORGE_PRM_PATH    HuggingFace model ID or local path for the Process
                       Reward Model (default: Qwen/Qwen2.5-Math-PRM-7B).
                       Set to "" to disable PRM scoring (uses SymPy only).

AXIOMFORGE_CURRICULUM_DIR
                       Directory where the CurriculumManager persists its
                       state between runs.  Defaults to
                       "checkpoints/curriculum".
"""

from __future__ import annotations

import json
import logging
import os
import random
from pathlib import Path
from typing import Any, Dict, List, Optional
from uuid import uuid4

from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State

try:
    from ..models import AxiomforgeaiAction, AxiomforgeaiObservation

except ImportError:
    from models import AxiomforgeaiAction, AxiomforgeaiObservation

# ── Heavy RL imports β€” fail gracefully so openenv validate passes even when
#    the ML stack is not installed (e.g. lightweight CI / schema validation).
try:
    import torch
    from src.rl.math_environment_curriculum import CurriculumMathEnvironment
    from src.rl.prm_scorer import ProcessRewardScorer
    from src.sft.solution_format import extract_final_answer_numeric_str

    _RL_AVAILABLE = True
except Exception as _rl_import_err:  # pragma: no cover
    torch = None  # type: ignore[assignment]
    _RL_AVAILABLE = False
    CurriculumMathEnvironment = None  # type: ignore[assignment,misc]
    ProcessRewardScorer = None  # type: ignore[assignment,misc]
    extract_final_answer_numeric_str = None  # type: ignore[assignment]


logger = logging.getLogger(__name__)

# Fallback question used during validation / when no dataset is configured.
_VALIDATION_QUESTION = (
    "A store sells apples for $2 each and oranges for $3 each. "
    "If Sarah buys 4 apples and 3 oranges, how much does she spend in total?"
)
_VALIDATION_GOLD = "17"
_VALIDATION_TOPIC = "basic_arithmetic"
_VALIDATION_DIFFICULTY = 0.1


def _load_qa_pairs(data_path: str) -> List[Dict[str, str]]:
    """Load {"question", "gold_final"} records from a JSONL file."""
    pairs: List[Dict[str, str]] = []
    p = Path(data_path)
    if not p.exists():
        logger.warning("AXIOMFORGE_DATA_PATH not found: %s", data_path)
        return pairs
    with p.open(encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                rec = json.loads(line)
            except json.JSONDecodeError:
                continue
            q = rec.get("question", "").strip()
            g = rec.get("gold_final", "").strip()
            if q and g:
                pairs.append({"question": q, "gold_final": g})
    logger.info("Loaded %d QA pairs from %s", len(pairs), data_path)
    return pairs


class AxiomforgeaiEnvironment(Environment):
    """
    AxiomForgeAI math RL environment for OpenEnv.

    Uses CurriculumMathEnvironment from src/rl/ for adaptive question
    selection and reward computation.  When the ML stack is unavailable
    (e.g. during schema validation), falls back to a lightweight mode
    that uses only the installed openenv-core dependencies.

    Supports concurrent WebSocket sessions β€” each client gets its own
    instance with independent episode state.
    """

    SUPPORTS_CONCURRENT_SESSIONS: bool = True

    def __init__(self) -> None:
        self._state = State(episode_id=str(uuid4()), step_count=0)

        # Per-episode state
        self._current_question: str = ""
        self._gold_final: str = ""
        self._current_topic: str = ""
        self._current_difficulty: float = 0.5

        self._math_env: Optional[Any] = None  # CurriculumMathEnvironment or None

        if torch is not None:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            device = "cpu"

        if not _RL_AVAILABLE:
            logger.warning(
                "RL stack (torch/transformers/sympy) not available β€” "
                "running in schema-validation mode with fixed fallback responses."
            )
            return

        # ── Load grounded QA pairs (optional) ─────────────────────────────
        grounded_qa_pairs: List[Dict[str, str]] = []
        data_path = os.environ.get("AXIOMFORGE_DATA_PATH", "")
        if data_path:
            grounded_qa_pairs = _load_qa_pairs(data_path)

        # ── Load PRM scorer (optional) ────────────────────────────────────
        prm: Optional[Any] = None  # ProcessRewardScorer or None
        prm_path = os.environ.get("AXIOMFORGE_PRM_PATH", "")
        if prm_path:
            try:
                prm = ProcessRewardScorer(
                    model_name=prm_path,
                    device=device,
                    load_in_4bit=True,
                )
                logger.info("PRM loaded: %s", prm_path)
            except Exception as exc:
                logger.warning("PRM load failed (%s) β€” scoring uses SymPy only.", exc)

        # ── Create CurriculumMathEnvironment in scoring-only mode ─────────
        # policy_model=None + tokenizer=None is safe when only reward-computation
        # methods are called (compute_grounded_reward, sample_instruction).
        # Generation methods (generate_with_logging, format_solution_prompt)
        # are NOT called from the server step path β€” the agent supplies solutions.
        curriculum_dir = os.environ.get(
            "AXIOMFORGE_CURRICULUM_DIR", "checkpoints/curriculum"
        )
        try:
            self._math_env = CurriculumMathEnvironment(
                policy_model=None,
                value_model=None,
                tokenizer=None,
                reference_questions=[qa["question"] for qa in grounded_qa_pairs],
                grounded_qa_pairs=grounded_qa_pairs,
                prm_scorer=prm,
                curriculum_checkpoint_dir=curriculum_dir,
                device=device,
            )
            logger.info(
                "CurriculumMathEnvironment ready (scoring-only, %d QA pairs, PRM=%s)",
                len(grounded_qa_pairs),
                "yes" if prm else "no",
            )
        except Exception as exc:
            logger.warning(
                "CurriculumMathEnvironment init failed (%s) β€” "
                "falling back to validation mode.",
                exc,
            )
            self._math_env = None

    # ------------------------------------------------------------------
    # OpenEnv interface
    # ------------------------------------------------------------------

    def reset(
        self,
        qa: Optional[Dict[str, str]] = None,
    ) -> AxiomforgeaiObservation:
        """
        Reset the environment and begin a new episode.

        Args:
            qa: Optional ``{"question": str, "gold_final": str}`` dict.
                When supplied the environment is seeded with this specific
                question and gold answer β€” used by the training loop for
                difficulty-sampled grounded episodes.  When omitted the
                environment draws from its internal grounded QA pool (if
                configured) or falls back to the curriculum instruction.

        Returns:
            AxiomforgeaiObservation with the question populated; reward=0.0.
        """
        self._state = State(episode_id=str(uuid4()), step_count=0)

        if qa is not None:
            # Caller-supplied episode β€” honour it exactly.
            self._current_question = qa.get("question", "").strip()
            self._gold_final = qa.get("gold_final", "").strip()
            self._current_topic = qa.get("topic", "grounded")
            self._current_difficulty = float(qa.get("difficulty", 0.5))
        elif self._math_env is not None:
            try:
                instruction, topic, difficulty = self._math_env.sample_instruction()
                self._current_topic = topic
                self._current_difficulty = float(difficulty)
                if self._math_env.grounded_qa_pairs:
                    _qa = random.choice(self._math_env.grounded_qa_pairs)
                    self._current_question = _qa["question"]
                    self._gold_final = _qa["gold_final"]
                else:
                    self._current_question = instruction
                    self._gold_final = ""
            except Exception as exc:
                logger.warning("sample_instruction failed, using fallback: %s", exc)
                self._current_question = _VALIDATION_QUESTION
                self._gold_final = _VALIDATION_GOLD
                self._current_topic = _VALIDATION_TOPIC
                self._current_difficulty = _VALIDATION_DIFFICULTY
        else:
            self._current_question = _VALIDATION_QUESTION
            self._gold_final = _VALIDATION_GOLD
            self._current_topic = _VALIDATION_TOPIC
            self._current_difficulty = _VALIDATION_DIFFICULTY

        return AxiomforgeaiObservation(
            question=self._current_question,
            topic=self._current_topic,
            difficulty=self._current_difficulty,
            feedback="",
            done=False,
            reward=0.0,
        )

    def step(self, action: AxiomforgeaiAction) -> AxiomforgeaiObservation:  # type: ignore[override]
        """
        Score the agent's submitted solution.

        Uses compute_grounded_reward from CurriculumMathEnvironment when
        available (PRM + SymPy + format scoring).  Falls back to numeric
        answer extraction when the full RL stack is not loaded.

        Args:
            action: AxiomforgeaiAction containing the solution text.

        Returns:
            AxiomforgeaiObservation with reward, feedback, and metadata.
            done=True β€” one question per episode.
        """
        self._state.step_count += 1
        solution = action.solution

        reward: float = 0.0
        feedback: str = ""
        metadata: Dict[str, Any] = {}

        if self._math_env is not None and self._current_question:
            try:
                reward_result = self._math_env.compute_grounded_reward(
                    question=self._current_question,
                    solution=solution,
                    gold_final=self._gold_final,
                )
                reward = float(reward_result.get("combined_score", 0.0))
                gt = reward_result.get("gt_match", False)
                step_acc = reward_result.get("step_accuracy", 0.0)
                lccp = reward_result.get("lccp", 0.0)
                pred = reward_result.get("pred_final", "")
                feedback = (
                    f"gt_match={gt} pred={pred!r} gold={self._gold_final!r} "
                    f"step_acc={step_acc:.2f} lccp={lccp:.2f}"
                )
                # Serialise reward breakdown into metadata; skip non-serialisable lists.
                metadata = {
                    k: v
                    for k, v in reward_result.items()
                    if not isinstance(v, list)
                }
            except Exception as exc:
                logger.warning("compute_grounded_reward failed: %s", exc)
                reward, feedback, metadata = self._fallback_score(solution)
        else:
            reward, feedback, metadata = self._fallback_score(solution)

        return AxiomforgeaiObservation(
            question=self._current_question,
            topic=self._current_topic,
            difficulty=self._current_difficulty,
            feedback=feedback,
            done=True,
            reward=reward,
            metadata=metadata,
        )

    # ------------------------------------------------------------------
    # Helpers
    # ------------------------------------------------------------------

    def _fallback_score(
        self, solution: str
    ) -> tuple[float, str, Dict[str, Any]]:
        """Lightweight scoring used when the full RL stack is unavailable."""
        pred: str = ""
        if extract_final_answer_numeric_str is not None:
            pred = extract_final_answer_numeric_str(solution) or ""
        reward = 1.0 if pred and pred == self._gold_final else 0.0
        feedback = f"pred={pred!r} gold={self._gold_final!r}"
        return reward, feedback, {"pred_final": pred, "gold_final": self._gold_final}

    def close(self) -> None:
        """
        Persist curriculum state and release resources.

        Call once at the end of a training run so the CurriculumManager's
        per-topic statistics are saved to disk and can be resumed on the
        next run.  Safe to call multiple times.
        """
        if self._math_env is not None:
            try:
                self._math_env.curriculum_manager.save_state(
                    iteration=self._math_env.curriculum_manager.current_iteration,
                    rollout=None,
                )
                logger.info(
                    "Curriculum state saved (iteration %d).",
                    self._math_env.curriculum_manager.current_iteration,
                )
            except Exception as exc:
                logger.warning("close(): curriculum save failed β€” %s", exc)

    @property
    def state(self) -> State:
        """Return the current episode state (episode_id + step_count)."""
        return self._state