Spaces:
Sleeping
Sleeping
File size: 20,944 Bytes
ece0bbe | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 | """
Generate teacher trajectories for Algorithm Distillation.
For each seed, plays one full RhythmEnv episode where the action at each step
is chosen by a teacher LLM (gpt-5.4 via Azure OpenAI). The teacher is prompted
to emit `<reasoning>...</reasoning>` followed by `S M W ACTION_NAME` on a final
line. We parse the answer line, step the env, save the full (prompt, response,
action, reward) tuple to JSONL, and aggregate per-episode metrics for gating.
Required env vars (no secrets in code):
AZURE_OPENAI_ENDPOINT e.g. https://metahackathon-resource.cognitiveservices.azure.com/
AZURE_OPENAI_API_KEY your Azure OpenAI key (do NOT paste in chat)
AZURE_OPENAI_DEPLOYMENT the deployment name you chose, e.g. gpt-5.4
AZURE_OPENAI_API_VERSION e.g. 2024-12-01-preview (default if unset)
Usage from rhythm_env root:
# Stage 1a: 30-episode validation (~$3-5)
python scripts/generate_teacher_trajectories.py \
--seeds 0-29 \
--output data/teacher_30ep_validation.jsonl \
--concurrency 3
# Stage 1b: scale to 150 episodes (~$15-20)
python scripts/generate_teacher_trajectories.py \
--seeds 0-99 \
--output data/teacher_150ep_indist.jsonl \
--concurrency 5
python scripts/generate_teacher_trajectories.py \
--seeds 10000-10049 \
--output data/teacher_150ep_ood.jsonl \
--concurrency 5
The script prints PASS/FAIL gate verdicts at the end so you can decide whether
to scale or fix the teacher prompt before spending more.
"""
import argparse
import asyncio
import json
import os
import re
import sys
import time
from collections import Counter
from pathlib import Path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# Load .env (repo root) before reading os.environ so credentials don't have
# to be exported in the shell. The .env file is in .gitignore.
try:
from dotenv import load_dotenv
_ENV_PATH = Path(__file__).resolve().parent.parent / ".env"
if _ENV_PATH.exists():
load_dotenv(_ENV_PATH)
except ImportError:
pass # dotenv not installed → fall back to whatever's in the shell
from openai import AsyncAzureOpenAI
from openai import APIError, RateLimitError, APIConnectionError, APITimeoutError
from models import ActionType, RhythmAction
from server.rhythm_environment import (
MAX_STEPS,
RhythmEnvironment,
)
from training.dataset import format_observation_prompt
# ---------------------------------------------------------------------------
# Teacher system prompt
# ---------------------------------------------------------------------------
# The student will eventually be SFT'd to match this contract: emit a
# <reasoning>...</reasoning> block then a final answer line `S M W ACTION_NAME`.
# Keep this in sync with whatever SYSTEM_PROMPT the SFT'd student will use.
TEACHER_SYSTEM_PROMPT = """You are a life-management agent helping a person whose preferences are HIDDEN.
You see 5 life meters and a rolling history of recent steps. The same action
affects different people differently — you must INFER who you're helping from
rewards, meter changes, and per-meter ANOMALY signals.
Each step, do TWO things:
1. Reason briefly about what the observations imply about the person.
Focus on:
- Anomalies (actual delta vs neutral-profile expectation): big positive
social_serenity / connection responses → high S; big morning cognition
gains → high M; productive work giving vitality back → high W
- Current meter state: any meter under 0.15 needs urgent recovery
- What action best fits BOTH the inferred profile and the current state
2. Output your final answer on the LAST line in this exact format:
S M W ACTION_NAME
where S, M, W are belief digits 0-9 (0=low, 9=high) representing your best
estimate of social_pref, morning_pref, work_pref. ACTION_NAME is one of:
DEEP_WORK, ADMIN_WORK, LEARN, SLEEP, EXERCISE, MEDITATE, FAMILY_TIME,
SOCIALIZE, ME_TIME, BINGE_WATCH
Wrap your reasoning in <reasoning>...</reasoning> tags. Keep reasoning under
120 tokens. The final answer line MUST be the last line of your response.
Belief→action quick reference:
- High S (extrovert): SOCIALIZE, FAMILY_TIME boost connection cheaply
- High M (morning person): DEEP_WORK / LEARN in early slots gets bonus cognition
- High W (workaholic): DEEP_WORK, LEARN drive progress and may energize
- Low S (introvert): MEDITATE, ME_TIME for solo recharge; avoid SOCIALIZE
- Low M (night owl): DEEP_WORK / LEARN in evening/night slots
- Watch crashes: any meter under 0.10 = -0.30 penalty per crashed meter
- Connection decays passively — actively maintain via SOCIALIZE/FAMILY_TIME
- Don't repeat the same action 3+ times in a row — repetition penalty applies
Strategy: probe varied actions in the first ~5 steps to gather profile evidence,
then exploit your sharpened belief by picking actions that match the inferred
profile + current meter state.
Example output:
<reasoning>
Last step's socialize gave V-0.12 (anom -0.06, much worse than neutral) — high
social drain, suggests low S. Morning DEEP_WORK earlier gave bonus cognition
(anom +0.04) → high M. Vitality at 0.6 still ok, serenity dropping. With low S +
high M, MEDITATE is the recovery play that fits.
</reasoning>
2 8 5 MEDITATE"""
# ---------------------------------------------------------------------------
# Answer parsing — find the LAST `S M W ACTION_NAME` pattern in the response
# ---------------------------------------------------------------------------
VALID_ACTIONS = [at.value.upper() for at in ActionType]
ANSWER_PATTERN = re.compile(
r'(\d)\s+(\d)\s+(\d)\s+(' + '|'.join(VALID_ACTIONS) + r')\b',
re.IGNORECASE,
)
def parse_teacher_response(text: str):
"""Extract (action_type, belief_vector, raw_match) from teacher output.
Returns (None, None, None) if no answer line is parseable.
"""
if not text:
return None, None, None
matches = list(ANSWER_PATTERN.finditer(text))
if not matches:
return None, None, None
last = matches[-1]
s, m, w, action_name = last.groups()
try:
belief = [int(s) / 9.0, int(m) / 9.0, int(w) / 9.0]
action = ActionType(action_name.lower())
return action, belief, last.group(0)
except (ValueError, KeyError):
return None, None, None
# ---------------------------------------------------------------------------
# Async API calls with retry
# ---------------------------------------------------------------------------
async def call_teacher(
client: AsyncAzureOpenAI,
deployment: str,
user_prompt: str,
temperature: float = 0.5,
max_completion_tokens: int = 400,
max_retries: int = 4,
) -> str:
"""Call the teacher with retries on transient errors. Returns response text."""
last_err: Exception | None = None
for attempt in range(max_retries):
try:
resp = await client.chat.completions.create(
model=deployment,
messages=[
{"role": "system", "content": TEACHER_SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
],
temperature=temperature,
max_completion_tokens=max_completion_tokens,
)
return resp.choices[0].message.content or ""
except (RateLimitError, APIConnectionError, APITimeoutError) as e:
last_err = e
wait = min(60, 2 ** attempt)
await asyncio.sleep(wait)
except APIError as e:
# Non-transient API error — log and bail (don't waste retries)
last_err = e
break
raise RuntimeError(f"Teacher call failed after {max_retries} retries: {last_err}")
# ---------------------------------------------------------------------------
# Episode rollout
# ---------------------------------------------------------------------------
async def play_episode(
client: AsyncAzureOpenAI,
deployment: str,
seed: int,
) -> tuple[list[dict], dict]:
"""Run a full episode with the teacher. Returns (per-step rows, summary)."""
env = RhythmEnvironment()
obs = env.reset(seed=seed)
true_belief = env.get_belief_target()
profile_name = env.state.profile_name
step_rows: list[dict] = []
actions_taken: list[str] = []
rewards: list[float] = []
final_belief: list[float] | None = None
for step_idx in range(MAX_STEPS):
if obs.done:
break
user_prompt = format_observation_prompt(obs)
try:
teacher_resp = await call_teacher(client, deployment, user_prompt)
except RuntimeError as e:
# Hard failure — abort this episode rather than corrupt the dataset
return step_rows, {
"seed": seed,
"profile_name": profile_name,
"true_belief": [round(x, 3) for x in true_belief],
"final_belief": [round(x, 3) for x in final_belief] if final_belief else None,
"belief_mae": None,
"final_score": 0.0,
"total_reward": round(sum(rewards), 2),
"n_steps": len(step_rows),
"actions": actions_taken,
"action_distribution": dict(Counter(actions_taken)),
"n_parse_failures": sum(1 for r in step_rows if r["parse_failed"]),
"aborted": True,
"error": str(e),
}
action, belief, raw_match = parse_teacher_response(teacher_resp)
parse_failed = action is None
if parse_failed:
# Fallback: SLEEP keeps the episode alive without skewing exploration
action = ActionType.SLEEP
belief = [0.5, 0.5, 0.5]
else:
final_belief = belief
# Tell the env about the emitted belief so the grader's belief_accuracy
# component scores it. Without this call, final_score logged below is
# artificially low (belief component scores 0 even when the teacher
# actually emitted a belief).
env.record_belief(belief)
rhythm_action = RhythmAction(action_type=action)
actions_taken.append(action.value)
next_obs = env.step(rhythm_action)
rewards.append(next_obs.reward)
step_rows.append({
"seed": seed,
"step": step_idx,
"profile_name": profile_name,
"user_prompt": user_prompt,
"teacher_response": teacher_resp,
"parsed_action": action.value,
"parsed_belief": belief,
"answer_match": raw_match,
"env_reward": round(next_obs.reward, 4),
"parse_failed": parse_failed,
"true_belief": [round(x, 3) for x in true_belief],
})
obs = next_obs
final_score = obs.reward_breakdown.get("final_score", 0.0)
belief_mae = (
sum(abs(b - t) for b, t in zip(final_belief, true_belief)) / 3.0
if final_belief is not None else None
)
return step_rows, {
"seed": seed,
"profile_name": profile_name,
"true_belief": [round(x, 3) for x in true_belief],
"final_belief": [round(x, 3) for x in final_belief] if final_belief else None,
"belief_mae": round(belief_mae, 4) if belief_mae is not None else None,
"final_score": round(final_score, 4),
"total_reward": round(sum(rewards), 2),
"n_steps": len(step_rows),
"actions": actions_taken,
"action_distribution": dict(Counter(actions_taken)),
"n_parse_failures": sum(1 for r in step_rows if r["parse_failed"]),
"aborted": False,
}
# ---------------------------------------------------------------------------
# Resume helpers
# ---------------------------------------------------------------------------
def already_completed_seeds(jsonl_path: Path) -> set[int]:
"""Seeds whose final step (MAX_STEPS - 1 = 27) is already in the file."""
if not jsonl_path.exists():
return set()
seed_max_step: dict[int, int] = {}
with open(jsonl_path) as f:
for line in f:
try:
row = json.loads(line)
except json.JSONDecodeError:
continue
sd = row.get("seed")
st = row.get("step", -1)
if sd is None:
continue
seed_max_step[sd] = max(seed_max_step.get(sd, -1), st)
return {s for s, mx in seed_max_step.items() if mx >= MAX_STEPS - 1}
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def parse_seed_arg(seed_str: str) -> list[int]:
if "-" in seed_str and "," not in seed_str:
lo, hi = seed_str.split("-")
return list(range(int(lo), int(hi) + 1))
return [int(s.strip()) for s in seed_str.split(",") if s.strip()]
async def main() -> None:
parser = argparse.ArgumentParser(description=__doc__.split("\n\n")[0])
parser.add_argument("--seeds", type=str, required=True,
help="Seed range '0-29' or comma list '0,1,5'")
parser.add_argument("--output", type=str, required=True,
help="Output JSONL path for per-step trajectories")
parser.add_argument("--summary", type=str, default=None,
help="Output JSON path for episode summaries (default: <output>.summary.json)")
parser.add_argument("--concurrency", type=int, default=3,
help="Episodes to run concurrently (default 3; 1500 RPM allows up to ~5)")
parser.add_argument("--temperature", type=float, default=0.5,
help="Teacher sampling temperature (default 0.5; lower = more consistent)")
parser.add_argument("--no-resume", action="store_true",
help="Do not skip seeds already in the output file")
args = parser.parse_args()
seeds = parse_seed_arg(args.seeds)
output_path = Path(args.output)
summary_path = Path(args.summary) if args.summary else output_path.with_suffix(".summary.json")
output_path.parent.mkdir(parents=True, exist_ok=True)
if not args.no_resume:
completed = already_completed_seeds(output_path)
if completed:
print(f"Resume: {len(completed)} seeds already complete; "
f"{len(seeds) - len(completed & set(seeds))} remaining of {len(seeds)}")
seeds = [s for s in seeds if s not in completed]
# Azure config (read from env so secrets never touch the repo)
try:
endpoint = os.environ["AZURE_OPENAI_ENDPOINT"]
api_key = os.environ["AZURE_OPENAI_API_KEY"]
deployment = os.environ["AZURE_OPENAI_DEPLOYMENT"]
except KeyError as e:
sys.exit(f"ERROR: missing env var {e}. Set AZURE_OPENAI_ENDPOINT, "
f"AZURE_OPENAI_API_KEY, AZURE_OPENAI_DEPLOYMENT.")
api_version = os.environ.get("AZURE_OPENAI_API_VERSION", "2024-12-01-preview")
print(f"Endpoint: {endpoint}")
print(f"Deployment: {deployment}")
print(f"API version: {api_version}")
print(f"Seeds: {len(seeds)} (concurrency={args.concurrency}, temp={args.temperature})")
print(f"Output: {output_path}")
print(f"Summary: {summary_path}")
print()
if not seeds:
print("No seeds to process. Exiting.")
return
client = AsyncAzureOpenAI(
azure_endpoint=endpoint,
api_key=api_key,
api_version=api_version,
)
sem = asyncio.Semaphore(args.concurrency)
file_lock = asyncio.Lock()
summaries: list[dict] = []
async def run_one(seed: int) -> dict | None:
async with sem:
t0 = time.time()
print(f" [seed {seed}] starting", flush=True)
try:
step_rows, summary = await play_episode(client, deployment, seed)
except Exception as e:
print(f" [seed {seed}] CRASHED: {e}", flush=True)
return None
# Append per-step rows atomically (prevents interleaved writes)
async with file_lock:
with open(output_path, "a") as f:
for row in step_rows:
f.write(json.dumps(row) + "\n")
dt = time.time() - t0
mae_str = f"{summary['belief_mae']:.3f}" if summary['belief_mae'] is not None else "n/a"
print(f" [seed {seed}] done in {dt:.1f}s: "
f"final={summary['final_score']:.3f} mae={mae_str} "
f"unique_actions={len(summary['action_distribution'])} "
f"parse_fails={summary['n_parse_failures']}", flush=True)
return summary
tasks = [run_one(s) for s in seeds]
results = await asyncio.gather(*tasks)
summaries = [r for r in results if r is not None]
# Merge with any prior summaries (for resume)
prior_summaries: list[dict] = []
if summary_path.exists() and not args.no_resume:
try:
with open(summary_path) as f:
prior_summaries = json.load(f).get("episodes", [])
except (json.JSONDecodeError, KeyError):
prior_summaries = []
seen = {s["seed"] for s in summaries}
summaries = summaries + [s for s in prior_summaries if s["seed"] not in seen]
# Aggregate
n = len(summaries)
if n == 0:
print("No episodes completed.")
return
valid = [s for s in summaries if not s.get("aborted")]
avg_score = sum(s["final_score"] for s in valid) / max(len(valid), 1)
valid_mae = [s["belief_mae"] for s in valid if s["belief_mae"] is not None]
avg_mae = sum(valid_mae) / len(valid_mae) if valid_mae else None
all_actions: Counter = Counter()
for s in valid:
all_actions.update(s["action_distribution"])
n_unique = len(all_actions)
n_parse_fails = sum(s["n_parse_failures"] for s in valid)
n_aborted = sum(1 for s in summaries if s.get("aborted"))
summary_blob = {
"n_episodes": n,
"n_aborted": n_aborted,
"avg_final_score": round(avg_score, 4),
"avg_belief_mae": round(avg_mae, 4) if avg_mae is not None else None,
"n_unique_actions_overall": n_unique,
"action_distribution_overall": dict(all_actions),
"n_parse_failures_total": n_parse_fails,
"deployment": deployment,
"api_version": api_version,
"episodes": summaries,
}
with open(summary_path, "w") as f:
json.dump(summary_blob, f, indent=2)
# Gates
BAR_HEURISTIC = 0.587
BAR_GATE_SCORE = 0.65
BAR_GATE_MAE = 0.20
BAR_GATE_ACTIONS = 6
print()
print("=" * 72)
print("BATCH SUMMARY")
print("=" * 72)
print(f"Episodes completed: {n} (aborted: {n_aborted})")
print(f"Avg final_score: {avg_score:.4f} "
f"(heuristic baseline: {BAR_HEURISTIC}, random: 0.516)")
if avg_mae is not None:
print(f"Avg belief MAE: {avg_mae:.4f} (lower is better)")
print(f"Unique actions: {n_unique} of 10")
print(f"Parse failures: {n_parse_fails} (across all step calls)")
print()
print("VALIDATION GATES:")
g_score = avg_score >= BAR_GATE_SCORE
g_mae = avg_mae is not None and avg_mae < BAR_GATE_MAE
g_actions = n_unique >= BAR_GATE_ACTIONS
g_parse = n_parse_fails < 0.05 * n * MAX_STEPS # < 5% parse failure rate
print(f" [{'PASS' if g_score else 'FAIL'}] avg_final_score >= {BAR_GATE_SCORE}: "
f"{avg_score:.3f}")
mae_disp = f"{avg_mae:.3f}" if avg_mae is not None else "n/a"
print(f" [{'PASS' if g_mae else 'FAIL'}] avg_belief_mae < {BAR_GATE_MAE}: {mae_disp}")
print(f" [{'PASS' if g_actions else 'FAIL'}] unique_actions >= {BAR_GATE_ACTIONS}: "
f"{n_unique}")
print(f" [{'PASS' if g_parse else 'FAIL'}] parse_failures < 5% of calls: "
f"{n_parse_fails}/{n * MAX_STEPS}")
print()
if g_score and g_mae and g_actions and g_parse:
print("ALL GATES PASS — safe to scale to production batch.")
else:
print("ONE OR MORE GATES FAILED — investigate before scaling.")
if not g_score:
print(" -> Teacher quality too low. Consider escalating model "
"(e.g. gpt-5-pro) or refining the prompt.")
if not g_mae:
print(" -> Teacher's beliefs aren't tracking the true profile. "
"Check anomaly visibility in observation prompt.")
if not g_actions:
print(" -> Teacher converged on a narrow action set. Encourage "
"exploration in the prompt.")
if not g_parse:
print(" -> Many responses didn't end with the answer pattern. "
"Strengthen format instruction in the system prompt.")
if __name__ == "__main__":
asyncio.run(main())
|