| """ |
| In-process episode driver for the /custom cockpit. |
| |
| Spawned as an asyncio background task by `POST /custom/start_episode`. Drives |
| a fresh CricketEnvironment locally (no HTTP) and publishes events through the |
| spectator BusObserver against the loopback URL — that way the cockpit's |
| WebSocket subscribers see the same canonical event schema as the standalone |
| `spectator/run_with_ui.py` driver. |
| |
| Captain choice comes from `server.captain_policy.captain_presets()`. Opponent |
| choice maps through `OPPONENT_PRESETS`. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import asyncio |
| import logging |
| import os |
| import time |
| import uuid |
| from typing import Any |
|
|
| import httpx |
|
|
| from server.cricket_environment import CricketEnvironment |
| from server.captain_policy import ( |
| OPPONENT_PRESETS, |
| captain_presets, |
| pick_action, |
| ) |
| from spectator.heuristic_captain import HeuristicCaptain |
| from spectator.publisher import BusObserver |
| from spectator.tts import CartesiaTTS |
| from spectator.ui_frame import to_ui_frame |
| from models import CricketAction |
|
|
|
|
| log = logging.getLogger("server.episode_runner") |
| if not log.handlers: |
| h = logging.StreamHandler() |
| h.setFormatter(logging.Formatter("%(asctime)s [episode] %(message)s", datefmt="%H:%M:%S")) |
| log.addHandler(h) |
| log.setLevel(os.environ.get("EPISODE_LOG_LEVEL", "INFO").upper()) |
| log.propagate = False |
|
|
|
|
| |
| |
| _DEFAULT_BASE_URL = os.environ.get("CRICKET_LOCAL_BASE_URL", "http://127.0.0.1:8000") |
|
|
|
|
| |
| |
| _RUNNING_LOCK = asyncio.Lock() |
| _RUNNING_INFO: dict[str, Any] = {} |
|
|
|
|
| def is_running() -> dict[str, Any] | None: |
| """Return current episode metadata, or None if idle.""" |
| return dict(_RUNNING_INFO) if _RUNNING_INFO else None |
|
|
|
|
| def _ep_id() -> str: |
| return f"custom-ep-{int(time.time())}-{uuid.uuid4().hex[:6]}" |
|
|
|
|
| def _heuristic_action(captain: HeuristicCaptain): |
| """Build a heuristic-fallback closure that takes an obs and returns |
| a CricketAction (the shape captain_policy.pick_action expects).""" |
| def fallback(obs): |
| frame = to_ui_frame(obs, episode_id="") |
| d = captain.act(frame) |
| return CricketAction(tool=d.get("tool", "analyze_situation"), |
| arguments=d.get("arguments", {}) or {}) |
| return fallback |
|
|
|
|
| async def _drive( |
| captain_preset: str, |
| opponent_choice: str, |
| max_overs: int, |
| eval_pack_id: str, |
| enable_tts: bool, |
| bus, |
| ) -> dict[str, Any]: |
| """One episode end-to-end. Returns a small summary dict.""" |
| presets = captain_presets() |
| if captain_preset not in presets: |
| raise ValueError(f"unknown captain preset: {captain_preset!r} (have {sorted(presets)})") |
| opponent_mode = OPPONENT_PRESETS.get(opponent_choice, "heuristic") |
|
|
| |
| async with httpx.AsyncClient(timeout=15.0) as http: |
| ep_id = _ep_id() |
| tts: CartesiaTTS | None = None |
| if enable_tts: |
| tts = CartesiaTTS() |
| if not tts.enabled: |
| log.info("TTS requested but no CARTESIA_API_KEY; commentary will be silent") |
| tts = None |
| else: |
| await tts.start() |
|
|
| observer = BusObserver( |
| http_client=http, |
| base_url=_DEFAULT_BASE_URL, |
| episode_id=ep_id, |
| mode=captain_preset, |
| task="custom_episode", |
| opponent_mode=opponent_mode, |
| eval_pack_id=eval_pack_id, |
| tts=tts, |
| ) |
|
|
| env = CricketEnvironment() |
| obs = env.reset(options={ |
| "random_start": False, |
| "opponent_mode": opponent_mode, |
| "eval_pack_id": eval_pack_id, |
| "max_overs": max_overs, |
| }) |
|
|
| await observer.start(obs) |
|
|
| |
| heuristic = HeuristicCaptain(seed=None) |
| fallback = _heuristic_action(heuristic) |
|
|
| turn = 0 |
| deliveries = 0 |
| max_steps = 800 |
| log.info(f"started ep={ep_id} captain={captain_preset} opponent={opponent_choice}") |
|
|
| while turn < max_steps: |
| try: |
| action = pick_action(obs, captain_preset, fallback) |
| except Exception as exc: |
| log.warning(f"captain pick failed at turn {turn}: {exc!r}; using heuristic") |
| action = fallback(obs) |
|
|
| action_dict = {"tool": action.tool, "arguments": action.arguments or {}} |
| pre = await observer.before_step(action_dict, obs, turn=turn) |
|
|
| try: |
| |
| |
| |
| new_obs = await asyncio.get_event_loop().run_in_executor( |
| None, env.step, action |
| ) |
| except Exception as exc: |
| log.error(f"env.step failed at turn {turn}: {exc!r}; ending episode") |
| break |
|
|
| done = bool(getattr(env._state, "match_over", False)) or getattr(new_obs, "done", False) |
| |
| |
| await observer.after_step(action_dict, pre, new_obs, None, turn=turn) |
|
|
| if action.tool in ("play_delivery", "bowl_delivery"): |
| deliveries += 1 |
|
|
| obs = new_obs |
| turn += 1 |
| if done: |
| break |
|
|
| |
| await asyncio.sleep(0.01) |
|
|
| await observer.end(getattr(env, "_state", None)) |
| if tts is not None: |
| await tts.aclose() |
|
|
| log.info(f"finished ep={ep_id} turns={turn} deliveries={deliveries}") |
| return {"episode_id": ep_id, "turns": turn, "deliveries": deliveries} |
|
|
|
|
| async def run_episode_in_proc( |
| bus, |
| captain_preset: str, |
| opponent_choice: str, |
| max_overs: int = 5, |
| eval_pack_id: str = "default", |
| enable_tts: bool = True, |
| ) -> dict[str, Any]: |
| """Background entry point. Holds the at-most-one lock for the episode's |
| lifetime so concurrent /custom/start_episode requests are serialized.""" |
| if _RUNNING_LOCK.locked(): |
| return {"error": "an episode is already running", "current": dict(_RUNNING_INFO)} |
|
|
| async with _RUNNING_LOCK: |
| ep_id = _ep_id() |
| _RUNNING_INFO.update({ |
| "episode_id": ep_id, |
| "captain": captain_preset, |
| "opponent": opponent_choice, |
| "started_at": time.time(), |
| }) |
| try: |
| return await _drive( |
| captain_preset=captain_preset, |
| opponent_choice=opponent_choice, |
| max_overs=max_overs, |
| eval_pack_id=eval_pack_id, |
| enable_tts=enable_tts, |
| bus=bus, |
| ) |
| finally: |
| _RUNNING_INFO.clear() |
|
|