cricket-captain-llm / server /episode_runner.py
pratinavseth's picture
custom: in-process episode driver + CaptainRL rename (mirrors github c17c1ba)
8a86db4 verified
"""
In-process episode driver for the /custom cockpit.
Spawned as an asyncio background task by `POST /custom/start_episode`. Drives
a fresh CricketEnvironment locally (no HTTP) and publishes events through the
spectator BusObserver against the loopback URL — that way the cockpit's
WebSocket subscribers see the same canonical event schema as the standalone
`spectator/run_with_ui.py` driver.
Captain choice comes from `server.captain_policy.captain_presets()`. Opponent
choice maps through `OPPONENT_PRESETS`.
"""
from __future__ import annotations
import asyncio
import logging
import os
import time
import uuid
from typing import Any
import httpx
from server.cricket_environment import CricketEnvironment
from server.captain_policy import (
OPPONENT_PRESETS,
captain_presets,
pick_action,
)
from spectator.heuristic_captain import HeuristicCaptain
from spectator.publisher import BusObserver
from spectator.tts import CartesiaTTS
from spectator.ui_frame import to_ui_frame
from models import CricketAction
log = logging.getLogger("server.episode_runner")
if not log.handlers:
h = logging.StreamHandler()
h.setFormatter(logging.Formatter("%(asctime)s [episode] %(message)s", datefmt="%H:%M:%S"))
log.addHandler(h)
log.setLevel(os.environ.get("EPISODE_LOG_LEVEL", "INFO").upper())
log.propagate = False
# Loopback base URL the BusObserver POSTs to. The Space's openenv.yaml sets
# app_port=8000; can be overridden via env if needed.
_DEFAULT_BASE_URL = os.environ.get("CRICKET_LOCAL_BASE_URL", "http://127.0.0.1:8000")
# At-most-one-running guard so a double-click doesn't spawn parallel matches
# competing for the bus and Cartesia quota.
_RUNNING_LOCK = asyncio.Lock()
_RUNNING_INFO: dict[str, Any] = {}
def is_running() -> dict[str, Any] | None:
"""Return current episode metadata, or None if idle."""
return dict(_RUNNING_INFO) if _RUNNING_INFO else None
def _ep_id() -> str:
return f"custom-ep-{int(time.time())}-{uuid.uuid4().hex[:6]}"
def _heuristic_action(captain: HeuristicCaptain):
"""Build a heuristic-fallback closure that takes an obs and returns
a CricketAction (the shape captain_policy.pick_action expects)."""
def fallback(obs):
frame = to_ui_frame(obs, episode_id="")
d = captain.act(frame)
return CricketAction(tool=d.get("tool", "analyze_situation"),
arguments=d.get("arguments", {}) or {})
return fallback
async def _drive(
captain_preset: str,
opponent_choice: str,
max_overs: int,
eval_pack_id: str,
enable_tts: bool,
bus,
) -> dict[str, Any]:
"""One episode end-to-end. Returns a small summary dict."""
presets = captain_presets()
if captain_preset not in presets:
raise ValueError(f"unknown captain preset: {captain_preset!r} (have {sorted(presets)})")
opponent_mode = OPPONENT_PRESETS.get(opponent_choice, "heuristic")
# The observer expects a real httpx client, even if we're hitting loopback.
async with httpx.AsyncClient(timeout=15.0) as http:
ep_id = _ep_id()
tts: CartesiaTTS | None = None
if enable_tts:
tts = CartesiaTTS()
if not tts.enabled:
log.info("TTS requested but no CARTESIA_API_KEY; commentary will be silent")
tts = None
else:
await tts.start()
observer = BusObserver(
http_client=http,
base_url=_DEFAULT_BASE_URL,
episode_id=ep_id,
mode=captain_preset,
task="custom_episode",
opponent_mode=opponent_mode,
eval_pack_id=eval_pack_id,
tts=tts,
)
env = CricketEnvironment()
obs = env.reset(options={
"random_start": False,
"opponent_mode": opponent_mode,
"eval_pack_id": eval_pack_id,
"max_overs": max_overs,
})
await observer.start(obs)
# Captain policy: tries the LLM, falls back to HeuristicCaptain on failure.
heuristic = HeuristicCaptain(seed=None)
fallback = _heuristic_action(heuristic)
turn = 0
deliveries = 0
max_steps = 800
log.info(f"started ep={ep_id} captain={captain_preset} opponent={opponent_choice}")
while turn < max_steps:
try:
action = pick_action(obs, captain_preset, fallback)
except Exception as exc:
log.warning(f"captain pick failed at turn {turn}: {exc!r}; using heuristic")
action = fallback(obs)
action_dict = {"tool": action.tool, "arguments": action.arguments or {}}
pre = await observer.before_step(action_dict, obs, turn=turn)
try:
# CricketEnvironment.step is sync — keep loop responsive by
# offloading. Most tools are fast, but the LLM opponent for a
# single ball can take a couple of seconds.
new_obs = await asyncio.get_event_loop().run_in_executor(
None, env.step, action
)
except Exception as exc:
log.error(f"env.step failed at turn {turn}: {exc!r}; ending episode")
break
done = bool(getattr(env._state, "match_over", False)) or getattr(new_obs, "done", False)
# BusObserver.after_step expects (action, pre, next_obs, reward, *, turn, ...).
# The local CricketEnvironment doesn't surface a per-step reward; pass None.
await observer.after_step(action_dict, pre, new_obs, None, turn=turn)
if action.tool in ("play_delivery", "bowl_delivery"):
deliveries += 1
obs = new_obs
turn += 1
if done:
break
# Tiny breather so the WebSocket fanout can flush.
await asyncio.sleep(0.01)
await observer.end(getattr(env, "_state", None))
if tts is not None:
await tts.aclose()
log.info(f"finished ep={ep_id} turns={turn} deliveries={deliveries}")
return {"episode_id": ep_id, "turns": turn, "deliveries": deliveries}
async def run_episode_in_proc(
bus,
captain_preset: str,
opponent_choice: str,
max_overs: int = 5,
eval_pack_id: str = "default",
enable_tts: bool = True,
) -> dict[str, Any]:
"""Background entry point. Holds the at-most-one lock for the episode's
lifetime so concurrent /custom/start_episode requests are serialized."""
if _RUNNING_LOCK.locked():
return {"error": "an episode is already running", "current": dict(_RUNNING_INFO)}
async with _RUNNING_LOCK:
ep_id = _ep_id()
_RUNNING_INFO.update({
"episode_id": ep_id,
"captain": captain_preset,
"opponent": opponent_choice,
"started_at": time.time(),
})
try:
return await _drive(
captain_preset=captain_preset,
opponent_choice=opponent_choice,
max_overs=max_overs,
eval_pack_id=eval_pack_id,
enable_tts=enable_tts,
bus=bus,
)
finally:
_RUNNING_INFO.clear()