"""OpenEnv :class:`Environment` subclass for PhysiX-Live. Owns one episode's lifecycle (state + budget + termination) and orchestrates the parser/simulator/metrics/reward modules. No scoring logic lives here. """ from __future__ import annotations import logging import uuid from typing import Any, Optional import numpy as np from openenv.core.env_server import Environment from physix.models import ( CONVERGENCE_THRESHOLD, DEFAULT_MAX_TURNS, HistoryEntry, PhysiXAction, PhysiXObservation, PhysiXState, RewardBreakdown, ) from physix.systems import PhysicalSystem, SystemTier, get_system, list_systems_by_tier from physix.systems.base import TrajectoryData from physix.verifier import ( ParseError, SimulationError, compute_amplitude_match, compute_frequency_match, compute_match, compute_reward, compute_shape_match, parse_equation, residual_summary, simulate_hypothesis, summarize_mismatch, ) _log = logging.getLogger(__name__) class PhysiXEnvironment(Environment[PhysiXAction, PhysiXObservation, PhysiXState]): """OpenEnv environment that drives one episode of equation discovery.""" def get_metadata(self): # type: ignore[override] from openenv.core.env_server.types import EnvironmentMetadata return EnvironmentMetadata( name="PhysiX-Live", description=( "Equation-discovery RL environment. " "The agent is shown a noisy trajectory from an unknown physical system " "(damped spring, free fall, pendulum …) and must output an ODE whose " "numerical simulation matches the trajectory. " "Reward = R² trajectory match + simplicity + format validity. " "Trained with GRPO on Qwen2.5-3B-Instruct." ), version="1.0.0", author="Pratyush-01", documentation_url="https://huggingface.co/spaces/Pratyush-01/physix-live", ) def __init__( self, *, max_turns: int = DEFAULT_MAX_TURNS, train_tiers: tuple[SystemTier, ...] = (SystemTier.TIER_1, SystemTier.TIER_2), seed: Optional[int] = None, ) -> None: super().__init__() self._max_turns = max_turns self._train_tiers = train_tiers self._rng = np.random.default_rng(seed) self._state = PhysiXState(max_turns=max_turns) self._system: Optional[PhysicalSystem] = None self._trajectory: Optional[TrajectoryData] = None self._history: list[HistoryEntry] = [] def reset( self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs: Any, ) -> PhysiXObservation: """Start a new episode. Pass ``system_id=`` to force a specific system.""" if seed is not None: self._rng = np.random.default_rng(seed) forced_id = kwargs.get("system_id") chosen_id = forced_id or self._sample_training_system_id() self._system = get_system(chosen_id) self._trajectory = self._system.simulate(self._rng) self._history = [] self._state = PhysiXState( episode_id=episode_id or str(uuid.uuid4()), step_count=0, system_id=chosen_id, ground_truth_equation=self._system.ground_truth_equation(), ground_truth_params=dict(self._system.parameters), last_reward_total=0.0, converged=False, max_turns=self._max_turns, ) return self._build_observation( mismatch_summary="", reward_breakdown=RewardBreakdown(), ) def step( self, action: PhysiXAction, timeout_s: Optional[float] = None, **kwargs: Any, ) -> PhysiXObservation: del timeout_s, kwargs # accepted for OpenEnv API conformance, unused if self._system is None or self._trajectory is None: raise RuntimeError("step() called before reset(); call reset() first.") self._state.step_count = self._state.step_count + 1 breakdown, mismatch_text = self._score_hypothesis(action) self._record_history(action, breakdown, mismatch_text) self._state.last_reward_total = breakdown.total self._state.last_r_match = breakdown.match if breakdown.match >= CONVERGENCE_THRESHOLD: self._state.converged = True return self._build_observation( mismatch_summary=mismatch_text, reward_breakdown=breakdown, ) @property def state(self) -> PhysiXState: return self._state def current_observation(self) -> Optional[PhysiXObservation]: """Re-render the observation an external driver should feed to the agent for the *next* turn (i.e. before calling :meth:`step`). Used by the interactive HTTP router to build prompts mid-session. Returns ``None`` before :meth:`reset` has been called. """ if self._system is None or self._trajectory is None: return None last = self._history[-1] if self._history else None breakdown = ( RewardBreakdown(**last.reward_components) if last is not None else RewardBreakdown() ) mismatch = last.mismatch_summary if last is not None else "" return self._build_observation( mismatch_summary=mismatch, reward_breakdown=breakdown, ) @property def current_trajectory(self) -> Optional[TrajectoryData]: return self._trajectory @property def current_system(self) -> Optional[PhysicalSystem]: return self._system def _is_done(self) -> bool: if self._state.converged: return True return self._state.step_count >= self._max_turns def _score_hypothesis( self, action: PhysiXAction, ) -> tuple[RewardBreakdown, str]: assert self._system is not None assert self._trajectory is not None parameter_names = frozenset(action.params or {}) try: parsed = parse_equation( action.equation, state_variables=self._system.state_variables, parameter_names=parameter_names, ) except ParseError as exc: _log.debug("parse_equation failed: %s", exc) breakdown = compute_reward( parse_succeeded=False, r_match=0.0, operator_count=0, previous_r_match=self._state.last_r_match, ) return breakdown, f"Parse error: {exc}" try: predicted = simulate_hypothesis( parsed, state_variables=self._system.state_variables, parameters=dict(action.params or {}), initial_conditions=self._trajectory.initial_conditions, timestamps=self._trajectory.timestamps, ) except SimulationError as exc: _log.debug("simulate_hypothesis failed: %s", exc) breakdown = compute_reward( parse_succeeded=True, simulation_succeeded=False, r_match=0.0, operator_count=parsed.operator_count, previous_r_match=self._state.last_r_match, ) return breakdown, f"Simulation error: {exc}" r_match = compute_match( observed=self._trajectory.states, predicted=predicted, state_variables=self._system.state_variables, ) # Diagnostic-only — see RewardBreakdown class docstring. These do # not flow into ``total`` or the trainer; they exist to give the # demo UI a non-collapsing signal of "visual closeness" and to # power the mismatch summariser's frequency-mismatch hint. r_shape = compute_shape_match( observed=self._trajectory.states, predicted=predicted, state_variables=self._system.state_variables, ) r_amplitude = compute_amplitude_match( observed=self._trajectory.states, predicted=predicted, state_variables=self._system.state_variables, ) r_frequency = compute_frequency_match( observed=self._trajectory.states, predicted=predicted, state_variables=self._system.state_variables, ) residuals = residual_summary( timestamps=self._trajectory.timestamps, observed=self._trajectory.states, predicted=predicted, state_variables=self._system.state_variables, ) mismatch_text = summarize_mismatch( observed=self._trajectory.states, predicted=predicted, state_variables=self._system.state_variables, timestamps=self._trajectory.timestamps, summary=residuals, shape_score=r_shape, frequency_score=r_frequency, amplitude_score=r_amplitude, ) breakdown = compute_reward( parse_succeeded=True, simulation_succeeded=True, r_match=r_match, operator_count=parsed.operator_count, previous_r_match=self._state.last_r_match, ) # Attach diagnostic-only sub-scores. These intentionally bypass # ``compute_reward`` so the reward arithmetic is byte-identical # to the trained checkpoint's calibration. breakdown = breakdown.model_copy( update={ "shape": r_shape, "freq": r_frequency, "amplitude": r_amplitude, } ) return breakdown, mismatch_text def _record_history( self, action: PhysiXAction, breakdown: RewardBreakdown, mismatch_text: str, ) -> None: entry = HistoryEntry( turn=self._state.step_count, equation=action.equation, params=dict(action.params or {}), reward_total=breakdown.total, reward_components=breakdown.as_dict(), mismatch_summary=mismatch_text, ) self._history.append(entry) def _build_observation( self, *, mismatch_summary: str, reward_breakdown: RewardBreakdown, ) -> PhysiXObservation: assert self._system is not None assert self._trajectory is not None return PhysiXObservation( done=self._is_done(), reward=reward_breakdown.total, trajectory=self._trajectory.to_observation_samples(), state_variables=list(self._system.state_variables), hint=self._system.hint(self._state.ground_truth_params), history=[entry.as_dict() for entry in self._history], mismatch_summary=mismatch_summary, turn=self._state.step_count, turn_remaining=max(0, self._max_turns - self._state.step_count), system_id=self._state.system_id, stats=self._trajectory.stats(), reward_breakdown=reward_breakdown.as_dict(), ) def _sample_training_system_id(self) -> str: candidates: list[str] = [] for tier in self._train_tiers: candidates.extend(list_systems_by_tier(tier)) if not candidates: raise RuntimeError( f"No training systems found for tiers {self._train_tiers!r}." ) idx = int(self._rng.integers(0, len(candidates))) return candidates[idx]