rhythm_env / tests /test_rhythm_env.py
InosLihka's picture
Refactor grader to use openenv.core.rubrics.WeightedSum + Rubric subclasses
f0ca22d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Tests for the RhythmEnv Life Simulator."""
import pytest
from models import ActionType, RhythmAction, RhythmObservation
from server.rhythm_environment import (
CRITICAL_THRESHOLD,
MAX_STEPS,
METERS,
PROFILES,
RhythmEnvironment,
)
@pytest.fixture
def env():
return RhythmEnvironment()
def make_action(action_type: ActionType) -> RhythmAction:
return RhythmAction(action_type=action_type)
# ---------------------------------------------------------------------------
# TestReset
# ---------------------------------------------------------------------------
class TestReset:
def test_returns_valid_observation(self, env):
obs = env.reset(seed=0)
assert isinstance(obs, RhythmObservation)
assert obs.timestep == 0
assert obs.day == 0
assert obs.slot == 0
assert obs.remaining_steps == MAX_STEPS
assert obs.done is False
assert obs.reward == 0.0
def test_meters_initialized(self, env):
obs = env.reset(seed=0)
assert 0.0 <= obs.vitality <= 1.0
assert 0.0 <= obs.cognition <= 1.0
assert obs.progress == 0.0
assert 0.0 <= obs.serenity <= 1.0
assert 0.0 <= obs.connection <= 1.0
def test_seed_selects_profile(self, env):
"""Different seeds select different profiles."""
profiles_seen = set()
for seed in range(3):
env.reset(seed=seed)
profiles_seen.add(env.state.profile_name)
assert len(profiles_seen) == 3
def test_deterministic_with_same_seed(self, env):
obs1 = env.reset(seed=42)
obs2 = env.reset(seed=42)
assert obs1.vitality == obs2.vitality
assert obs1.cognition == obs2.cognition
assert obs1.serenity == obs2.serenity
assert obs1.connection == obs2.connection
def test_explicit_profile_selection(self, env):
env.reset(seed=0, profile="workaholic_stoic")
assert env.state.profile_name == "workaholic_stoic"
def test_reset_clears_state(self, env):
env.reset(seed=0)
for _ in range(5):
env.step(make_action(ActionType.DEEP_WORK))
obs = env.reset(seed=0)
assert obs.timestep == 0
assert obs.progress == 0.0
# ---------------------------------------------------------------------------
# TestStep
# ---------------------------------------------------------------------------
class TestStep:
def test_timestep_advances(self, env):
env.reset(seed=0)
obs = env.step(make_action(ActionType.DEEP_WORK))
assert obs.timestep == 1
def test_day_and_slot_correct(self, env):
env.reset(seed=0)
for _ in range(5):
obs = env.step(make_action(ActionType.SLEEP))
assert obs.day == 1
assert obs.slot == 1
def test_deep_work_increases_progress(self, env):
env.reset(seed=0)
obs = env.step(make_action(ActionType.DEEP_WORK))
assert obs.progress > 0.0
def test_deep_work_drains_vitality(self, env):
env.reset(seed=0)
initial_vitality = env.state.vitality
obs = env.step(make_action(ActionType.DEEP_WORK))
assert obs.vitality < initial_vitality
def test_sleep_recovers_vitality(self, env):
env.reset(seed=0)
for _ in range(3):
env.step(make_action(ActionType.DEEP_WORK))
vitality_before_sleep = env.state.vitality
obs = env.step(make_action(ActionType.SLEEP))
assert obs.vitality > vitality_before_sleep
def test_socialize_builds_connection(self, env):
env.reset(seed=0)
initial_connection = env.state.connection
obs = env.step(make_action(ActionType.FAMILY_TIME))
assert obs.connection > initial_connection - 0.05
def test_episode_ends_at_max_steps(self, env):
env.reset(seed=0)
for i in range(MAX_STEPS):
obs = env.step(make_action(ActionType.SLEEP))
assert obs.done is True
assert obs.timestep == MAX_STEPS
def test_not_done_before_max_steps(self, env):
env.reset(seed=0)
for i in range(MAX_STEPS - 1):
obs = env.step(make_action(ActionType.SLEEP))
assert obs.done is False
def test_meters_stay_in_bounds(self, env):
"""No meter exceeds [0.0, 1.0] regardless of actions."""
env.reset(seed=0)
for _ in range(MAX_STEPS):
obs = env.step(make_action(ActionType.DEEP_WORK))
for meter in METERS:
val = getattr(obs, meter)
assert 0.0 <= val <= 1.0, f"{meter}={val} out of bounds"
def test_low_vitality_reduces_effectiveness(self, env):
"""Progress gain should be lower when vitality is low."""
env.reset(seed=0, profile="introvert_morning")
obs_high = env.step(make_action(ActionType.DEEP_WORK))
progress_high = obs_high.progress
env.reset(seed=0, profile="introvert_morning")
for _ in range(6):
env.step(make_action(ActionType.DEEP_WORK))
progress_before = env.state.progress
env.step(make_action(ActionType.DEEP_WORK))
progress_gained_low = env.state.progress - progress_before
assert progress_high > progress_gained_low
# ---------------------------------------------------------------------------
# TestProfiles
# ---------------------------------------------------------------------------
class TestProfiles:
def test_introvert_social_drains_more(self, env):
"""Introvert loses more vitality from socializing than extrovert."""
env.reset(seed=0, profile="introvert_morning")
v_before_intro = env.state.vitality
env.step(make_action(ActionType.SOCIALIZE))
intro_drain = v_before_intro - env.state.vitality
env2 = RhythmEnvironment()
env2.reset(seed=0, profile="extrovert_night_owl")
v_before_extro = env2.state.vitality
env2.step(make_action(ActionType.SOCIALIZE))
extro_drain = v_before_extro - env2.state.vitality
assert intro_drain > extro_drain
def test_workaholic_progress_gives_serenity(self, env):
"""Workaholic has better serenity outcome from deep work than introvert."""
env.reset(seed=0, profile="workaholic_stoic")
serenity_before = env.state.serenity
env.step(make_action(ActionType.DEEP_WORK))
workaholic_change = env.state.serenity - serenity_before
env2 = RhythmEnvironment()
env2.reset(seed=0, profile="introvert_morning")
serenity_before_intro = env2.state.serenity
env2.step(make_action(ActionType.DEEP_WORK))
introvert_change = env2.state.serenity - serenity_before_intro
assert workaholic_change > introvert_change
def test_binge_shame_introvert(self, env):
"""Introvert suffers extra serenity loss from binge watching."""
env.reset(seed=0, profile="introvert_morning")
serenity_before = env.state.serenity
env.step(make_action(ActionType.BINGE_WATCH))
intro_change = env.state.serenity - serenity_before
env2 = RhythmEnvironment()
env2.reset(seed=0, profile="extrovert_night_owl")
serenity_before_ext = env2.state.serenity
env2.step(make_action(ActionType.BINGE_WATCH))
ext_change = env2.state.serenity - serenity_before_ext
assert intro_change < ext_change
def test_different_rewards_same_action(self, env):
"""Same action produces different rewards for different profiles."""
rewards = {}
for profile_name in ["introvert_morning", "extrovert_night_owl", "workaholic_stoic"]:
e = RhythmEnvironment()
e.reset(seed=0, profile=profile_name)
obs = e.step(make_action(ActionType.DEEP_WORK))
rewards[profile_name] = obs.reward
values = list(rewards.values())
assert len(set(round(v, 3) for v in values)) > 1
def test_extrovert_night_cognition_bonus(self, env):
"""Extrovert gets better cognition gains in evening vs morning."""
env.reset(seed=0, profile="extrovert_night_owl")
env.step(make_action(ActionType.SLEEP)) # morning
env.step(make_action(ActionType.SLEEP)) # afternoon
cognition_before = env.state.cognition
env.step(make_action(ActionType.MEDITATE)) # evening
evening_gain = env.state.cognition - cognition_before
env.reset(seed=0, profile="extrovert_night_owl")
cognition_before_m = env.state.cognition
env.step(make_action(ActionType.MEDITATE)) # morning
morning_gain = env.state.cognition - cognition_before_m
assert evening_gain > morning_gain
# ---------------------------------------------------------------------------
# TestEvents
# ---------------------------------------------------------------------------
class TestEvents:
def test_events_deterministic_with_seed(self, env):
"""Same seed produces same event sequence."""
events1 = []
env.reset(seed=99)
for _ in range(MAX_STEPS):
obs = env.step(make_action(ActionType.SLEEP))
events1.append(obs.active_event)
events2 = []
env.reset(seed=99)
for _ in range(MAX_STEPS):
obs = env.step(make_action(ActionType.SLEEP))
events2.append(obs.active_event)
assert events1 == events2
def test_event_visible_in_observation(self, env):
"""When an event fires, active_event is set in observation."""
found_event = False
for seed in range(100):
env.reset(seed=seed)
for _ in range(MAX_STEPS):
obs = env.step(make_action(ActionType.SLEEP))
if obs.active_event is not None:
found_event = True
assert obs.active_event in [
"prod_crash", "family_emergency", "illness", "good_news"
]
break
if found_event:
break
assert found_event, "No events triggered in 100 episodes"
def test_no_event_when_none(self, env):
"""Most steps should have no event."""
env.reset(seed=0)
no_event_count = 0
for _ in range(MAX_STEPS):
obs = env.step(make_action(ActionType.SLEEP))
if obs.active_event is None:
no_event_count += 1
assert no_event_count > MAX_STEPS * 0.7
# ---------------------------------------------------------------------------
# TestGrader
# ---------------------------------------------------------------------------
class TestGrader:
def test_final_score_in_range(self, env):
env.reset(seed=0)
for _ in range(MAX_STEPS):
obs = env.step(make_action(ActionType.SLEEP))
assert "final_score" in obs.reward_breakdown
score = obs.reward_breakdown["final_score"]
assert 0.0 <= score <= 1.0
def test_balanced_play_beats_all_sleep(self, env):
"""A balanced strategy should score higher than just sleeping."""
env.reset(seed=0)
for _ in range(MAX_STEPS):
obs_sleep = env.step(make_action(ActionType.SLEEP))
score_sleep = obs_sleep.reward_breakdown["final_score"]
balanced_actions = [
ActionType.DEEP_WORK, ActionType.LEARN,
ActionType.EXERCISE, ActionType.FAMILY_TIME,
] * 7
env.reset(seed=0)
for action_type in balanced_actions:
obs_balanced = env.step(make_action(action_type))
score_balanced = obs_balanced.reward_breakdown["final_score"]
assert score_balanced > score_sleep
def test_deterministic_grading(self, env):
"""Same actions produce same final score."""
scores = []
for _ in range(2):
env.reset(seed=42)
for _ in range(MAX_STEPS):
obs = env.step(make_action(ActionType.DEEP_WORK))
scores.append(obs.reward_breakdown["final_score"])
assert scores[0] == scores[1]
def test_all_binge_scores_low(self, env):
"""Binge watching everything should produce a low score."""
env.reset(seed=0)
for _ in range(MAX_STEPS):
obs = env.step(make_action(ActionType.BINGE_WATCH))
score = obs.reward_breakdown["final_score"]
assert score < 0.5
# ---------------------------------------------------------------------------
# TestEdgeCases
# ---------------------------------------------------------------------------
class TestEdgeCases:
def test_observation_hides_profile(self, env):
"""Observation should not expose profile_name."""
obs = env.reset(seed=0)
obs_dict = obs.model_dump()
assert "profile_name" not in obs_dict
def test_state_exposes_profile(self, env):
"""State should include profile_name for debugging."""
# Default: continuous profile (name like 'sampled_0')
env.reset(seed=0)
assert env.state.profile_name != ""
assert env.state.profile_name.startswith("sampled_")
# Explicit profile: name matches the requested reference profile
env.reset(seed=0, profile="workaholic_stoic")
assert env.state.profile_name == "workaholic_stoic"
assert env.state.profile_name in [p["name"] for p in PROFILES]
def test_all_action_types_valid(self, env):
"""Every ActionType should be processable without error."""
env.reset(seed=0)
for action_type in ActionType:
e = RhythmEnvironment()
e.reset(seed=0)
obs = e.step(make_action(action_type))
assert isinstance(obs, RhythmObservation)
# ---------------------------------------------------------------------------
# Belief-accuracy grader component
# ---------------------------------------------------------------------------
class TestBeliefAccuracyGrader:
"""The grader awards 0.20 weight to belief_accuracy. Agents that don't
emit beliefs get 0 on this component; agents whose final belief matches
the true profile vector get up to 0.20 added to final_score.
"""
def _run_episode_with_belief(self, seed, belief, profile=None):
env = RhythmEnvironment()
if profile:
obs = env.reset(seed=seed, profile=profile)
else:
obs = env.reset(seed=seed)
for _ in range(MAX_STEPS):
if obs.done:
break
env.record_belief(belief)
obs = env.step(make_action(ActionType.SLEEP))
return obs.reward_breakdown.get("final_score", 0.0)
def _run_episode_no_belief(self, seed, profile=None):
env = RhythmEnvironment()
if profile:
obs = env.reset(seed=seed, profile=profile)
else:
obs = env.reset(seed=seed)
for _ in range(MAX_STEPS):
if obs.done:
break
obs = env.step(make_action(ActionType.SLEEP))
return obs.reward_breakdown.get("final_score", 0.0)
def test_no_belief_means_zero_belief_component(self, env):
"""Agent that never calls record_belief gets 0 on the belief component."""
score = self._run_episode_no_belief(seed=42)
# Without belief, max possible score is 0.80 (all weights ex belief).
# Realistic ceiling is much lower since SLEEP-only doesn't max meters.
assert score <= 0.80
def test_perfect_belief_lifts_score(self, env):
"""An agent that emits the TRUE belief vector should score higher
than the same actions with no belief — by up to +0.20."""
# Use a known reference profile so we can hand-pick the perfect belief.
from server.rhythm_environment import (
PROFILE_MAP,
profile_to_belief_vector,
)
profile_name = "workaholic_stoic"
true_belief = profile_to_belief_vector(PROFILE_MAP[profile_name])
no_belief_score = self._run_episode_no_belief(seed=7, profile=profile_name)
perfect_score = self._run_episode_with_belief(
seed=7, belief=true_belief, profile=profile_name
)
# Perfect belief contributes 0.20 to final_score
assert perfect_score > no_belief_score
assert (perfect_score - no_belief_score) == pytest.approx(0.20, abs=0.01)
def test_wrong_belief_scores_less_than_perfect(self, env):
"""Wrong belief still counts (0 ≤ score ≤ 1) but less than perfect."""
from server.rhythm_environment import (
PROFILE_MAP,
profile_to_belief_vector,
)
profile_name = "introvert_morning"
true_belief = profile_to_belief_vector(PROFILE_MAP[profile_name])
wrong_belief = [1.0 - b for b in true_belief] # opposite
perfect_score = self._run_episode_with_belief(
seed=7, belief=true_belief, profile=profile_name
)
wrong_score = self._run_episode_with_belief(
seed=7, belief=wrong_belief, profile=profile_name
)
assert perfect_score > wrong_score
def test_record_belief_validates_length(self, env):
env.reset(seed=0)
with pytest.raises(ValueError):
env.record_belief([0.5, 0.5]) # wrong length
with pytest.raises(ValueError):
env.record_belief([0.5, 0.5, 0.5, 0.5]) # too long
def test_record_belief_clamps_to_unit_interval(self, env):
"""Beliefs outside [0, 1] should be clamped, not rejected."""
env.reset(seed=0)
env.record_belief([-0.5, 1.5, 0.5])
# Internal state should be clamped
assert env._final_belief == [0.0, 1.0, 0.5]
def test_grader_uses_openenv_weighted_sum_rubric(self, env):
"""Grader composes child rubrics via openenv.core.rubrics.WeightedSum."""
from openenv.core.rubrics import Rubric, WeightedSum
from server.rubrics import (
CrashFreeRubric, ProgressRubric, ConnectionRubric,
AdaptationRubric, EfficiencyRubric, BeliefAccuracyRubric,
GRADE_WEIGHTS, make_grade_rubric,
)
# Trigger a full episode so _grade_episode runs and builds the rubric
obs = env.reset(seed=0)
for _ in range(MAX_STEPS):
if obs.done:
break
obs = env.step(make_action(ActionType.SLEEP))
rubric = env._grade_rubric
assert isinstance(rubric, WeightedSum), "grader must use WeightedSum"
assert isinstance(rubric, Rubric)
# 6 children, one per scoring component
children = list(rubric.children())
assert len(children) == 6
types = {type(c).__name__ for c in children}
assert types == {
"CrashFreeRubric", "ProgressRubric", "ConnectionRubric",
"AdaptationRubric", "EfficiencyRubric", "BeliefAccuracyRubric",
}
# Weights must sum to 1.0 (WeightedSum enforces; sanity check the keys)
assert abs(sum(GRADE_WEIGHTS.values()) - 1.0) < 1e-6
def test_make_grade_rubric_is_pure_function(self, env):
"""make_grade_rubric should produce equivalent rubrics across calls."""
from server.rubrics import make_grade_rubric
env.reset(seed=42)
r1 = make_grade_rubric(env)
r2 = make_grade_rubric(env)
# Same shape, fresh object
assert len(list(r1.children())) == len(list(r2.children())) == 6
assert r1 is not r2
# Same weights
assert r1._weights == r2._weights