rhythm_env / scripts /generate_teacher_trajectories.py
InosLihka's picture
Algorithm Distillation: grader v2 with belief_accuracy + SFT pipeline
ece0bbe
"""
Generate teacher trajectories for Algorithm Distillation.
For each seed, plays one full RhythmEnv episode where the action at each step
is chosen by a teacher LLM (gpt-5.4 via Azure OpenAI). The teacher is prompted
to emit `<reasoning>...</reasoning>` followed by `S M W ACTION_NAME` on a final
line. We parse the answer line, step the env, save the full (prompt, response,
action, reward) tuple to JSONL, and aggregate per-episode metrics for gating.
Required env vars (no secrets in code):
AZURE_OPENAI_ENDPOINT e.g. https://metahackathon-resource.cognitiveservices.azure.com/
AZURE_OPENAI_API_KEY your Azure OpenAI key (do NOT paste in chat)
AZURE_OPENAI_DEPLOYMENT the deployment name you chose, e.g. gpt-5.4
AZURE_OPENAI_API_VERSION e.g. 2024-12-01-preview (default if unset)
Usage from rhythm_env root:
# Stage 1a: 30-episode validation (~$3-5)
python scripts/generate_teacher_trajectories.py \
--seeds 0-29 \
--output data/teacher_30ep_validation.jsonl \
--concurrency 3
# Stage 1b: scale to 150 episodes (~$15-20)
python scripts/generate_teacher_trajectories.py \
--seeds 0-99 \
--output data/teacher_150ep_indist.jsonl \
--concurrency 5
python scripts/generate_teacher_trajectories.py \
--seeds 10000-10049 \
--output data/teacher_150ep_ood.jsonl \
--concurrency 5
The script prints PASS/FAIL gate verdicts at the end so you can decide whether
to scale or fix the teacher prompt before spending more.
"""
import argparse
import asyncio
import json
import os
import re
import sys
import time
from collections import Counter
from pathlib import Path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# Load .env (repo root) before reading os.environ so credentials don't have
# to be exported in the shell. The .env file is in .gitignore.
try:
from dotenv import load_dotenv
_ENV_PATH = Path(__file__).resolve().parent.parent / ".env"
if _ENV_PATH.exists():
load_dotenv(_ENV_PATH)
except ImportError:
pass # dotenv not installed → fall back to whatever's in the shell
from openai import AsyncAzureOpenAI
from openai import APIError, RateLimitError, APIConnectionError, APITimeoutError
from models import ActionType, RhythmAction
from server.rhythm_environment import (
MAX_STEPS,
RhythmEnvironment,
)
from training.dataset import format_observation_prompt
# ---------------------------------------------------------------------------
# Teacher system prompt
# ---------------------------------------------------------------------------
# The student will eventually be SFT'd to match this contract: emit a
# <reasoning>...</reasoning> block then a final answer line `S M W ACTION_NAME`.
# Keep this in sync with whatever SYSTEM_PROMPT the SFT'd student will use.
TEACHER_SYSTEM_PROMPT = """You are a life-management agent helping a person whose preferences are HIDDEN.
You see 5 life meters and a rolling history of recent steps. The same action
affects different people differently — you must INFER who you're helping from
rewards, meter changes, and per-meter ANOMALY signals.
Each step, do TWO things:
1. Reason briefly about what the observations imply about the person.
Focus on:
- Anomalies (actual delta vs neutral-profile expectation): big positive
social_serenity / connection responses → high S; big morning cognition
gains → high M; productive work giving vitality back → high W
- Current meter state: any meter under 0.15 needs urgent recovery
- What action best fits BOTH the inferred profile and the current state
2. Output your final answer on the LAST line in this exact format:
S M W ACTION_NAME
where S, M, W are belief digits 0-9 (0=low, 9=high) representing your best
estimate of social_pref, morning_pref, work_pref. ACTION_NAME is one of:
DEEP_WORK, ADMIN_WORK, LEARN, SLEEP, EXERCISE, MEDITATE, FAMILY_TIME,
SOCIALIZE, ME_TIME, BINGE_WATCH
Wrap your reasoning in <reasoning>...</reasoning> tags. Keep reasoning under
120 tokens. The final answer line MUST be the last line of your response.
Belief→action quick reference:
- High S (extrovert): SOCIALIZE, FAMILY_TIME boost connection cheaply
- High M (morning person): DEEP_WORK / LEARN in early slots gets bonus cognition
- High W (workaholic): DEEP_WORK, LEARN drive progress and may energize
- Low S (introvert): MEDITATE, ME_TIME for solo recharge; avoid SOCIALIZE
- Low M (night owl): DEEP_WORK / LEARN in evening/night slots
- Watch crashes: any meter under 0.10 = -0.30 penalty per crashed meter
- Connection decays passively — actively maintain via SOCIALIZE/FAMILY_TIME
- Don't repeat the same action 3+ times in a row — repetition penalty applies
Strategy: probe varied actions in the first ~5 steps to gather profile evidence,
then exploit your sharpened belief by picking actions that match the inferred
profile + current meter state.
Example output:
<reasoning>
Last step's socialize gave V-0.12 (anom -0.06, much worse than neutral) — high
social drain, suggests low S. Morning DEEP_WORK earlier gave bonus cognition
(anom +0.04) → high M. Vitality at 0.6 still ok, serenity dropping. With low S +
high M, MEDITATE is the recovery play that fits.
</reasoning>
2 8 5 MEDITATE"""
# ---------------------------------------------------------------------------
# Answer parsing — find the LAST `S M W ACTION_NAME` pattern in the response
# ---------------------------------------------------------------------------
VALID_ACTIONS = [at.value.upper() for at in ActionType]
ANSWER_PATTERN = re.compile(
r'(\d)\s+(\d)\s+(\d)\s+(' + '|'.join(VALID_ACTIONS) + r')\b',
re.IGNORECASE,
)
def parse_teacher_response(text: str):
"""Extract (action_type, belief_vector, raw_match) from teacher output.
Returns (None, None, None) if no answer line is parseable.
"""
if not text:
return None, None, None
matches = list(ANSWER_PATTERN.finditer(text))
if not matches:
return None, None, None
last = matches[-1]
s, m, w, action_name = last.groups()
try:
belief = [int(s) / 9.0, int(m) / 9.0, int(w) / 9.0]
action = ActionType(action_name.lower())
return action, belief, last.group(0)
except (ValueError, KeyError):
return None, None, None
# ---------------------------------------------------------------------------
# Async API calls with retry
# ---------------------------------------------------------------------------
async def call_teacher(
client: AsyncAzureOpenAI,
deployment: str,
user_prompt: str,
temperature: float = 0.5,
max_completion_tokens: int = 400,
max_retries: int = 4,
) -> str:
"""Call the teacher with retries on transient errors. Returns response text."""
last_err: Exception | None = None
for attempt in range(max_retries):
try:
resp = await client.chat.completions.create(
model=deployment,
messages=[
{"role": "system", "content": TEACHER_SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
],
temperature=temperature,
max_completion_tokens=max_completion_tokens,
)
return resp.choices[0].message.content or ""
except (RateLimitError, APIConnectionError, APITimeoutError) as e:
last_err = e
wait = min(60, 2 ** attempt)
await asyncio.sleep(wait)
except APIError as e:
# Non-transient API error — log and bail (don't waste retries)
last_err = e
break
raise RuntimeError(f"Teacher call failed after {max_retries} retries: {last_err}")
# ---------------------------------------------------------------------------
# Episode rollout
# ---------------------------------------------------------------------------
async def play_episode(
client: AsyncAzureOpenAI,
deployment: str,
seed: int,
) -> tuple[list[dict], dict]:
"""Run a full episode with the teacher. Returns (per-step rows, summary)."""
env = RhythmEnvironment()
obs = env.reset(seed=seed)
true_belief = env.get_belief_target()
profile_name = env.state.profile_name
step_rows: list[dict] = []
actions_taken: list[str] = []
rewards: list[float] = []
final_belief: list[float] | None = None
for step_idx in range(MAX_STEPS):
if obs.done:
break
user_prompt = format_observation_prompt(obs)
try:
teacher_resp = await call_teacher(client, deployment, user_prompt)
except RuntimeError as e:
# Hard failure — abort this episode rather than corrupt the dataset
return step_rows, {
"seed": seed,
"profile_name": profile_name,
"true_belief": [round(x, 3) for x in true_belief],
"final_belief": [round(x, 3) for x in final_belief] if final_belief else None,
"belief_mae": None,
"final_score": 0.0,
"total_reward": round(sum(rewards), 2),
"n_steps": len(step_rows),
"actions": actions_taken,
"action_distribution": dict(Counter(actions_taken)),
"n_parse_failures": sum(1 for r in step_rows if r["parse_failed"]),
"aborted": True,
"error": str(e),
}
action, belief, raw_match = parse_teacher_response(teacher_resp)
parse_failed = action is None
if parse_failed:
# Fallback: SLEEP keeps the episode alive without skewing exploration
action = ActionType.SLEEP
belief = [0.5, 0.5, 0.5]
else:
final_belief = belief
# Tell the env about the emitted belief so the grader's belief_accuracy
# component scores it. Without this call, final_score logged below is
# artificially low (belief component scores 0 even when the teacher
# actually emitted a belief).
env.record_belief(belief)
rhythm_action = RhythmAction(action_type=action)
actions_taken.append(action.value)
next_obs = env.step(rhythm_action)
rewards.append(next_obs.reward)
step_rows.append({
"seed": seed,
"step": step_idx,
"profile_name": profile_name,
"user_prompt": user_prompt,
"teacher_response": teacher_resp,
"parsed_action": action.value,
"parsed_belief": belief,
"answer_match": raw_match,
"env_reward": round(next_obs.reward, 4),
"parse_failed": parse_failed,
"true_belief": [round(x, 3) for x in true_belief],
})
obs = next_obs
final_score = obs.reward_breakdown.get("final_score", 0.0)
belief_mae = (
sum(abs(b - t) for b, t in zip(final_belief, true_belief)) / 3.0
if final_belief is not None else None
)
return step_rows, {
"seed": seed,
"profile_name": profile_name,
"true_belief": [round(x, 3) for x in true_belief],
"final_belief": [round(x, 3) for x in final_belief] if final_belief else None,
"belief_mae": round(belief_mae, 4) if belief_mae is not None else None,
"final_score": round(final_score, 4),
"total_reward": round(sum(rewards), 2),
"n_steps": len(step_rows),
"actions": actions_taken,
"action_distribution": dict(Counter(actions_taken)),
"n_parse_failures": sum(1 for r in step_rows if r["parse_failed"]),
"aborted": False,
}
# ---------------------------------------------------------------------------
# Resume helpers
# ---------------------------------------------------------------------------
def already_completed_seeds(jsonl_path: Path) -> set[int]:
"""Seeds whose final step (MAX_STEPS - 1 = 27) is already in the file."""
if not jsonl_path.exists():
return set()
seed_max_step: dict[int, int] = {}
with open(jsonl_path) as f:
for line in f:
try:
row = json.loads(line)
except json.JSONDecodeError:
continue
sd = row.get("seed")
st = row.get("step", -1)
if sd is None:
continue
seed_max_step[sd] = max(seed_max_step.get(sd, -1), st)
return {s for s, mx in seed_max_step.items() if mx >= MAX_STEPS - 1}
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def parse_seed_arg(seed_str: str) -> list[int]:
if "-" in seed_str and "," not in seed_str:
lo, hi = seed_str.split("-")
return list(range(int(lo), int(hi) + 1))
return [int(s.strip()) for s in seed_str.split(",") if s.strip()]
async def main() -> None:
parser = argparse.ArgumentParser(description=__doc__.split("\n\n")[0])
parser.add_argument("--seeds", type=str, required=True,
help="Seed range '0-29' or comma list '0,1,5'")
parser.add_argument("--output", type=str, required=True,
help="Output JSONL path for per-step trajectories")
parser.add_argument("--summary", type=str, default=None,
help="Output JSON path for episode summaries (default: <output>.summary.json)")
parser.add_argument("--concurrency", type=int, default=3,
help="Episodes to run concurrently (default 3; 1500 RPM allows up to ~5)")
parser.add_argument("--temperature", type=float, default=0.5,
help="Teacher sampling temperature (default 0.5; lower = more consistent)")
parser.add_argument("--no-resume", action="store_true",
help="Do not skip seeds already in the output file")
args = parser.parse_args()
seeds = parse_seed_arg(args.seeds)
output_path = Path(args.output)
summary_path = Path(args.summary) if args.summary else output_path.with_suffix(".summary.json")
output_path.parent.mkdir(parents=True, exist_ok=True)
if not args.no_resume:
completed = already_completed_seeds(output_path)
if completed:
print(f"Resume: {len(completed)} seeds already complete; "
f"{len(seeds) - len(completed & set(seeds))} remaining of {len(seeds)}")
seeds = [s for s in seeds if s not in completed]
# Azure config (read from env so secrets never touch the repo)
try:
endpoint = os.environ["AZURE_OPENAI_ENDPOINT"]
api_key = os.environ["AZURE_OPENAI_API_KEY"]
deployment = os.environ["AZURE_OPENAI_DEPLOYMENT"]
except KeyError as e:
sys.exit(f"ERROR: missing env var {e}. Set AZURE_OPENAI_ENDPOINT, "
f"AZURE_OPENAI_API_KEY, AZURE_OPENAI_DEPLOYMENT.")
api_version = os.environ.get("AZURE_OPENAI_API_VERSION", "2024-12-01-preview")
print(f"Endpoint: {endpoint}")
print(f"Deployment: {deployment}")
print(f"API version: {api_version}")
print(f"Seeds: {len(seeds)} (concurrency={args.concurrency}, temp={args.temperature})")
print(f"Output: {output_path}")
print(f"Summary: {summary_path}")
print()
if not seeds:
print("No seeds to process. Exiting.")
return
client = AsyncAzureOpenAI(
azure_endpoint=endpoint,
api_key=api_key,
api_version=api_version,
)
sem = asyncio.Semaphore(args.concurrency)
file_lock = asyncio.Lock()
summaries: list[dict] = []
async def run_one(seed: int) -> dict | None:
async with sem:
t0 = time.time()
print(f" [seed {seed}] starting", flush=True)
try:
step_rows, summary = await play_episode(client, deployment, seed)
except Exception as e:
print(f" [seed {seed}] CRASHED: {e}", flush=True)
return None
# Append per-step rows atomically (prevents interleaved writes)
async with file_lock:
with open(output_path, "a") as f:
for row in step_rows:
f.write(json.dumps(row) + "\n")
dt = time.time() - t0
mae_str = f"{summary['belief_mae']:.3f}" if summary['belief_mae'] is not None else "n/a"
print(f" [seed {seed}] done in {dt:.1f}s: "
f"final={summary['final_score']:.3f} mae={mae_str} "
f"unique_actions={len(summary['action_distribution'])} "
f"parse_fails={summary['n_parse_failures']}", flush=True)
return summary
tasks = [run_one(s) for s in seeds]
results = await asyncio.gather(*tasks)
summaries = [r for r in results if r is not None]
# Merge with any prior summaries (for resume)
prior_summaries: list[dict] = []
if summary_path.exists() and not args.no_resume:
try:
with open(summary_path) as f:
prior_summaries = json.load(f).get("episodes", [])
except (json.JSONDecodeError, KeyError):
prior_summaries = []
seen = {s["seed"] for s in summaries}
summaries = summaries + [s for s in prior_summaries if s["seed"] not in seen]
# Aggregate
n = len(summaries)
if n == 0:
print("No episodes completed.")
return
valid = [s for s in summaries if not s.get("aborted")]
avg_score = sum(s["final_score"] for s in valid) / max(len(valid), 1)
valid_mae = [s["belief_mae"] for s in valid if s["belief_mae"] is not None]
avg_mae = sum(valid_mae) / len(valid_mae) if valid_mae else None
all_actions: Counter = Counter()
for s in valid:
all_actions.update(s["action_distribution"])
n_unique = len(all_actions)
n_parse_fails = sum(s["n_parse_failures"] for s in valid)
n_aborted = sum(1 for s in summaries if s.get("aborted"))
summary_blob = {
"n_episodes": n,
"n_aborted": n_aborted,
"avg_final_score": round(avg_score, 4),
"avg_belief_mae": round(avg_mae, 4) if avg_mae is not None else None,
"n_unique_actions_overall": n_unique,
"action_distribution_overall": dict(all_actions),
"n_parse_failures_total": n_parse_fails,
"deployment": deployment,
"api_version": api_version,
"episodes": summaries,
}
with open(summary_path, "w") as f:
json.dump(summary_blob, f, indent=2)
# Gates
BAR_HEURISTIC = 0.587
BAR_GATE_SCORE = 0.65
BAR_GATE_MAE = 0.20
BAR_GATE_ACTIONS = 6
print()
print("=" * 72)
print("BATCH SUMMARY")
print("=" * 72)
print(f"Episodes completed: {n} (aborted: {n_aborted})")
print(f"Avg final_score: {avg_score:.4f} "
f"(heuristic baseline: {BAR_HEURISTIC}, random: 0.516)")
if avg_mae is not None:
print(f"Avg belief MAE: {avg_mae:.4f} (lower is better)")
print(f"Unique actions: {n_unique} of 10")
print(f"Parse failures: {n_parse_fails} (across all step calls)")
print()
print("VALIDATION GATES:")
g_score = avg_score >= BAR_GATE_SCORE
g_mae = avg_mae is not None and avg_mae < BAR_GATE_MAE
g_actions = n_unique >= BAR_GATE_ACTIONS
g_parse = n_parse_fails < 0.05 * n * MAX_STEPS # < 5% parse failure rate
print(f" [{'PASS' if g_score else 'FAIL'}] avg_final_score >= {BAR_GATE_SCORE}: "
f"{avg_score:.3f}")
mae_disp = f"{avg_mae:.3f}" if avg_mae is not None else "n/a"
print(f" [{'PASS' if g_mae else 'FAIL'}] avg_belief_mae < {BAR_GATE_MAE}: {mae_disp}")
print(f" [{'PASS' if g_actions else 'FAIL'}] unique_actions >= {BAR_GATE_ACTIONS}: "
f"{n_unique}")
print(f" [{'PASS' if g_parse else 'FAIL'}] parse_failures < 5% of calls: "
f"{n_parse_fails}/{n * MAX_STEPS}")
print()
if g_score and g_mae and g_actions and g_parse:
print("ALL GATES PASS — safe to scale to production batch.")
else:
print("ONE OR MORE GATES FAILED — investigate before scaling.")
if not g_score:
print(" -> Teacher quality too low. Consider escalating model "
"(e.g. gpt-5-pro) or refining the prompt.")
if not g_mae:
print(" -> Teacher's beliefs aren't tracking the true profile. "
"Check anomaly visibility in observation prompt.")
if not g_actions:
print(" -> Teacher converged on a narrow action set. Encourage "
"exploration in the prompt.")
if not g_parse:
print(" -> Many responses didn't end with the answer pattern. "
"Strengthen format instruction in the system prompt.")
if __name__ == "__main__":
asyncio.run(main())