rhythm_env / server /rubrics.py
InosLihka's picture
Refactor grader to use openenv.core.rubrics.WeightedSum + Rubric subclasses
f0ca22d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Composable Rubric implementation of the RhythmEnv episode grader.
Mirrors the original `_grade_episode` in `rhythm_environment.py` but built
on top of `openenv.core.rubrics.Rubric` + `WeightedSum` β€” the framework's
official scoring composition primitives. Each Rubric subclass wraps one
of the 6 grader components; `make_rubric(env)` composes them with their
weights.
The `forward(action, observation)` signature is required by the Rubric
ABC. Because RhythmEnv grades at episode end (after `done=True`) using
aggregated env state β€” not per-(action, observation) data β€” these
subclasses ignore the per-step args and read directly from the env they
were constructed with. This is the recommended pattern from RFC 004 for
trajectory-summary scoring.
Used by `RhythmEnvironment._grade_episode`. The original numerical
implementation is preserved in the legacy code path; this file is the
primary, conformant implementation.
"""
from __future__ import annotations
from typing import Any, TYPE_CHECKING
from openenv.core.rubrics import Rubric, WeightedSum
if TYPE_CHECKING:
from server.rhythm_environment import RhythmEnvironment
# ---------------------------------------------------------------------------
# Component rubrics β€” one per scored axis of the final grade.
# ---------------------------------------------------------------------------
class CrashFreeRubric(Rubric):
"""Reward for keeping all 5 meters above the crash threshold.
Score = 1 βˆ’ (crashes / total_possible_meter_step_drops). Higher is
better; perfect play (no meter ever drops below 0.10) gives 1.0.
"""
def __init__(self, env: "RhythmEnvironment") -> None:
super().__init__()
self._env = env
def forward(self, action: Any, observation: Any) -> float:
from server.rhythm_environment import METERS # local import avoids cycle
steps = max(self._env._timestep, 1)
return 1.0 - (self._env._crash_count / (steps * len(METERS)))
class ProgressRubric(Rubric):
"""Career/skill growth β€” final value of the progress meter."""
def __init__(self, env: "RhythmEnvironment") -> None:
super().__init__()
self._env = env
def forward(self, action: Any, observation: Any) -> float:
return float(self._env._progress)
class ConnectionRubric(Rubric):
"""Relationship maintenance β€” final value of the connection meter."""
def __init__(self, env: "RhythmEnvironment") -> None:
super().__init__()
self._env = env
def forward(self, action: Any, observation: Any) -> float:
return float(self._env._connection)
class AdaptationRubric(Rubric):
"""Implicit meta-learning signal: late-half mean reward minus early-half.
Scaled to [0, 1]. Per-step rewards are profile-weighted so a positive
gain means the agent is exploiting profile-aware play that it wasn't
using early. Gated by `late_quality` so a "terrible-then-mediocre"
exploit cannot win.
"""
def __init__(self, env: "RhythmEnvironment") -> None:
super().__init__()
self._env = env
def forward(self, action: Any, observation: Any) -> float:
steps = max(self._env._timestep, 1)
half = max(steps // 2, 1)
rewards = self._env._step_rewards
early = rewards[:half]
late = rewards[half:]
if not (early and late):
return 0.0
mean_early = sum(early) / len(early)
mean_late = sum(late) / len(late)
# Per-step rewards are clamped to [-3, +3] in step(), so normalize
# late_quality with the [-3, +3] range β€” without this, the gate
# saturates at 1.0 for any mean_late β‰₯ 1 and the grader can't
# distinguish good from excellent late-half quality.
late_quality = max(0.0, min(1.0, (mean_late + 3.0) / 6.0))
gain = mean_late - mean_early
# gain ∈ [-6, +6]; only positive gain counts, normalized to [0, 1]
gain_norm = max(0.0, min(1.0, gain / 3.0))
return gain_norm * late_quality
class EfficiencyRubric(Rubric):
"""Bounded normalized average per-step reward across the episode."""
def __init__(self, env: "RhythmEnvironment") -> None:
super().__init__()
self._env = env
def forward(self, action: Any, observation: Any) -> float:
steps = max(self._env._timestep, 1)
avg_reward = self._env._total_reward / steps
return max(0.0, min(1.0, (avg_reward + 1.0) / 2.0))
class BeliefAccuracyRubric(Rubric):
"""Explicit meta-RL inference signal.
Score = max(0, 1 βˆ’ MAE) between the agent's last-emitted belief and
the true profile vector. Returns 0 if the agent never emitted a
belief (heuristic / random baselines) β€” by design, only agents that
actually try to infer get credit on this axis.
"""
def __init__(self, env: "RhythmEnvironment") -> None:
super().__init__()
self._env = env
def forward(self, action: Any, observation: Any) -> float:
from server.rhythm_environment import profile_to_belief_vector
emitted = self._env._final_belief
if emitted is None:
return 0.0
true_belief = profile_to_belief_vector(self._env._profile)
mae = sum(abs(b - t) for b, t in zip(emitted, true_belief)) / 3.0
return max(0.0, 1.0 - mae)
# ---------------------------------------------------------------------------
# Composition
# ---------------------------------------------------------------------------
# Weights matching the original _grade_episode formula; sum to 1.0.
GRADE_WEIGHTS = {
"crash_free": 0.15,
"progress": 0.20,
"connection": 0.10,
"adaptation": 0.25,
"efficiency": 0.10,
"belief_accuracy": 0.20,
}
def make_grade_rubric(env: "RhythmEnvironment") -> WeightedSum:
"""Build the composed `WeightedSum` rubric for grading episodes.
Returns a single `Rubric` whose `forward(None, None)` reads the env's
aggregated state and returns the same final_score the original
`_grade_episode` would have computed.
"""
return WeightedSum(
rubrics=[
CrashFreeRubric(env),
ProgressRubric(env),
ConnectionRubric(env),
AdaptationRubric(env),
EfficiencyRubric(env),
BeliefAccuracyRubric(env),
],
weights=[
GRADE_WEIGHTS["crash_free"],
GRADE_WEIGHTS["progress"],
GRADE_WEIGHTS["connection"],
GRADE_WEIGHTS["adaptation"],
GRADE_WEIGHTS["efficiency"],
GRADE_WEIGHTS["belief_accuracy"],
],
)