rhythm_env / training /inference_eval.py
InosLihka's picture
Fix prompt truncation in inference_eval.py: max_seq_length 768 -> 2048
1217c1d
"""
RhythmEnv Inference Evaluation — Baseline vs Trained, with meta-RL eval suite.
Three evaluation conditions:
1. discrete-3-profiles: 3 hardcoded reference profiles. A sanity check
that the meta-trained agent still handles the original named profiles.
2. continuous-in-distribution: Sampled profiles from the training distribution
(was the agent able to learn the meta-policy?)
3. continuous-OOD: Profiles from a held-out region of the parameter space
(does the meta-policy generalize, or did the agent memorize?)
Usage:
# Baselines only (no trained model):
python training/inference_eval.py
# With trained model:
python training/inference_eval.py --model_path outputs/rhythmenv_trained
"""
import argparse
import json
import os
import random
import sys
from typing import Optional
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from models import ActionType, RhythmAction
from server.rhythm_environment import RhythmEnvironment, MAX_STEPS, sample_profile, profile_to_belief_vector
from training.dataset import heuristic_action
DISCRETE_PROFILES = ["introvert_morning", "extrovert_night_owl", "workaholic_stoic"]
SLOT_NAMES = ["Morning", "Afternoon", "Evening", "Night"]
DAY_NAMES = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"]
# Seed ranges: training distribution = [0, 200); OOD = [10000, 10030)
# (10000 offset makes seeded sampled profiles in OOD region statistically distinct)
IN_DIST_SEEDS_DEFAULT = list(range(100, 110)) # 10 unseen-by-training in-distribution
OOD_SEEDS_DEFAULT = list(range(10000, 10010)) # 10 OOD seeds
def random_action(rng) -> ActionType:
return rng.choice(list(ActionType))
def model_action(obs, model, tokenizer, return_belief: bool = False):
"""Get action (and optionally belief) from trained model."""
# Lazy imports: keep the heavy training-stack imports out of module load
# so this script can run in baseline-only mode without unsloth/transformers.
from training.dataset import format_observation_prompt, SYSTEM_PROMPT
from training.reward_functions import extract_action_and_belief
prompt = format_observation_prompt(obs)
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt},
]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(text, return_tensors="pt").to(model.device)
# 256 tokens lets the SFT-distilled student emit its full
# <reasoning>...</reasoning> block PLUS the final S M W ACTION_NAME line.
# Earlier 20-token cap truncated mid-reasoning so the answer line was
# never reached and parser fell back to extracting action names from
# the partial reasoning text.
outputs = model.generate(**inputs, max_new_tokens=256, temperature=0.7, do_sample=True)
response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
action_type, belief, _ = extract_action_and_belief(response)
if action_type is None:
action_type = ActionType.SLEEP
return (action_type, belief) if return_belief else action_type
def run_episode(
seed: int,
strategy: str,
profile_mode: str = "continuous",
profile: Optional[str] = None,
model=None,
tokenizer=None,
) -> dict:
"""Run a single episode and return per-episode metrics."""
rng = random.Random(seed + 500)
env = RhythmEnvironment()
if profile is not None:
obs = env.reset(seed=seed, profile=profile)
else:
obs = env.reset(seed=seed, profile_mode=profile_mode)
true_belief = env.get_belief_target()
profile_name = env.state.profile_name
total_reward = 0.0
step_rewards = []
actions_taken = []
beliefs_seen = [] # for trained model
for step in range(MAX_STEPS):
if obs.done:
break
if strategy == "heuristic":
action_type = heuristic_action(obs)
elif strategy == "random":
action_type = random_action(rng)
elif strategy == "model" and model is not None:
action_type, belief = model_action(obs, model, tokenizer, return_belief=True)
beliefs_seen.append(belief)
# Tell the env about the emitted belief so the grader can score
# belief_accuracy. Heuristic / random skip this — they get 0 on
# the belief component, by design.
env.record_belief(belief)
else:
action_type = random_action(rng)
action = RhythmAction(action_type=action_type)
actions_taken.append(action_type.value)
obs = env.step(action)
total_reward += obs.reward
step_rewards.append(obs.reward)
final_score = obs.reward_breakdown.get("final_score", 0.0)
# Adaptation: late-half mean minus early-half mean
half = max(len(step_rewards) // 2, 1)
early = step_rewards[:half]
late = step_rewards[half:]
adaptation = (sum(late) / len(late) - sum(early) / len(early)) if (early and late) else 0.0
# Belief tracking (only for trained model)
final_belief = beliefs_seen[-1] if beliefs_seen else None
belief_mae = None
if final_belief is not None:
belief_mae = sum(abs(b - t) for b, t in zip(final_belief, true_belief)) / 3.0
return {
"profile_name": profile_name,
"profile_mode": profile_mode if profile is None else "discrete",
"strategy": strategy,
"seed": seed,
"final_score": round(final_score, 4),
"total_reward": round(total_reward, 2),
"adaptation": round(adaptation, 3),
"vitality": round(obs.vitality, 2),
"cognition": round(obs.cognition, 2),
"progress": round(obs.progress, 2),
"serenity": round(obs.serenity, 2),
"connection": round(obs.connection, 2),
"actions": actions_taken,
"true_belief": [round(x, 3) for x in true_belief],
"final_belief": [round(x, 3) for x in final_belief] if final_belief is not None else None,
"belief_mae": round(belief_mae, 3) if belief_mae is not None else None,
}
def eval_condition(
name: str,
strategies: list[str],
runs: list[dict],
model=None,
tokenizer=None,
) -> list[dict]:
"""Run an eval condition and print summary."""
print(f"\n{'=' * 60}")
print(f"Condition: {name}")
print(f"{'=' * 60}")
results = []
for strategy in strategies:
print(f"\n Strategy: {strategy.upper()}")
scores = []
adaptations = []
belief_maes = []
for run in runs:
r = run_episode(strategy=strategy, model=model, tokenizer=tokenizer, **run)
results.append({"condition": name, **r})
scores.append(r["final_score"])
adaptations.append(r["adaptation"])
if r["belief_mae"] is not None:
belief_maes.append(r["belief_mae"])
avg_score = sum(scores) / len(scores) if scores else 0.0
avg_adapt = sum(adaptations) / len(adaptations) if adaptations else 0.0
avg_mae = sum(belief_maes) / len(belief_maes) if belief_maes else None
line = f" avg_score={avg_score:.3f} avg_adaptation={avg_adapt:+.3f}"
if avg_mae is not None:
line += f" avg_belief_mae={avg_mae:.3f}"
print(line)
return results
def main():
parser = argparse.ArgumentParser(description="Evaluate RhythmEnv agent (meta-RL eval suite)")
parser.add_argument("--model_path", type=str, default=None,
help="Path to trained model (skip for baseline only)")
parser.add_argument("--num_episodes", type=int, default=5,
help="Episodes per condition per strategy (for discrete: per-profile)")
parser.add_argument("--output_file", type=str, default="eval_results.json")
parser.add_argument("--in_dist_seeds", type=str, default=None,
help="Comma-separated seeds for in-distribution eval")
parser.add_argument("--ood_seeds", type=str, default=None,
help="Comma-separated seeds for OOD eval")
args = parser.parse_args()
in_dist_seeds = (
[int(s) for s in args.in_dist_seeds.split(",")] if args.in_dist_seeds
else IN_DIST_SEEDS_DEFAULT[:args.num_episodes * 2]
)
ood_seeds = (
[int(s) for s in args.ood_seeds.split(",")] if args.ood_seeds
else OOD_SEEDS_DEFAULT[:args.num_episodes * 2]
)
model, tokenizer = None, None
strategies = ["heuristic", "random"]
if args.model_path and os.path.exists(args.model_path):
try:
from unsloth import FastLanguageModel
# max_seq_length=2048 must accommodate: user prompt with 7-step
# history + per-meter anomalies (~900-1200 tokens) PLUS
# max_new_tokens=256 for the CoT response. Earlier value of 768
# silently truncated prompts on the LEFT (kept end of prompt,
# lost system instructions or older meter history), producing
# incoherent model outputs.
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=args.model_path,
load_in_4bit=True,
max_seq_length=2048,
)
FastLanguageModel.for_inference(model)
strategies.append("model")
print(f"Loaded trained model from: {args.model_path}")
except Exception as e:
print(f"Warning: Could not load model: {e}")
print("Running baseline-only evaluation.")
all_results = []
# Condition 1: 3 hardcoded reference profiles (sanity-check the agent
# still handles the named profiles — no longer the primary eval signal).
discrete_runs = [
{"seed": ep, "profile": p}
for p in DISCRETE_PROFILES for ep in range(args.num_episodes)
]
all_results += eval_condition(
"discrete-3-profiles",
strategies, discrete_runs,
model=model, tokenizer=tokenizer,
)
# Condition 2: In-distribution sampled profiles
in_dist_runs = [{"seed": s, "profile_mode": "continuous"} for s in in_dist_seeds]
all_results += eval_condition(
"continuous-in-distribution",
strategies, in_dist_runs,
model=model, tokenizer=tokenizer,
)
# Condition 3: OOD sampled profiles (the meta-learning generalization test)
ood_runs = [{"seed": s, "profile_mode": "continuous"} for s in ood_seeds]
all_results += eval_condition(
"continuous-OOD (generalization)",
strategies, ood_runs,
model=model, tokenizer=tokenizer,
)
# Per-profile breakdown for the 3 reference profiles
print(f"\n{'=' * 70}")
print("DISCRETE-3-PROFILE BREAKDOWN")
print(f"{'=' * 70}")
print(f"{'Profile':<25} ", end="")
for s in strategies:
print(f"{s:>10}", end="")
print()
print("-" * 70)
discrete = [r for r in all_results if r["condition"] == "discrete-3-profiles"]
for profile in DISCRETE_PROFILES:
row = f"{profile:<25} "
for s in strategies:
rs = [r for r in discrete if r["profile_name"] == profile and r["strategy"] == s]
avg = sum(r["final_score"] for r in rs) / len(rs) if rs else 0.0
row += f"{avg:>10.3f}"
print(row)
# Save
with open(args.output_file, "w") as f:
json.dump(all_results, f, indent=2)
print(f"\nResults saved to: {args.output_file}")
if __name__ == "__main__":
main()