Commit ·
4d4439f
1
Parent(s): 2f787f1
fix: wire reset options for max_overs
Browse filesPass reset parameters via OpenEnv `options` so max_overs and other env options take effect in inference/eval runs.
Made-with: Cursor
- eval.py +12 -7
- inference.py +12 -6
eval.py
CHANGED
|
@@ -52,17 +52,20 @@ async def collect_eval_episodes(
|
|
| 52 |
task: str,
|
| 53 |
eval_pack_id: str = "default",
|
| 54 |
opponent_mode: str = "heuristic",
|
|
|
|
| 55 |
) -> list[dict[str, Any]]:
|
| 56 |
"""Run n_episodes and return raw episode data for visualisation."""
|
| 57 |
episodes = []
|
| 58 |
async with CricketCaptainEnv(env_url) as env:
|
| 59 |
for ep in range(n_episodes):
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
|
|
|
|
|
|
| 66 |
obs = result.observation
|
| 67 |
history = []
|
| 68 |
step_data = []
|
|
@@ -291,7 +294,7 @@ async def _run_eval(args):
|
|
| 291 |
|
| 292 |
print(f"Collecting {args.episodes} evaluation episodes...")
|
| 293 |
episodes = await collect_eval_episodes(
|
| 294 |
-
args.env_url, agent, args.episodes, args.task, args.eval_pack_id, args.opponent_mode
|
| 295 |
)
|
| 296 |
|
| 297 |
print_summary(episodes)
|
|
@@ -322,6 +325,8 @@ def main():
|
|
| 322 |
parser.add_argument("--eval-pack-id", default=os.environ.get("CRICKET_EVAL_PACK_ID", "default"))
|
| 323 |
parser.add_argument("--opponent-mode", default=os.environ.get("CRICKET_OPPONENT_MODE", "heuristic"),
|
| 324 |
choices=["heuristic", "llm_live", "llm_cached"])
|
|
|
|
|
|
|
| 325 |
parser.add_argument("--out-dir", default="./eval_output")
|
| 326 |
parser.add_argument("--log-file", default=None,
|
| 327 |
help="Path to JSONL training log for reward curves")
|
|
|
|
| 52 |
task: str,
|
| 53 |
eval_pack_id: str = "default",
|
| 54 |
opponent_mode: str = "heuristic",
|
| 55 |
+
max_overs: int | None = None,
|
| 56 |
) -> list[dict[str, Any]]:
|
| 57 |
"""Run n_episodes and return raw episode data for visualisation."""
|
| 58 |
episodes = []
|
| 59 |
async with CricketCaptainEnv(env_url) as env:
|
| 60 |
for ep in range(n_episodes):
|
| 61 |
+
# OpenEnv server routes reset params via `options`.
|
| 62 |
+
result = await env.reset(options={
|
| 63 |
+
"task": task,
|
| 64 |
+
"random_start": False,
|
| 65 |
+
"eval_pack_id": eval_pack_id,
|
| 66 |
+
"opponent_mode": opponent_mode,
|
| 67 |
+
"max_overs": max_overs,
|
| 68 |
+
})
|
| 69 |
obs = result.observation
|
| 70 |
history = []
|
| 71 |
step_data = []
|
|
|
|
| 294 |
|
| 295 |
print(f"Collecting {args.episodes} evaluation episodes...")
|
| 296 |
episodes = await collect_eval_episodes(
|
| 297 |
+
args.env_url, agent, args.episodes, args.task, args.eval_pack_id, args.opponent_mode, args.max_overs
|
| 298 |
)
|
| 299 |
|
| 300 |
print_summary(episodes)
|
|
|
|
| 325 |
parser.add_argument("--eval-pack-id", default=os.environ.get("CRICKET_EVAL_PACK_ID", "default"))
|
| 326 |
parser.add_argument("--opponent-mode", default=os.environ.get("CRICKET_OPPONENT_MODE", "heuristic"),
|
| 327 |
choices=["heuristic", "llm_live", "llm_cached"])
|
| 328 |
+
parser.add_argument("--max-overs", type=int, default=None,
|
| 329 |
+
help="Limit innings length for fast experiments (e.g. 5).")
|
| 330 |
parser.add_argument("--out-dir", default="./eval_output")
|
| 331 |
parser.add_argument("--log-file", default=None,
|
| 332 |
help="Path to JSONL training log for reward curves")
|
inference.py
CHANGED
|
@@ -225,13 +225,16 @@ async def run_episode(
|
|
| 225 |
verbose: bool = False,
|
| 226 |
eval_pack_id: str = "default",
|
| 227 |
opponent_mode: str = "heuristic",
|
|
|
|
| 228 |
) -> dict[str, Any]:
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
|
|
|
|
|
|
| 235 |
obs = result.observation
|
| 236 |
|
| 237 |
history: list[dict] = []
|
|
@@ -308,6 +311,7 @@ async def evaluate(args):
|
|
| 308 |
verbose=args.verbose,
|
| 309 |
eval_pack_id=args.eval_pack_id,
|
| 310 |
opponent_mode=args.opponent_mode,
|
|
|
|
| 311 |
)
|
| 312 |
results.append(ep_result)
|
| 313 |
print(
|
|
@@ -332,6 +336,8 @@ def main():
|
|
| 332 |
parser.add_argument("--episodes", type=int, default=5)
|
| 333 |
parser.add_argument("--task", default="stage2_full",
|
| 334 |
choices=["stage1_format", "stage2_full", "eval_50over"])
|
|
|
|
|
|
|
| 335 |
parser.add_argument("--env-url", default=os.environ.get("CRICKET_CAPTAIN_ENV_URL", "ws://localhost:8000"))
|
| 336 |
parser.add_argument("--eval-pack-id", default=os.environ.get("CRICKET_EVAL_PACK_ID", "default"))
|
| 337 |
parser.add_argument("--opponent-mode", default=os.environ.get("CRICKET_OPPONENT_MODE", "heuristic"),
|
|
|
|
| 225 |
verbose: bool = False,
|
| 226 |
eval_pack_id: str = "default",
|
| 227 |
opponent_mode: str = "heuristic",
|
| 228 |
+
max_overs: int | None = None,
|
| 229 |
) -> dict[str, Any]:
|
| 230 |
+
# OpenEnv server routes reset params via `options`.
|
| 231 |
+
result = await env.reset(options={
|
| 232 |
+
"task": task,
|
| 233 |
+
"random_start": False,
|
| 234 |
+
"eval_pack_id": eval_pack_id,
|
| 235 |
+
"opponent_mode": opponent_mode,
|
| 236 |
+
"max_overs": max_overs,
|
| 237 |
+
})
|
| 238 |
obs = result.observation
|
| 239 |
|
| 240 |
history: list[dict] = []
|
|
|
|
| 311 |
verbose=args.verbose,
|
| 312 |
eval_pack_id=args.eval_pack_id,
|
| 313 |
opponent_mode=args.opponent_mode,
|
| 314 |
+
max_overs=args.max_overs,
|
| 315 |
)
|
| 316 |
results.append(ep_result)
|
| 317 |
print(
|
|
|
|
| 336 |
parser.add_argument("--episodes", type=int, default=5)
|
| 337 |
parser.add_argument("--task", default="stage2_full",
|
| 338 |
choices=["stage1_format", "stage2_full", "eval_50over"])
|
| 339 |
+
parser.add_argument("--max-overs", type=int, default=None,
|
| 340 |
+
help="Limit innings length for fast experiments (e.g. 5).")
|
| 341 |
parser.add_argument("--env-url", default=os.environ.get("CRICKET_CAPTAIN_ENV_URL", "ws://localhost:8000"))
|
| 342 |
parser.add_argument("--eval-pack-id", default=os.environ.get("CRICKET_EVAL_PACK_ID", "default"))
|
| 343 |
parser.add_argument("--opponent-mode", default=os.environ.get("CRICKET_OPPONENT_MODE", "heuristic"),
|