""" In-process episode driver for the /custom cockpit. Spawned as an asyncio background task by `POST /custom/start_episode`. Drives a fresh CricketEnvironment locally (no HTTP) and publishes events through the spectator BusObserver against the loopback URL — that way the cockpit's WebSocket subscribers see the same canonical event schema as the standalone `spectator/run_with_ui.py` driver. Captain choice comes from `server.captain_policy.captain_presets()`. Opponent choice maps through `OPPONENT_PRESETS`. """ from __future__ import annotations import asyncio import logging import os import time import uuid from typing import Any import httpx from server.cricket_environment import CricketEnvironment from server.captain_policy import ( OPPONENT_PRESETS, captain_presets, pick_action, ) from spectator.heuristic_captain import HeuristicCaptain from spectator.publisher import BusObserver from spectator.tts import CartesiaTTS from spectator.ui_frame import to_ui_frame from models import CricketAction log = logging.getLogger("server.episode_runner") if not log.handlers: h = logging.StreamHandler() h.setFormatter(logging.Formatter("%(asctime)s [episode] %(message)s", datefmt="%H:%M:%S")) log.addHandler(h) log.setLevel(os.environ.get("EPISODE_LOG_LEVEL", "INFO").upper()) log.propagate = False # Loopback base URL the BusObserver POSTs to. The Space's openenv.yaml sets # app_port=8000; can be overridden via env if needed. _DEFAULT_BASE_URL = os.environ.get("CRICKET_LOCAL_BASE_URL", "http://127.0.0.1:8000") # At-most-one-running guard so a double-click doesn't spawn parallel matches # competing for the bus and Cartesia quota. _RUNNING_LOCK = asyncio.Lock() _RUNNING_INFO: dict[str, Any] = {} def is_running() -> dict[str, Any] | None: """Return current episode metadata, or None if idle.""" return dict(_RUNNING_INFO) if _RUNNING_INFO else None def _ep_id() -> str: return f"custom-ep-{int(time.time())}-{uuid.uuid4().hex[:6]}" def _heuristic_action(captain: HeuristicCaptain): """Build a heuristic-fallback closure that takes an obs and returns a CricketAction (the shape captain_policy.pick_action expects).""" def fallback(obs): frame = to_ui_frame(obs, episode_id="") d = captain.act(frame) return CricketAction(tool=d.get("tool", "analyze_situation"), arguments=d.get("arguments", {}) or {}) return fallback async def _drive( captain_preset: str, opponent_choice: str, max_overs: int, eval_pack_id: str, enable_tts: bool, bus, ) -> dict[str, Any]: """One episode end-to-end. Returns a small summary dict.""" presets = captain_presets() if captain_preset not in presets: raise ValueError(f"unknown captain preset: {captain_preset!r} (have {sorted(presets)})") opponent_mode = OPPONENT_PRESETS.get(opponent_choice, "heuristic") # The observer expects a real httpx client, even if we're hitting loopback. async with httpx.AsyncClient(timeout=15.0) as http: ep_id = _ep_id() tts: CartesiaTTS | None = None if enable_tts: tts = CartesiaTTS() if not tts.enabled: log.info("TTS requested but no CARTESIA_API_KEY; commentary will be silent") tts = None else: await tts.start() observer = BusObserver( http_client=http, base_url=_DEFAULT_BASE_URL, episode_id=ep_id, mode=captain_preset, task="custom_episode", opponent_mode=opponent_mode, eval_pack_id=eval_pack_id, tts=tts, ) env = CricketEnvironment() obs = env.reset(options={ "random_start": False, "opponent_mode": opponent_mode, "eval_pack_id": eval_pack_id, "max_overs": max_overs, }) await observer.start(obs) # Captain policy: tries the LLM, falls back to HeuristicCaptain on failure. heuristic = HeuristicCaptain(seed=None) fallback = _heuristic_action(heuristic) turn = 0 deliveries = 0 max_steps = 800 log.info(f"started ep={ep_id} captain={captain_preset} opponent={opponent_choice}") while turn < max_steps: try: action = pick_action(obs, captain_preset, fallback) except Exception as exc: log.warning(f"captain pick failed at turn {turn}: {exc!r}; using heuristic") action = fallback(obs) action_dict = {"tool": action.tool, "arguments": action.arguments or {}} pre = await observer.before_step(action_dict, obs, turn=turn) try: # CricketEnvironment.step is sync — keep loop responsive by # offloading. Most tools are fast, but the LLM opponent for a # single ball can take a couple of seconds. new_obs = await asyncio.get_event_loop().run_in_executor( None, env.step, action ) except Exception as exc: log.error(f"env.step failed at turn {turn}: {exc!r}; ending episode") break done = bool(getattr(env._state, "match_over", False)) or getattr(new_obs, "done", False) # BusObserver.after_step expects (action, pre, next_obs, reward, *, turn, ...). # The local CricketEnvironment doesn't surface a per-step reward; pass None. await observer.after_step(action_dict, pre, new_obs, None, turn=turn) if action.tool in ("play_delivery", "bowl_delivery"): deliveries += 1 obs = new_obs turn += 1 if done: break # Tiny breather so the WebSocket fanout can flush. await asyncio.sleep(0.01) await observer.end(getattr(env, "_state", None)) if tts is not None: await tts.aclose() log.info(f"finished ep={ep_id} turns={turn} deliveries={deliveries}") return {"episode_id": ep_id, "turns": turn, "deliveries": deliveries} async def run_episode_in_proc( bus, captain_preset: str, opponent_choice: str, max_overs: int = 5, eval_pack_id: str = "default", enable_tts: bool = True, ) -> dict[str, Any]: """Background entry point. Holds the at-most-one lock for the episode's lifetime so concurrent /custom/start_episode requests are serialized.""" if _RUNNING_LOCK.locked(): return {"error": "an episode is already running", "current": dict(_RUNNING_INFO)} async with _RUNNING_LOCK: ep_id = _ep_id() _RUNNING_INFO.update({ "episode_id": ep_id, "captain": captain_preset, "opponent": opponent_choice, "started_at": time.time(), }) try: return await _drive( captain_preset=captain_preset, opponent_choice=opponent_choice, max_overs=max_overs, eval_pack_id=eval_pack_id, enable_tts=enable_tts, bus=bus, ) finally: _RUNNING_INFO.clear()