rhythm_env / inference.py
InosLihka's picture
iter4: fix the 'constant belief = free reward' bug + 6 other deep issues
bb2a9c7
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
RhythmEnv Life Simulator — Inference Script
===================================
MANDATORY
- Before submitting, ensure the following variables are defined in your environment configuration:
API_BASE_URL The API endpoint for the LLM.
MODEL_NAME The model identifier to use for inference.
HF_TOKEN Your Hugging Face / API key.
LOCAL_IMAGE_NAME The name of the local image to use for the environment if you are using from_docker_image()
- Defaults are set only for API_BASE_URL and MODEL_NAME
(and should reflect your active inference setup):
API_BASE_URL = os.getenv("API_BASE_URL", "<your-active-endpoint>")
MODEL_NAME = os.getenv("MODEL_NAME", "<your-active-model>")
- The inference script must be named `inference.py` and placed in the root directory of the project
- Participants must use OpenAI Client for all LLM calls using above variables
STDOUT FORMAT
- The script must emit exactly three line types to stdout, in this order:
[START] task=<task_name> env=<benchmark> model=<model_name>
[STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
[END] success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...,rn>
Rules:
- One [START] line at episode begin.
- One [STEP] line per step, immediately after env.step() returns.
- One [END] line after env.close(), always emitted (even on exception).
- reward and rewards are formatted to 2 decimal places.
- done and success are lowercase booleans: true or false.
- error is the raw last_action_error string, or null if none.
- All fields on a single line with no newlines within a line.
- Each tasks should return score in [0, 1]
"""
import asyncio
import os
import sys
import textwrap
from typing import List, Optional
from openai import OpenAI
# Add current directory to path for local imports
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from client import RhythmEnv
from models import ActionType, RhythmAction
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
IMAGE_NAME = os.getenv("IMAGE_NAME")
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
BASE_URL = os.getenv("RHYTHM_ENV_URL", "https://InosLihka-rhythm-env.hf.space")
BENCHMARK = "rhythm_env"
# Tasks map to seed values: seed 0 = introvert_morning, 1 = extrovert_night_owl, 2 = workaholic_stoic
TASKS = ["profile_0", "profile_1", "profile_2"]
TASK_SEEDS = {"profile_0": 0, "profile_1": 1, "profile_2": 2}
MAX_STEPS = 28
SCORE_THRESHOLD = 0.1
SLOT_NAMES = ["Morning", "Afternoon", "Evening", "Night"]
DAY_NAMES = ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"]
SYSTEM_PROMPT = textwrap.dedent("""\
You are a life-management agent helping a person with HIDDEN preferences.
You see 5 life meters and a rolling history. The same action affects different
people differently — you must INFER who you're helping from the rewards and
meter changes you observe.
Each step, output ONE LINE in this exact format:
S M W ACTION_NAME
First write your BELIEF as 3 digits 0-9, then the ACTION that fits:
S = social preference (0=hates social, 9=loves social)
M = morning preference (0=night owl, 9=morning person)
W = work preference (0=avoids work, 9=workaholic)
ACTION choices:
DEEP_WORK, ADMIN_WORK, LEARN, SLEEP, EXERCISE, MEDITATE,
FAMILY_TIME, SOCIALIZE, ME_TIME, BINGE_WATCH
Example: 3 8 7 DEEP_WORK
Belief-action coupling guide:
- High S: SOCIALIZE, FAMILY_TIME (extrovert boosts)
- High M: DEEP_WORK in morning slots (morning-person bonus)
- High W: DEEP_WORK, LEARN (workaholic energy)
- Low S: MEDITATE, ME_TIME (introvert recharge)
- Low M: DEEP_WORK in evening/night (night-owl bonus)
Tactics:
- Early week: PROBE varied actions to gather information.
- Late week: EXPLOIT — pick actions matching your sharpened belief.
- Don't repeat the same action; you'll get a repetition penalty.
- Watch for crashes: any meter under 0.1 = big penalty.
- Connection decays passively — actively maintain it.
Respond with ONLY the format line, no other text.""")
# ---------------------------------------------------------------------------
# Logging helpers
# ---------------------------------------------------------------------------
def log_start(task: str, env: str, model: str) -> None:
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
error_val = error if error else "null"
done_val = str(done).lower()
print(
f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
flush=True,
)
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
print(
f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}",
flush=True,
)
# ---------------------------------------------------------------------------
# Heuristic action selection
# ---------------------------------------------------------------------------
def choose_action_heuristic(obs) -> RhythmAction:
"""Priority-based heuristic: critical recovery → time-appropriate → balance."""
slot = obs.slot
vitality = obs.vitality
cognition = obs.cognition
serenity = obs.serenity
connection = obs.connection
progress = obs.progress
# Critical recovery: prevent any meter from crashing
if vitality < 0.15:
return RhythmAction(action_type=ActionType.SLEEP)
if serenity < 0.15:
return RhythmAction(action_type=ActionType.MEDITATE)
if connection < 0.15:
return RhythmAction(action_type=ActionType.FAMILY_TIME)
# Night slot: prioritize sleep unless critical
if slot == 3:
if vitality < 0.5:
return RhythmAction(action_type=ActionType.SLEEP)
if connection < 0.3:
return RhythmAction(action_type=ActionType.FAMILY_TIME)
return RhythmAction(action_type=ActionType.SLEEP)
# Morning: productivity if able
if slot == 0:
if vitality > 0.4 and cognition > 0.3:
return RhythmAction(action_type=ActionType.DEEP_WORK)
if vitality < 0.4:
return RhythmAction(action_type=ActionType.EXERCISE)
return RhythmAction(action_type=ActionType.ADMIN_WORK)
# Afternoon: balanced mix
if slot == 1:
if connection < 0.3:
return RhythmAction(action_type=ActionType.FAMILY_TIME)
if progress < 0.3 and vitality > 0.3:
return RhythmAction(action_type=ActionType.LEARN)
if serenity < 0.4:
return RhythmAction(action_type=ActionType.MEDITATE)
return RhythmAction(action_type=ActionType.ADMIN_WORK)
# Evening: social and recovery
if connection < 0.4:
return RhythmAction(action_type=ActionType.SOCIALIZE)
if serenity < 0.5:
return RhythmAction(action_type=ActionType.ME_TIME)
if vitality < 0.4:
return RhythmAction(action_type=ActionType.EXERCISE)
return RhythmAction(action_type=ActionType.MEDITATE)
def choose_action_llm(obs, llm_client: OpenAI) -> RhythmAction:
"""Use LLM to pick an action (and emit belief), fall back to heuristic on failure."""
day_name = DAY_NAMES[obs.day] if obs.day < 7 else f"Day {obs.day}"
slot_name = SLOT_NAMES[obs.slot] if obs.slot < 4 else f"Slot {obs.slot}"
event_str = f"\nActive event: {obs.active_event}" if obs.active_event else ""
history_lines = []
for h in (getattr(obs, "step_history", None) or [])[-5:]:
# Iter 4 fix: include anomalies for profile-inference signal
va = getattr(h, "vitality_anomaly", 0.0)
ca = getattr(h, "cognition_anomaly", 0.0)
pa = getattr(h, "progress_anomaly", 0.0)
sa = getattr(h, "serenity_anomaly", 0.0)
cna = getattr(h, "connection_anomaly", 0.0)
history_lines.append(
f" step {h.step}: {h.action} -> reward {h.reward:+.2f} "
f"(V{h.vitality_delta:+.2f} C{h.cognition_delta:+.2f} "
f"P{h.progress_delta:+.2f} S{h.serenity_delta:+.2f} Cn{h.connection_delta:+.2f})"
f" [anom V{va:+.2f} C{ca:+.2f} P{pa:+.2f} S{sa:+.2f} Cn{cna:+.2f}]"
)
history_str = ""
if history_lines:
history_str = "\n\nRecent history (anom = profile-inference signal):\n" + "\n".join(history_lines)
user_prompt = textwrap.dedent(f"""\
Step: {obs.timestep}/{MAX_STEPS} ({day_name} {slot_name})
Remaining steps: {obs.remaining_steps}
Meters:
Vitality: {obs.vitality:.2f}
Cognition: {obs.cognition:.2f}
Progress: {obs.progress:.2f}
Serenity: {obs.serenity:.2f}
Connection: {obs.connection:.2f}{event_str}{history_str}
Output belief then action (format: S M W ACTION_NAME):""")
try:
completion = llm_client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
],
temperature=0.3,
max_tokens=20,
stream=False,
)
text = (completion.choices[0].message.content or "").strip()
return parse_llm_action(text)
except Exception:
return choose_action_heuristic(obs)
def parse_llm_action(text: str) -> RhythmAction:
"""Parse LLM response (action+belief format) into a RhythmAction.
Belief digits are ignored at inference time — only used as a demo signal.
"""
# Reuse the training parser for consistency
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "training"))
try:
from reward_functions import extract_action_and_belief
action, _belief, _provided = extract_action_and_belief(text)
if action is not None:
return RhythmAction(action_type=action)
except ImportError:
pass
# Fallback: legacy parsing
text = text.strip().upper().replace(" ", "_")
for action_type in ActionType:
if action_type.value.upper() == text:
return RhythmAction(action_type=action_type)
for action_type in ActionType:
if action_type.value.upper() in text:
return RhythmAction(action_type=action_type)
return RhythmAction(action_type=ActionType.SLEEP)
# ---------------------------------------------------------------------------
# Main loop
# ---------------------------------------------------------------------------
async def run_task(task_name: str, llm_client: OpenAI) -> float:
"""Run a single task (profile) and return the score."""
seed = TASK_SEEDS.get(task_name, 0)
if IMAGE_NAME:
env = await RhythmEnv.from_docker_image(IMAGE_NAME)
else:
env = RhythmEnv(base_url=BASE_URL)
rewards: List[float] = []
steps_taken = 0
score = 0.0
success = False
log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME)
try:
async with env:
result = await env.reset(seed=seed)
for step in range(1, MAX_STEPS + 1):
if result.done:
break
# Use LLM if available, otherwise heuristic
if llm_client is not None:
action = choose_action_llm(result.observation, llm_client)
else:
action = choose_action_heuristic(result.observation)
action_str = action.action_type.value
result = await env.step(action)
reward = result.reward or 0.0
done = result.done
rewards.append(reward)
steps_taken = step
log_step(step=step, action=action_str, reward=reward, done=done, error=None)
if done:
break
# Get final score from grader
score = result.observation.reward_breakdown.get("final_score", 0.0)
score = max(0.0, min(1.0, score))
success = score >= SCORE_THRESHOLD
except Exception as e:
print(f"[DEBUG] Error running task {task_name}: {e}", flush=True)
finally:
try:
await env.close()
except Exception as e:
print(f"[DEBUG] env.close() error: {e}", flush=True)
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
return score
async def main() -> None:
llm_client = None
if API_KEY:
llm_client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
scores = []
for task_name in TASKS:
s = await run_task(task_name, llm_client)
scores.append(s)
avg = sum(scores) / len(scores) if scores else 0.0
print(f"\n[SUMMARY] avg_score={avg:.3f} scores={','.join(f'{s:.3f}' for s in scores)}", flush=True)
if __name__ == "__main__":
asyncio.run(main())