Spaces:
Sleeping
Sleeping
| """ | |
| Generate teacher trajectories for Algorithm Distillation. | |
| For each seed, plays one full RhythmEnv episode where the action at each step | |
| is chosen by a teacher LLM (gpt-5.4 via Azure OpenAI). The teacher is prompted | |
| to emit `<reasoning>...</reasoning>` followed by `S M W ACTION_NAME` on a final | |
| line. We parse the answer line, step the env, save the full (prompt, response, | |
| action, reward) tuple to JSONL, and aggregate per-episode metrics for gating. | |
| Required env vars (no secrets in code): | |
| AZURE_OPENAI_ENDPOINT e.g. https://metahackathon-resource.cognitiveservices.azure.com/ | |
| AZURE_OPENAI_API_KEY your Azure OpenAI key (do NOT paste in chat) | |
| AZURE_OPENAI_DEPLOYMENT the deployment name you chose, e.g. gpt-5.4 | |
| AZURE_OPENAI_API_VERSION e.g. 2024-12-01-preview (default if unset) | |
| Usage from rhythm_env root: | |
| # Stage 1a: 30-episode validation (~$3-5) | |
| python scripts/generate_teacher_trajectories.py \ | |
| --seeds 0-29 \ | |
| --output data/teacher_30ep_validation.jsonl \ | |
| --concurrency 3 | |
| # Stage 1b: scale to 150 episodes (~$15-20) | |
| python scripts/generate_teacher_trajectories.py \ | |
| --seeds 0-99 \ | |
| --output data/teacher_150ep_indist.jsonl \ | |
| --concurrency 5 | |
| python scripts/generate_teacher_trajectories.py \ | |
| --seeds 10000-10049 \ | |
| --output data/teacher_150ep_ood.jsonl \ | |
| --concurrency 5 | |
| The script prints PASS/FAIL gate verdicts at the end so you can decide whether | |
| to scale or fix the teacher prompt before spending more. | |
| """ | |
| import argparse | |
| import asyncio | |
| import json | |
| import os | |
| import re | |
| import sys | |
| import time | |
| from collections import Counter | |
| from pathlib import Path | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| # Load .env (repo root) before reading os.environ so credentials don't have | |
| # to be exported in the shell. The .env file is in .gitignore. | |
| try: | |
| from dotenv import load_dotenv | |
| _ENV_PATH = Path(__file__).resolve().parent.parent / ".env" | |
| if _ENV_PATH.exists(): | |
| load_dotenv(_ENV_PATH) | |
| except ImportError: | |
| pass # dotenv not installed → fall back to whatever's in the shell | |
| from openai import AsyncAzureOpenAI | |
| from openai import APIError, RateLimitError, APIConnectionError, APITimeoutError | |
| from models import ActionType, RhythmAction | |
| from server.rhythm_environment import ( | |
| MAX_STEPS, | |
| RhythmEnvironment, | |
| ) | |
| from training.dataset import format_observation_prompt | |
| # --------------------------------------------------------------------------- | |
| # Teacher system prompt | |
| # --------------------------------------------------------------------------- | |
| # The student will eventually be SFT'd to match this contract: emit a | |
| # <reasoning>...</reasoning> block then a final answer line `S M W ACTION_NAME`. | |
| # Keep this in sync with whatever SYSTEM_PROMPT the SFT'd student will use. | |
| TEACHER_SYSTEM_PROMPT = """You are a life-management agent helping a person whose preferences are HIDDEN. | |
| You see 5 life meters and a rolling history of recent steps. The same action | |
| affects different people differently — you must INFER who you're helping from | |
| rewards, meter changes, and per-meter ANOMALY signals. | |
| Each step, do TWO things: | |
| 1. Reason briefly about what the observations imply about the person. | |
| Focus on: | |
| - Anomalies (actual delta vs neutral-profile expectation): big positive | |
| social_serenity / connection responses → high S; big morning cognition | |
| gains → high M; productive work giving vitality back → high W | |
| - Current meter state: any meter under 0.15 needs urgent recovery | |
| - What action best fits BOTH the inferred profile and the current state | |
| 2. Output your final answer on the LAST line in this exact format: | |
| S M W ACTION_NAME | |
| where S, M, W are belief digits 0-9 (0=low, 9=high) representing your best | |
| estimate of social_pref, morning_pref, work_pref. ACTION_NAME is one of: | |
| DEEP_WORK, ADMIN_WORK, LEARN, SLEEP, EXERCISE, MEDITATE, FAMILY_TIME, | |
| SOCIALIZE, ME_TIME, BINGE_WATCH | |
| Wrap your reasoning in <reasoning>...</reasoning> tags. Keep reasoning under | |
| 120 tokens. The final answer line MUST be the last line of your response. | |
| Belief→action quick reference: | |
| - High S (extrovert): SOCIALIZE, FAMILY_TIME boost connection cheaply | |
| - High M (morning person): DEEP_WORK / LEARN in early slots gets bonus cognition | |
| - High W (workaholic): DEEP_WORK, LEARN drive progress and may energize | |
| - Low S (introvert): MEDITATE, ME_TIME for solo recharge; avoid SOCIALIZE | |
| - Low M (night owl): DEEP_WORK / LEARN in evening/night slots | |
| - Watch crashes: any meter under 0.10 = -0.30 penalty per crashed meter | |
| - Connection decays passively — actively maintain via SOCIALIZE/FAMILY_TIME | |
| - Don't repeat the same action 3+ times in a row — repetition penalty applies | |
| Strategy: probe varied actions in the first ~5 steps to gather profile evidence, | |
| then exploit your sharpened belief by picking actions that match the inferred | |
| profile + current meter state. | |
| Example output: | |
| <reasoning> | |
| Last step's socialize gave V-0.12 (anom -0.06, much worse than neutral) — high | |
| social drain, suggests low S. Morning DEEP_WORK earlier gave bonus cognition | |
| (anom +0.04) → high M. Vitality at 0.6 still ok, serenity dropping. With low S + | |
| high M, MEDITATE is the recovery play that fits. | |
| </reasoning> | |
| 2 8 5 MEDITATE""" | |
| # --------------------------------------------------------------------------- | |
| # Answer parsing — find the LAST `S M W ACTION_NAME` pattern in the response | |
| # --------------------------------------------------------------------------- | |
| VALID_ACTIONS = [at.value.upper() for at in ActionType] | |
| ANSWER_PATTERN = re.compile( | |
| r'(\d)\s+(\d)\s+(\d)\s+(' + '|'.join(VALID_ACTIONS) + r')\b', | |
| re.IGNORECASE, | |
| ) | |
| def parse_teacher_response(text: str): | |
| """Extract (action_type, belief_vector, raw_match) from teacher output. | |
| Returns (None, None, None) if no answer line is parseable. | |
| """ | |
| if not text: | |
| return None, None, None | |
| matches = list(ANSWER_PATTERN.finditer(text)) | |
| if not matches: | |
| return None, None, None | |
| last = matches[-1] | |
| s, m, w, action_name = last.groups() | |
| try: | |
| belief = [int(s) / 9.0, int(m) / 9.0, int(w) / 9.0] | |
| action = ActionType(action_name.lower()) | |
| return action, belief, last.group(0) | |
| except (ValueError, KeyError): | |
| return None, None, None | |
| # --------------------------------------------------------------------------- | |
| # Async API calls with retry | |
| # --------------------------------------------------------------------------- | |
| async def call_teacher( | |
| client: AsyncAzureOpenAI, | |
| deployment: str, | |
| user_prompt: str, | |
| temperature: float = 0.5, | |
| max_completion_tokens: int = 400, | |
| max_retries: int = 4, | |
| ) -> str: | |
| """Call the teacher with retries on transient errors. Returns response text.""" | |
| last_err: Exception | None = None | |
| for attempt in range(max_retries): | |
| try: | |
| resp = await client.chat.completions.create( | |
| model=deployment, | |
| messages=[ | |
| {"role": "system", "content": TEACHER_SYSTEM_PROMPT}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| temperature=temperature, | |
| max_completion_tokens=max_completion_tokens, | |
| ) | |
| return resp.choices[0].message.content or "" | |
| except (RateLimitError, APIConnectionError, APITimeoutError) as e: | |
| last_err = e | |
| wait = min(60, 2 ** attempt) | |
| await asyncio.sleep(wait) | |
| except APIError as e: | |
| # Non-transient API error — log and bail (don't waste retries) | |
| last_err = e | |
| break | |
| raise RuntimeError(f"Teacher call failed after {max_retries} retries: {last_err}") | |
| # --------------------------------------------------------------------------- | |
| # Episode rollout | |
| # --------------------------------------------------------------------------- | |
| async def play_episode( | |
| client: AsyncAzureOpenAI, | |
| deployment: str, | |
| seed: int, | |
| ) -> tuple[list[dict], dict]: | |
| """Run a full episode with the teacher. Returns (per-step rows, summary).""" | |
| env = RhythmEnvironment() | |
| obs = env.reset(seed=seed) | |
| true_belief = env.get_belief_target() | |
| profile_name = env.state.profile_name | |
| step_rows: list[dict] = [] | |
| actions_taken: list[str] = [] | |
| rewards: list[float] = [] | |
| final_belief: list[float] | None = None | |
| for step_idx in range(MAX_STEPS): | |
| if obs.done: | |
| break | |
| user_prompt = format_observation_prompt(obs) | |
| try: | |
| teacher_resp = await call_teacher(client, deployment, user_prompt) | |
| except RuntimeError as e: | |
| # Hard failure — abort this episode rather than corrupt the dataset | |
| return step_rows, { | |
| "seed": seed, | |
| "profile_name": profile_name, | |
| "true_belief": [round(x, 3) for x in true_belief], | |
| "final_belief": [round(x, 3) for x in final_belief] if final_belief else None, | |
| "belief_mae": None, | |
| "final_score": 0.0, | |
| "total_reward": round(sum(rewards), 2), | |
| "n_steps": len(step_rows), | |
| "actions": actions_taken, | |
| "action_distribution": dict(Counter(actions_taken)), | |
| "n_parse_failures": sum(1 for r in step_rows if r["parse_failed"]), | |
| "aborted": True, | |
| "error": str(e), | |
| } | |
| action, belief, raw_match = parse_teacher_response(teacher_resp) | |
| parse_failed = action is None | |
| if parse_failed: | |
| # Fallback: SLEEP keeps the episode alive without skewing exploration | |
| action = ActionType.SLEEP | |
| belief = [0.5, 0.5, 0.5] | |
| else: | |
| final_belief = belief | |
| # Tell the env about the emitted belief so the grader's belief_accuracy | |
| # component scores it. Without this call, final_score logged below is | |
| # artificially low (belief component scores 0 even when the teacher | |
| # actually emitted a belief). | |
| env.record_belief(belief) | |
| rhythm_action = RhythmAction(action_type=action) | |
| actions_taken.append(action.value) | |
| next_obs = env.step(rhythm_action) | |
| rewards.append(next_obs.reward) | |
| step_rows.append({ | |
| "seed": seed, | |
| "step": step_idx, | |
| "profile_name": profile_name, | |
| "user_prompt": user_prompt, | |
| "teacher_response": teacher_resp, | |
| "parsed_action": action.value, | |
| "parsed_belief": belief, | |
| "answer_match": raw_match, | |
| "env_reward": round(next_obs.reward, 4), | |
| "parse_failed": parse_failed, | |
| "true_belief": [round(x, 3) for x in true_belief], | |
| }) | |
| obs = next_obs | |
| final_score = obs.reward_breakdown.get("final_score", 0.0) | |
| belief_mae = ( | |
| sum(abs(b - t) for b, t in zip(final_belief, true_belief)) / 3.0 | |
| if final_belief is not None else None | |
| ) | |
| return step_rows, { | |
| "seed": seed, | |
| "profile_name": profile_name, | |
| "true_belief": [round(x, 3) for x in true_belief], | |
| "final_belief": [round(x, 3) for x in final_belief] if final_belief else None, | |
| "belief_mae": round(belief_mae, 4) if belief_mae is not None else None, | |
| "final_score": round(final_score, 4), | |
| "total_reward": round(sum(rewards), 2), | |
| "n_steps": len(step_rows), | |
| "actions": actions_taken, | |
| "action_distribution": dict(Counter(actions_taken)), | |
| "n_parse_failures": sum(1 for r in step_rows if r["parse_failed"]), | |
| "aborted": False, | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Resume helpers | |
| # --------------------------------------------------------------------------- | |
| def already_completed_seeds(jsonl_path: Path) -> set[int]: | |
| """Seeds whose final step (MAX_STEPS - 1 = 27) is already in the file.""" | |
| if not jsonl_path.exists(): | |
| return set() | |
| seed_max_step: dict[int, int] = {} | |
| with open(jsonl_path) as f: | |
| for line in f: | |
| try: | |
| row = json.loads(line) | |
| except json.JSONDecodeError: | |
| continue | |
| sd = row.get("seed") | |
| st = row.get("step", -1) | |
| if sd is None: | |
| continue | |
| seed_max_step[sd] = max(seed_max_step.get(sd, -1), st) | |
| return {s for s, mx in seed_max_step.items() if mx >= MAX_STEPS - 1} | |
| # --------------------------------------------------------------------------- | |
| # Main | |
| # --------------------------------------------------------------------------- | |
| def parse_seed_arg(seed_str: str) -> list[int]: | |
| if "-" in seed_str and "," not in seed_str: | |
| lo, hi = seed_str.split("-") | |
| return list(range(int(lo), int(hi) + 1)) | |
| return [int(s.strip()) for s in seed_str.split(",") if s.strip()] | |
| async def main() -> None: | |
| parser = argparse.ArgumentParser(description=__doc__.split("\n\n")[0]) | |
| parser.add_argument("--seeds", type=str, required=True, | |
| help="Seed range '0-29' or comma list '0,1,5'") | |
| parser.add_argument("--output", type=str, required=True, | |
| help="Output JSONL path for per-step trajectories") | |
| parser.add_argument("--summary", type=str, default=None, | |
| help="Output JSON path for episode summaries (default: <output>.summary.json)") | |
| parser.add_argument("--concurrency", type=int, default=3, | |
| help="Episodes to run concurrently (default 3; 1500 RPM allows up to ~5)") | |
| parser.add_argument("--temperature", type=float, default=0.5, | |
| help="Teacher sampling temperature (default 0.5; lower = more consistent)") | |
| parser.add_argument("--no-resume", action="store_true", | |
| help="Do not skip seeds already in the output file") | |
| args = parser.parse_args() | |
| seeds = parse_seed_arg(args.seeds) | |
| output_path = Path(args.output) | |
| summary_path = Path(args.summary) if args.summary else output_path.with_suffix(".summary.json") | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| if not args.no_resume: | |
| completed = already_completed_seeds(output_path) | |
| if completed: | |
| print(f"Resume: {len(completed)} seeds already complete; " | |
| f"{len(seeds) - len(completed & set(seeds))} remaining of {len(seeds)}") | |
| seeds = [s for s in seeds if s not in completed] | |
| # Azure config (read from env so secrets never touch the repo) | |
| try: | |
| endpoint = os.environ["AZURE_OPENAI_ENDPOINT"] | |
| api_key = os.environ["AZURE_OPENAI_API_KEY"] | |
| deployment = os.environ["AZURE_OPENAI_DEPLOYMENT"] | |
| except KeyError as e: | |
| sys.exit(f"ERROR: missing env var {e}. Set AZURE_OPENAI_ENDPOINT, " | |
| f"AZURE_OPENAI_API_KEY, AZURE_OPENAI_DEPLOYMENT.") | |
| api_version = os.environ.get("AZURE_OPENAI_API_VERSION", "2024-12-01-preview") | |
| print(f"Endpoint: {endpoint}") | |
| print(f"Deployment: {deployment}") | |
| print(f"API version: {api_version}") | |
| print(f"Seeds: {len(seeds)} (concurrency={args.concurrency}, temp={args.temperature})") | |
| print(f"Output: {output_path}") | |
| print(f"Summary: {summary_path}") | |
| print() | |
| if not seeds: | |
| print("No seeds to process. Exiting.") | |
| return | |
| client = AsyncAzureOpenAI( | |
| azure_endpoint=endpoint, | |
| api_key=api_key, | |
| api_version=api_version, | |
| ) | |
| sem = asyncio.Semaphore(args.concurrency) | |
| file_lock = asyncio.Lock() | |
| summaries: list[dict] = [] | |
| async def run_one(seed: int) -> dict | None: | |
| async with sem: | |
| t0 = time.time() | |
| print(f" [seed {seed}] starting", flush=True) | |
| try: | |
| step_rows, summary = await play_episode(client, deployment, seed) | |
| except Exception as e: | |
| print(f" [seed {seed}] CRASHED: {e}", flush=True) | |
| return None | |
| # Append per-step rows atomically (prevents interleaved writes) | |
| async with file_lock: | |
| with open(output_path, "a") as f: | |
| for row in step_rows: | |
| f.write(json.dumps(row) + "\n") | |
| dt = time.time() - t0 | |
| mae_str = f"{summary['belief_mae']:.3f}" if summary['belief_mae'] is not None else "n/a" | |
| print(f" [seed {seed}] done in {dt:.1f}s: " | |
| f"final={summary['final_score']:.3f} mae={mae_str} " | |
| f"unique_actions={len(summary['action_distribution'])} " | |
| f"parse_fails={summary['n_parse_failures']}", flush=True) | |
| return summary | |
| tasks = [run_one(s) for s in seeds] | |
| results = await asyncio.gather(*tasks) | |
| summaries = [r for r in results if r is not None] | |
| # Merge with any prior summaries (for resume) | |
| prior_summaries: list[dict] = [] | |
| if summary_path.exists() and not args.no_resume: | |
| try: | |
| with open(summary_path) as f: | |
| prior_summaries = json.load(f).get("episodes", []) | |
| except (json.JSONDecodeError, KeyError): | |
| prior_summaries = [] | |
| seen = {s["seed"] for s in summaries} | |
| summaries = summaries + [s for s in prior_summaries if s["seed"] not in seen] | |
| # Aggregate | |
| n = len(summaries) | |
| if n == 0: | |
| print("No episodes completed.") | |
| return | |
| valid = [s for s in summaries if not s.get("aborted")] | |
| avg_score = sum(s["final_score"] for s in valid) / max(len(valid), 1) | |
| valid_mae = [s["belief_mae"] for s in valid if s["belief_mae"] is not None] | |
| avg_mae = sum(valid_mae) / len(valid_mae) if valid_mae else None | |
| all_actions: Counter = Counter() | |
| for s in valid: | |
| all_actions.update(s["action_distribution"]) | |
| n_unique = len(all_actions) | |
| n_parse_fails = sum(s["n_parse_failures"] for s in valid) | |
| n_aborted = sum(1 for s in summaries if s.get("aborted")) | |
| summary_blob = { | |
| "n_episodes": n, | |
| "n_aborted": n_aborted, | |
| "avg_final_score": round(avg_score, 4), | |
| "avg_belief_mae": round(avg_mae, 4) if avg_mae is not None else None, | |
| "n_unique_actions_overall": n_unique, | |
| "action_distribution_overall": dict(all_actions), | |
| "n_parse_failures_total": n_parse_fails, | |
| "deployment": deployment, | |
| "api_version": api_version, | |
| "episodes": summaries, | |
| } | |
| with open(summary_path, "w") as f: | |
| json.dump(summary_blob, f, indent=2) | |
| # Gates | |
| BAR_HEURISTIC = 0.587 | |
| BAR_GATE_SCORE = 0.65 | |
| BAR_GATE_MAE = 0.20 | |
| BAR_GATE_ACTIONS = 6 | |
| print() | |
| print("=" * 72) | |
| print("BATCH SUMMARY") | |
| print("=" * 72) | |
| print(f"Episodes completed: {n} (aborted: {n_aborted})") | |
| print(f"Avg final_score: {avg_score:.4f} " | |
| f"(heuristic baseline: {BAR_HEURISTIC}, random: 0.516)") | |
| if avg_mae is not None: | |
| print(f"Avg belief MAE: {avg_mae:.4f} (lower is better)") | |
| print(f"Unique actions: {n_unique} of 10") | |
| print(f"Parse failures: {n_parse_fails} (across all step calls)") | |
| print() | |
| print("VALIDATION GATES:") | |
| g_score = avg_score >= BAR_GATE_SCORE | |
| g_mae = avg_mae is not None and avg_mae < BAR_GATE_MAE | |
| g_actions = n_unique >= BAR_GATE_ACTIONS | |
| g_parse = n_parse_fails < 0.05 * n * MAX_STEPS # < 5% parse failure rate | |
| print(f" [{'PASS' if g_score else 'FAIL'}] avg_final_score >= {BAR_GATE_SCORE}: " | |
| f"{avg_score:.3f}") | |
| mae_disp = f"{avg_mae:.3f}" if avg_mae is not None else "n/a" | |
| print(f" [{'PASS' if g_mae else 'FAIL'}] avg_belief_mae < {BAR_GATE_MAE}: {mae_disp}") | |
| print(f" [{'PASS' if g_actions else 'FAIL'}] unique_actions >= {BAR_GATE_ACTIONS}: " | |
| f"{n_unique}") | |
| print(f" [{'PASS' if g_parse else 'FAIL'}] parse_failures < 5% of calls: " | |
| f"{n_parse_fails}/{n * MAX_STEPS}") | |
| print() | |
| if g_score and g_mae and g_actions and g_parse: | |
| print("ALL GATES PASS — safe to scale to production batch.") | |
| else: | |
| print("ONE OR MORE GATES FAILED — investigate before scaling.") | |
| if not g_score: | |
| print(" -> Teacher quality too low. Consider escalating model " | |
| "(e.g. gpt-5-pro) or refining the prompt.") | |
| if not g_mae: | |
| print(" -> Teacher's beliefs aren't tracking the true profile. " | |
| "Check anomaly visibility in observation prompt.") | |
| if not g_actions: | |
| print(" -> Teacher converged on a narrow action set. Encourage " | |
| "exploration in the prompt.") | |
| if not g_parse: | |
| print(" -> Many responses didn't end with the answer pattern. " | |
| "Strengthen format instruction in the system prompt.") | |
| if __name__ == "__main__": | |
| asyncio.run(main()) | |