physix-live / physix /models.py
Pratyush-01's picture
cleanup: strip verbose comments from physix/models.py
605f6fe verified
"""Pydantic schemas + constants. Behaviour lives elsewhere."""
from __future__ import annotations
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from openenv.core.env_server import Action, Observation, State
#: Per-episode turn budget. Episodes terminate earlier if r_match crosses
#: :data:`CONVERGENCE_THRESHOLD`.
DEFAULT_MAX_TURNS: int = 8
CONVERGENCE_THRESHOLD: float = 0.93
#: Reward component weights. Ablations only edit this dict.
REWARD_WEIGHTS: dict[str, float] = {
"match": 0.5,
"progress": 0.2,
"simplicity": 0.2,
"format": 0.1,
}
class PhysiXAction(Action):
"""One agent step.
Fields have defaults so tests can construct partial actions, and so
the env can fabricate a ``format=0`` no-op action for completions
that fail to parse JSON. The LLM is expected to fill all three;
an empty string / dict is fine when irrelevant.
"""
equation: str = Field(default="", description="ODE in the verifier DSL")
params: dict[str, float] = Field(
default_factory=dict,
description="Numerical substitutions for free symbols on the RHS",
)
rationale: str = Field(default="")
class PhysiXObservation(Observation):
"""What the agent sees per step. Inherits ``done`` / ``reward`` from
:class:`openenv.core.env_server.Observation`."""
trajectory: list[dict[str, float]] = Field(
default_factory=list,
description="Observed (noisy) trajectory as list of timestep dicts",
)
state_variables: list[str] = Field(
default_factory=list,
description="Names of state-variable keys present in trajectory[i] (excluding t)",
)
hint: str = Field(default="", description="One-sentence physical-context string")
history: list[dict[str, Any]] = Field(
default_factory=list,
description="Prior turns surfaced back so the agent can refine",
)
mismatch_summary: str = Field(
default="",
description="English description of where last prediction diverged",
)
turn: int = Field(default=0, ge=0, description="0-indexed turn counter")
turn_remaining: int = Field(
default=DEFAULT_MAX_TURNS,
ge=0,
description="Turns left in the episode budget",
)
system_id: str = Field(default="", description="Stable id of underlying system")
stats: dict[str, float] = Field(
default_factory=dict, description="Aggregate trajectory statistics"
)
reward_breakdown: dict[str, float] = Field(
default_factory=dict,
description="Four reward components from the previous step",
)
class PhysiXState(State):
"""Episode-level state. The ground-truth equation lives here for logging;
it is *never* surfaced to the agent. ``last_reward_total`` feeds the
per-turn ``progress`` reward delta."""
system_id: str = Field(default="")
ground_truth_equation: str = Field(default="")
ground_truth_params: dict[str, float] = Field(default_factory=dict)
last_reward_total: float = Field(default=0.0)
last_r_match: float = Field(default=0.0)
converged: bool = Field(default=False)
max_turns: int = Field(default=DEFAULT_MAX_TURNS, ge=1)
class HistoryEntry(BaseModel):
"""One previous turn, surfaced back to the agent on the next step."""
model_config = ConfigDict(extra="forbid")
turn: int
equation: str
params: dict[str, float]
reward_total: float
reward_components: dict[str, float]
mismatch_summary: str
def as_dict(self) -> dict[str, Any]:
return {
"turn": self.turn,
"equation": self.equation,
"params": self.params,
"reward_total": round(self.reward_total, 4),
"reward_components": {
k: round(v, 4) for k, v in self.reward_components.items()
},
"mismatch_summary": self.mismatch_summary,
}
class RewardBreakdown(BaseModel):
"""Per-step reward components. shape/freq/amplitude are diagnostic only (not used in training)."""
model_config = ConfigDict(extra="forbid")
match: float = 0.0
progress: float = 0.0
simplicity: float = 0.0
format: float = 0.0
total: float = 0.0
# Diagnostic-only. Not part of the reward total. See class docstring.
shape: float = 0.0
freq: float = 0.0
amplitude: float = 0.0
def as_dict(self) -> dict[str, float]:
return {
"match": self.match,
"progress": self.progress,
"simplicity": self.simplicity,
"format": self.format,
"total": self.total,
"shape": self.shape,
"freq": self.freq,
"amplitude": self.amplitude,
}