"""
Dataset generator for RhythmEnv GRPO training (meta-RL version).
Plays episodes under a continuously-sampled profile per seed and emits
observation prompts at each step, paired with the replay metadata
(seed, step_index, action_history) the reward functions need to
reconstruct env state deterministically.
The system prompt asks for "S M W ACTION_NAME" — three belief digits then
the action. A `hint_fraction` slice of episodes carries a true-belief hint
in the prompt as a curriculum warmup; the rest force pure inference.
"""
import sys
import os
import random
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from models import ActionType, RhythmAction
from server.rhythm_environment import RhythmEnvironment, MAX_STEPS, METERS
SLOT_NAMES = ["Morning", "Afternoon", "Evening", "Night"]
DAY_NAMES = ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"]
SYSTEM_PROMPT = """You are a life-management agent helping a person whose preferences are HIDDEN.
You see 5 life meters and a rolling history of recent steps. The same action
affects different people differently — you must INFER who you're helping from
rewards, meter changes, and per-meter ANOMALY signals.
Each step, do TWO things:
1. Reason briefly about what the observations imply about the person.
Focus on:
- Anomalies (actual delta vs neutral-profile expectation): big positive
social_serenity / connection responses → high S; big morning cognition
gains → high M; productive work giving vitality back → high W
- Current meter state: any meter under 0.15 needs urgent recovery
- What action best fits BOTH the inferred profile and the current state
2. Output your final answer on the LAST line in this exact format:
S M W ACTION_NAME
where S, M, W are belief digits 0-9 (0=low, 9=high) representing your best
estimate of social_pref, morning_pref, work_pref. ACTION_NAME is one of:
DEEP_WORK, ADMIN_WORK, LEARN, SLEEP, EXERCISE, MEDITATE, FAMILY_TIME,
SOCIALIZE, ME_TIME, BINGE_WATCH
Wrap your reasoning in ... tags. Keep reasoning under
120 tokens. The final answer line MUST be the last line of your response.
Belief→action quick reference:
- High S (extrovert): SOCIALIZE, FAMILY_TIME boost connection cheaply
- High M (morning person): DEEP_WORK / LEARN in early slots gets bonus cognition
- High W (workaholic): DEEP_WORK, LEARN drive progress and may energize
- Low S (introvert): MEDITATE, ME_TIME for solo recharge; avoid SOCIALIZE
- Low M (night owl): DEEP_WORK / LEARN in evening/night slots
- Watch crashes: any meter under 0.10 = -0.30 penalty per crashed meter
- Connection decays passively — actively maintain via SOCIALIZE/FAMILY_TIME
- Don't repeat the same action 3+ times in a row — repetition penalty applies
Strategy: probe varied actions in the first ~5 steps to gather profile evidence,
then exploit your sharpened belief by picking actions that match the inferred
profile + current meter state.
Example output:
Last step's socialize gave V-0.12 (anom -0.06, much worse than neutral) — high
social drain, suggests low S. Morning DEEP_WORK earlier gave bonus cognition
(anom +0.04) → high M. Vitality at 0.6 still ok, serenity dropping. With low S +
high M, MEDITATE is the recovery play that fits.
2 8 5 MEDITATE"""
def format_observation_prompt(obs, profile_hint: dict | None = None) -> str:
"""Format an observation into a user prompt for the LLM.
If profile_hint is provided (curriculum's "visible" phase), include it in
the prompt so the agent learns the *skill* of using profile signals
before having to infer them from scratch.
"""
day_name = DAY_NAMES[obs.day] if obs.day < 7 else f"Day {obs.day}"
slot_name = SLOT_NAMES[obs.slot] if obs.slot < 4 else f"Slot {obs.slot}"
event_str = f"\nActive event: {obs.active_event}" if obs.active_event else ""
history_lines = []
for h in (obs.step_history or [])[-5:]: # last 5 only to fit prompt budget
# Per-meter anomalies (actual_delta − expected_under_neutral_profile)
# are the cleanest profile-inference signal — they show how this person's
# response DEVIATES from the average person. Surfacing them here in the
# prompt is what gives the agent a fingerprint to learn from.
anom_str = (
f" [anom V{h.vitality_anomaly:+.2f} C{h.cognition_anomaly:+.2f} "
f"P{h.progress_anomaly:+.2f} S{h.serenity_anomaly:+.2f} "
f"Cn{h.connection_anomaly:+.2f}]"
)
history_lines.append(
f" step {h.step}: {h.action} -> reward {h.reward:+.2f} "
f"(V{h.vitality_delta:+.2f} C{h.cognition_delta:+.2f} "
f"P{h.progress_delta:+.2f} S{h.serenity_delta:+.2f} Cn{h.connection_delta:+.2f})"
f"{anom_str}"
)
history_str = ""
if history_lines:
history_str = (
"\n\nRecent history (anom = how this person deviated from neutral baseline):\n"
+ "\n".join(history_lines)
)
hint_str = ""
if profile_hint is not None:
hint_str = (
f"\n\nKnown about this person (training hint):\n"
f" social_pref={profile_hint['social_pref']:.2f}, "
f"morning_pref={profile_hint['morning_pref']:.2f}, "
f"work_pref={profile_hint['work_pref']:.2f}"
)
return (
f"Step: {obs.timestep}/{MAX_STEPS} ({day_name} {slot_name})\n"
f"Remaining steps: {obs.remaining_steps}\n\n"
f"Meters:\n"
f" Vitality: {obs.vitality:.2f}\n"
f" Cognition: {obs.cognition:.2f}\n"
f" Progress: {obs.progress:.2f}\n"
f" Serenity: {obs.serenity:.2f}\n"
f" Connection: {obs.connection:.2f}"
f"{event_str}"
f"{history_str}"
f"{hint_str}\n\n"
f"Output your belief, then your action (format: S M W ACTION_NAME):"
)
def generate_episode_samples(
seed: int,
strategy: str = "random",
profile_mode: str = "continuous",
show_profile_hint: bool = False,
) -> list:
"""Play one episode and return a list of training samples.
Each sample includes the prompt + replay metadata (seed, step_index,
action_history, profile_mode) so reward functions can deterministically
reconstruct the env state.
Args:
seed: Episode seed (also determines profile when profile_mode=continuous).
strategy: "random" or "heuristic" — used to roll out the episode for
state diversity. The agent's training generations replace these
actions; we only need the prefix history for replay.
profile_mode: "continuous" (sampled per seed) or "discrete" (1 of 3
hardcoded profiles).
show_profile_hint: If True, include the true belief vector in the prompt.
Use during the curriculum's "visible" warmup phase.
"""
env = RhythmEnvironment()
obs = env.reset(seed=seed, profile_mode=profile_mode)
profile_hint = env.get_profile_hint() if show_profile_hint else None
rng = random.Random(seed + 1000)
actions_taken = []
samples = []
all_actions = list(ActionType)
for step in range(MAX_STEPS):
if obs.done:
break
prompt = format_observation_prompt(obs, profile_hint=profile_hint)
samples.append({
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt},
],
"seed": seed,
"step_index": step,
"action_history": list(actions_taken),
"profile_mode": profile_mode,
"show_profile_hint": show_profile_hint,
})
if strategy == "random":
action_type = rng.choice(all_actions)
elif strategy == "heuristic":
action_type = heuristic_action(obs)
else:
action_type = rng.choice(all_actions)
action = RhythmAction(action_type=action_type)
actions_taken.append(action_type.value)
obs = env.step(action)
return samples
def heuristic_action(obs) -> ActionType:
"""Priority-based heuristic baseline (profile-blind).
Used both during dataset generation (to roll out diverse states) and
by inference_eval as the heuristic baseline strategy.
"""
slot = obs.slot
v, c, p, s, cn = obs.vitality, obs.cognition, obs.progress, obs.serenity, obs.connection
if v < 0.15:
return ActionType.SLEEP
if s < 0.15:
return ActionType.MEDITATE
if cn < 0.15:
return ActionType.FAMILY_TIME
if slot == 3:
return ActionType.SLEEP
if slot == 0:
return ActionType.DEEP_WORK if (v > 0.4 and c > 0.3) else ActionType.EXERCISE
if slot == 1:
if cn < 0.3:
return ActionType.FAMILY_TIME
if p < 0.3 and v > 0.3:
return ActionType.LEARN
return ActionType.ADMIN_WORK
if cn < 0.4:
return ActionType.SOCIALIZE
if s < 0.5:
return ActionType.ME_TIME
return ActionType.MEDITATE
def generate_dataset(
num_episodes: int = 200,
strategy: str = "mixed",
max_samples: int = 2000,
profile_mode: str = "continuous",
hint_fraction: float = 0.2,
) -> list:
"""Generate a training dataset by playing multiple episodes.
Curriculum is baked into the dataset: hint_fraction of samples have the
true profile visible (visible-phase warmup). After shuffle, GRPOTrainer
sees a mix early on; we can sort to put hint samples first if needed.
Args:
num_episodes: Number of episodes to play.
strategy: "random", "heuristic", or "mixed" (alternating).
max_samples: Maximum samples to return.
profile_mode: "continuous" (default, meta-RL) or "discrete" (3 profiles).
hint_fraction: Fraction of episodes to play with profile hint visible.
"""
all_samples = []
n_hint_episodes = int(num_episodes * hint_fraction)
for i in range(num_episodes):
seed = i
if strategy == "mixed":
s = "heuristic" if i % 2 == 0 else "random"
else:
s = strategy
show_hint = i < n_hint_episodes
episode_samples = generate_episode_samples(
seed=seed,
strategy=s,
profile_mode=profile_mode,
show_profile_hint=show_hint,
)
all_samples.extend(episode_samples)
if len(all_samples) >= max_samples:
break
# Shuffle (curriculum is per-sample via show_profile_hint flag, not order)
random.shuffle(all_samples)
all_samples = all_samples[:max_samples]
n_hint = sum(1 for s in all_samples if s["show_profile_hint"])
print(
f"Generated {len(all_samples)} samples from {min(i+1, num_episodes)} episodes "
f"({n_hint} with profile hint, {len(all_samples) - n_hint} without)"
)
return all_samples
if __name__ == "__main__":
samples = generate_dataset(num_episodes=20, strategy="mixed", max_samples=80, hint_fraction=0.5)
print(f"\nFirst sample (with hint):")
hinted = next((s for s in samples if s["show_profile_hint"]), None)
if hinted:
print(hinted["prompt"][1]["content"])
print(f"\nseed={hinted['seed']}, step={hinted['step_index']}, mode={hinted['profile_mode']}")
print(f"\nFirst sample (without hint):")
plain = next((s for s in samples if not s["show_profile_hint"]), None)
if plain:
print(plain["prompt"][1]["content"])