"""
Generate teacher trajectories for Algorithm Distillation.
For each seed, plays one full RhythmEnv episode where the action at each step
is chosen by a teacher LLM (gpt-5.4 via Azure OpenAI). The teacher is prompted
to emit `...` followed by `S M W ACTION_NAME` on a final
line. We parse the answer line, step the env, save the full (prompt, response,
action, reward) tuple to JSONL, and aggregate per-episode metrics for gating.
Required env vars (no secrets in code):
AZURE_OPENAI_ENDPOINT e.g. https://metahackathon-resource.cognitiveservices.azure.com/
AZURE_OPENAI_API_KEY your Azure OpenAI key (do NOT paste in chat)
AZURE_OPENAI_DEPLOYMENT the deployment name you chose, e.g. gpt-5.4
AZURE_OPENAI_API_VERSION e.g. 2024-12-01-preview (default if unset)
Usage from rhythm_env root:
# Stage 1a: 30-episode validation (~$3-5)
python scripts/generate_teacher_trajectories.py \
--seeds 0-29 \
--output data/teacher_30ep_validation.jsonl \
--concurrency 3
# Stage 1b: scale to 150 episodes (~$15-20)
python scripts/generate_teacher_trajectories.py \
--seeds 0-99 \
--output data/teacher_150ep_indist.jsonl \
--concurrency 5
python scripts/generate_teacher_trajectories.py \
--seeds 10000-10049 \
--output data/teacher_150ep_ood.jsonl \
--concurrency 5
The script prints PASS/FAIL gate verdicts at the end so you can decide whether
to scale or fix the teacher prompt before spending more.
"""
import argparse
import asyncio
import json
import os
import re
import sys
import time
from collections import Counter
from pathlib import Path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# Load .env (repo root) before reading os.environ so credentials don't have
# to be exported in the shell. The .env file is in .gitignore.
try:
from dotenv import load_dotenv
_ENV_PATH = Path(__file__).resolve().parent.parent / ".env"
if _ENV_PATH.exists():
load_dotenv(_ENV_PATH)
except ImportError:
pass # dotenv not installed → fall back to whatever's in the shell
from openai import AsyncAzureOpenAI
from openai import APIError, RateLimitError, APIConnectionError, APITimeoutError
from models import ActionType, RhythmAction
from server.rhythm_environment import (
MAX_STEPS,
RhythmEnvironment,
)
from training.dataset import format_observation_prompt
# ---------------------------------------------------------------------------
# Teacher system prompt
# ---------------------------------------------------------------------------
# The student will eventually be SFT'd to match this contract: emit a
# ... block then a final answer line `S M W ACTION_NAME`.
# Keep this in sync with whatever SYSTEM_PROMPT the SFT'd student will use.
TEACHER_SYSTEM_PROMPT = """You are a life-management agent helping a person whose preferences are HIDDEN.
You see 5 life meters and a rolling history of recent steps. The same action
affects different people differently — you must INFER who you're helping from
rewards, meter changes, and per-meter ANOMALY signals.
Each step, do TWO things:
1. Reason briefly about what the observations imply about the person.
Focus on:
- Anomalies (actual delta vs neutral-profile expectation): big positive
social_serenity / connection responses → high S; big morning cognition
gains → high M; productive work giving vitality back → high W
- Current meter state: any meter under 0.15 needs urgent recovery
- What action best fits BOTH the inferred profile and the current state
2. Output your final answer on the LAST line in this exact format:
S M W ACTION_NAME
where S, M, W are belief digits 0-9 (0=low, 9=high) representing your best
estimate of social_pref, morning_pref, work_pref. ACTION_NAME is one of:
DEEP_WORK, ADMIN_WORK, LEARN, SLEEP, EXERCISE, MEDITATE, FAMILY_TIME,
SOCIALIZE, ME_TIME, BINGE_WATCH
Wrap your reasoning in ... tags. Keep reasoning under
120 tokens. The final answer line MUST be the last line of your response.
Belief→action quick reference:
- High S (extrovert): SOCIALIZE, FAMILY_TIME boost connection cheaply
- High M (morning person): DEEP_WORK / LEARN in early slots gets bonus cognition
- High W (workaholic): DEEP_WORK, LEARN drive progress and may energize
- Low S (introvert): MEDITATE, ME_TIME for solo recharge; avoid SOCIALIZE
- Low M (night owl): DEEP_WORK / LEARN in evening/night slots
- Watch crashes: any meter under 0.10 = -0.30 penalty per crashed meter
- Connection decays passively — actively maintain via SOCIALIZE/FAMILY_TIME
- Don't repeat the same action 3+ times in a row — repetition penalty applies
Strategy: probe varied actions in the first ~5 steps to gather profile evidence,
then exploit your sharpened belief by picking actions that match the inferred
profile + current meter state.
Example output:
Last step's socialize gave V-0.12 (anom -0.06, much worse than neutral) — high
social drain, suggests low S. Morning DEEP_WORK earlier gave bonus cognition
(anom +0.04) → high M. Vitality at 0.6 still ok, serenity dropping. With low S +
high M, MEDITATE is the recovery play that fits.
2 8 5 MEDITATE"""
# ---------------------------------------------------------------------------
# Answer parsing — find the LAST `S M W ACTION_NAME` pattern in the response
# ---------------------------------------------------------------------------
VALID_ACTIONS = [at.value.upper() for at in ActionType]
ANSWER_PATTERN = re.compile(
r'(\d)\s+(\d)\s+(\d)\s+(' + '|'.join(VALID_ACTIONS) + r')\b',
re.IGNORECASE,
)
def parse_teacher_response(text: str):
"""Extract (action_type, belief_vector, raw_match) from teacher output.
Returns (None, None, None) if no answer line is parseable.
"""
if not text:
return None, None, None
matches = list(ANSWER_PATTERN.finditer(text))
if not matches:
return None, None, None
last = matches[-1]
s, m, w, action_name = last.groups()
try:
belief = [int(s) / 9.0, int(m) / 9.0, int(w) / 9.0]
action = ActionType(action_name.lower())
return action, belief, last.group(0)
except (ValueError, KeyError):
return None, None, None
# ---------------------------------------------------------------------------
# Async API calls with retry
# ---------------------------------------------------------------------------
async def call_teacher(
client: AsyncAzureOpenAI,
deployment: str,
user_prompt: str,
temperature: float = 0.5,
max_completion_tokens: int = 400,
max_retries: int = 4,
) -> str:
"""Call the teacher with retries on transient errors. Returns response text."""
last_err: Exception | None = None
for attempt in range(max_retries):
try:
resp = await client.chat.completions.create(
model=deployment,
messages=[
{"role": "system", "content": TEACHER_SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
],
temperature=temperature,
max_completion_tokens=max_completion_tokens,
)
return resp.choices[0].message.content or ""
except (RateLimitError, APIConnectionError, APITimeoutError) as e:
last_err = e
wait = min(60, 2 ** attempt)
await asyncio.sleep(wait)
except APIError as e:
# Non-transient API error — log and bail (don't waste retries)
last_err = e
break
raise RuntimeError(f"Teacher call failed after {max_retries} retries: {last_err}")
# ---------------------------------------------------------------------------
# Episode rollout
# ---------------------------------------------------------------------------
async def play_episode(
client: AsyncAzureOpenAI,
deployment: str,
seed: int,
) -> tuple[list[dict], dict]:
"""Run a full episode with the teacher. Returns (per-step rows, summary)."""
env = RhythmEnvironment()
obs = env.reset(seed=seed)
true_belief = env.get_belief_target()
profile_name = env.state.profile_name
step_rows: list[dict] = []
actions_taken: list[str] = []
rewards: list[float] = []
final_belief: list[float] | None = None
for step_idx in range(MAX_STEPS):
if obs.done:
break
user_prompt = format_observation_prompt(obs)
try:
teacher_resp = await call_teacher(client, deployment, user_prompt)
except RuntimeError as e:
# Hard failure — abort this episode rather than corrupt the dataset
return step_rows, {
"seed": seed,
"profile_name": profile_name,
"true_belief": [round(x, 3) for x in true_belief],
"final_belief": [round(x, 3) for x in final_belief] if final_belief else None,
"belief_mae": None,
"final_score": 0.0,
"total_reward": round(sum(rewards), 2),
"n_steps": len(step_rows),
"actions": actions_taken,
"action_distribution": dict(Counter(actions_taken)),
"n_parse_failures": sum(1 for r in step_rows if r["parse_failed"]),
"aborted": True,
"error": str(e),
}
action, belief, raw_match = parse_teacher_response(teacher_resp)
parse_failed = action is None
if parse_failed:
# Fallback: SLEEP keeps the episode alive without skewing exploration
action = ActionType.SLEEP
belief = [0.5, 0.5, 0.5]
else:
final_belief = belief
# Tell the env about the emitted belief so the grader's belief_accuracy
# component scores it. Without this call, final_score logged below is
# artificially low (belief component scores 0 even when the teacher
# actually emitted a belief).
env.record_belief(belief)
rhythm_action = RhythmAction(action_type=action)
actions_taken.append(action.value)
next_obs = env.step(rhythm_action)
rewards.append(next_obs.reward)
step_rows.append({
"seed": seed,
"step": step_idx,
"profile_name": profile_name,
"user_prompt": user_prompt,
"teacher_response": teacher_resp,
"parsed_action": action.value,
"parsed_belief": belief,
"answer_match": raw_match,
"env_reward": round(next_obs.reward, 4),
"parse_failed": parse_failed,
"true_belief": [round(x, 3) for x in true_belief],
})
obs = next_obs
final_score = obs.reward_breakdown.get("final_score", 0.0)
belief_mae = (
sum(abs(b - t) for b, t in zip(final_belief, true_belief)) / 3.0
if final_belief is not None else None
)
return step_rows, {
"seed": seed,
"profile_name": profile_name,
"true_belief": [round(x, 3) for x in true_belief],
"final_belief": [round(x, 3) for x in final_belief] if final_belief else None,
"belief_mae": round(belief_mae, 4) if belief_mae is not None else None,
"final_score": round(final_score, 4),
"total_reward": round(sum(rewards), 2),
"n_steps": len(step_rows),
"actions": actions_taken,
"action_distribution": dict(Counter(actions_taken)),
"n_parse_failures": sum(1 for r in step_rows if r["parse_failed"]),
"aborted": False,
}
# ---------------------------------------------------------------------------
# Resume helpers
# ---------------------------------------------------------------------------
def already_completed_seeds(jsonl_path: Path) -> set[int]:
"""Seeds whose final step (MAX_STEPS - 1 = 27) is already in the file."""
if not jsonl_path.exists():
return set()
seed_max_step: dict[int, int] = {}
with open(jsonl_path) as f:
for line in f:
try:
row = json.loads(line)
except json.JSONDecodeError:
continue
sd = row.get("seed")
st = row.get("step", -1)
if sd is None:
continue
seed_max_step[sd] = max(seed_max_step.get(sd, -1), st)
return {s for s, mx in seed_max_step.items() if mx >= MAX_STEPS - 1}
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def parse_seed_arg(seed_str: str) -> list[int]:
if "-" in seed_str and "," not in seed_str:
lo, hi = seed_str.split("-")
return list(range(int(lo), int(hi) + 1))
return [int(s.strip()) for s in seed_str.split(",") if s.strip()]
async def main() -> None:
parser = argparse.ArgumentParser(description=__doc__.split("\n\n")[0])
parser.add_argument("--seeds", type=str, required=True,
help="Seed range '0-29' or comma list '0,1,5'")
parser.add_argument("--output", type=str, required=True,
help="Output JSONL path for per-step trajectories")
parser.add_argument("--summary", type=str, default=None,
help="Output JSON path for episode summaries (default: