pratinavseth commited on
Commit
4d4439f
·
1 Parent(s): 2f787f1

fix: wire reset options for max_overs

Browse files

Pass reset parameters via OpenEnv `options` so max_overs and other env options take effect in inference/eval runs.

Made-with: Cursor

Files changed (2) hide show
  1. eval.py +12 -7
  2. inference.py +12 -6
eval.py CHANGED
@@ -52,17 +52,20 @@ async def collect_eval_episodes(
52
  task: str,
53
  eval_pack_id: str = "default",
54
  opponent_mode: str = "heuristic",
 
55
  ) -> list[dict[str, Any]]:
56
  """Run n_episodes and return raw episode data for visualisation."""
57
  episodes = []
58
  async with CricketCaptainEnv(env_url) as env:
59
  for ep in range(n_episodes):
60
- result = await env.reset(
61
- task=task,
62
- random_start=False,
63
- eval_pack_id=eval_pack_id,
64
- opponent_mode=opponent_mode,
65
- )
 
 
66
  obs = result.observation
67
  history = []
68
  step_data = []
@@ -291,7 +294,7 @@ async def _run_eval(args):
291
 
292
  print(f"Collecting {args.episodes} evaluation episodes...")
293
  episodes = await collect_eval_episodes(
294
- args.env_url, agent, args.episodes, args.task, args.eval_pack_id, args.opponent_mode
295
  )
296
 
297
  print_summary(episodes)
@@ -322,6 +325,8 @@ def main():
322
  parser.add_argument("--eval-pack-id", default=os.environ.get("CRICKET_EVAL_PACK_ID", "default"))
323
  parser.add_argument("--opponent-mode", default=os.environ.get("CRICKET_OPPONENT_MODE", "heuristic"),
324
  choices=["heuristic", "llm_live", "llm_cached"])
 
 
325
  parser.add_argument("--out-dir", default="./eval_output")
326
  parser.add_argument("--log-file", default=None,
327
  help="Path to JSONL training log for reward curves")
 
52
  task: str,
53
  eval_pack_id: str = "default",
54
  opponent_mode: str = "heuristic",
55
+ max_overs: int | None = None,
56
  ) -> list[dict[str, Any]]:
57
  """Run n_episodes and return raw episode data for visualisation."""
58
  episodes = []
59
  async with CricketCaptainEnv(env_url) as env:
60
  for ep in range(n_episodes):
61
+ # OpenEnv server routes reset params via `options`.
62
+ result = await env.reset(options={
63
+ "task": task,
64
+ "random_start": False,
65
+ "eval_pack_id": eval_pack_id,
66
+ "opponent_mode": opponent_mode,
67
+ "max_overs": max_overs,
68
+ })
69
  obs = result.observation
70
  history = []
71
  step_data = []
 
294
 
295
  print(f"Collecting {args.episodes} evaluation episodes...")
296
  episodes = await collect_eval_episodes(
297
+ args.env_url, agent, args.episodes, args.task, args.eval_pack_id, args.opponent_mode, args.max_overs
298
  )
299
 
300
  print_summary(episodes)
 
325
  parser.add_argument("--eval-pack-id", default=os.environ.get("CRICKET_EVAL_PACK_ID", "default"))
326
  parser.add_argument("--opponent-mode", default=os.environ.get("CRICKET_OPPONENT_MODE", "heuristic"),
327
  choices=["heuristic", "llm_live", "llm_cached"])
328
+ parser.add_argument("--max-overs", type=int, default=None,
329
+ help="Limit innings length for fast experiments (e.g. 5).")
330
  parser.add_argument("--out-dir", default="./eval_output")
331
  parser.add_argument("--log-file", default=None,
332
  help="Path to JSONL training log for reward curves")
inference.py CHANGED
@@ -225,13 +225,16 @@ async def run_episode(
225
  verbose: bool = False,
226
  eval_pack_id: str = "default",
227
  opponent_mode: str = "heuristic",
 
228
  ) -> dict[str, Any]:
229
- result = await env.reset(
230
- task=task,
231
- random_start=False,
232
- eval_pack_id=eval_pack_id,
233
- opponent_mode=opponent_mode,
234
- )
 
 
235
  obs = result.observation
236
 
237
  history: list[dict] = []
@@ -308,6 +311,7 @@ async def evaluate(args):
308
  verbose=args.verbose,
309
  eval_pack_id=args.eval_pack_id,
310
  opponent_mode=args.opponent_mode,
 
311
  )
312
  results.append(ep_result)
313
  print(
@@ -332,6 +336,8 @@ def main():
332
  parser.add_argument("--episodes", type=int, default=5)
333
  parser.add_argument("--task", default="stage2_full",
334
  choices=["stage1_format", "stage2_full", "eval_50over"])
 
 
335
  parser.add_argument("--env-url", default=os.environ.get("CRICKET_CAPTAIN_ENV_URL", "ws://localhost:8000"))
336
  parser.add_argument("--eval-pack-id", default=os.environ.get("CRICKET_EVAL_PACK_ID", "default"))
337
  parser.add_argument("--opponent-mode", default=os.environ.get("CRICKET_OPPONENT_MODE", "heuristic"),
 
225
  verbose: bool = False,
226
  eval_pack_id: str = "default",
227
  opponent_mode: str = "heuristic",
228
+ max_overs: int | None = None,
229
  ) -> dict[str, Any]:
230
+ # OpenEnv server routes reset params via `options`.
231
+ result = await env.reset(options={
232
+ "task": task,
233
+ "random_start": False,
234
+ "eval_pack_id": eval_pack_id,
235
+ "opponent_mode": opponent_mode,
236
+ "max_overs": max_overs,
237
+ })
238
  obs = result.observation
239
 
240
  history: list[dict] = []
 
311
  verbose=args.verbose,
312
  eval_pack_id=args.eval_pack_id,
313
  opponent_mode=args.opponent_mode,
314
+ max_overs=args.max_overs,
315
  )
316
  results.append(ep_result)
317
  print(
 
336
  parser.add_argument("--episodes", type=int, default=5)
337
  parser.add_argument("--task", default="stage2_full",
338
  choices=["stage1_format", "stage2_full", "eval_50over"])
339
+ parser.add_argument("--max-overs", type=int, default=None,
340
+ help="Limit innings length for fast experiments (e.g. 5).")
341
  parser.add_argument("--env-url", default=os.environ.get("CRICKET_CAPTAIN_ENV_URL", "ws://localhost:8000"))
342
  parser.add_argument("--eval-pack-id", default=os.environ.get("CRICKET_EVAL_PACK_ID", "default"))
343
  parser.add_argument("--opponent-mode", default=os.environ.get("CRICKET_OPPONENT_MODE", "heuristic"),