NeerajCodz commited on
Commit
f594f81
·
1 Parent(s): f6b54cb

fix: inference.py reset

Browse files
Files changed (1) hide show
  1. inference.py +61 -8
inference.py CHANGED
@@ -59,12 +59,13 @@ HF_TOKEN = os.getenv("HF_TOKEN")
59
  ENV_API_BASE_URL = _env_str("ENV_API_BASE_URL", "http://localhost:8000/api")
60
  TASK_NAME_DEFAULT = _env_str("TASK_NAME", "task_001")
61
  BENCHMARK_DEFAULT = _env_str("BENCHMARK", "openenv")
 
62
  MAX_STEPS_DEFAULT = _env_int("MAX_STEPS", 12)
63
  EPISODE_SEED_DEFAULT = _env_int("EPISODE_SEED", 42)
64
  LLM_TEMPERATURE = _env_float("LLM_TEMPERATURE", 0.0)
65
  PROMPT_HTML_LIMIT = _env_int("PROMPT_HTML_LIMIT", 5000)
66
  REQUEST_TIMEOUT_SECONDS = _env_float("REQUEST_TIMEOUT_SECONDS", 30.0)
67
- USE_OPENENV_SDK = _env_bool("USE_OPENENV_SDK", True)
68
 
69
 
70
  @dataclass
@@ -125,9 +126,20 @@ def _emit_step(step_number: int, action: str, reward: float, done: bool, error_v
125
  )
126
 
127
 
128
- def _emit_end(success: bool, steps: int, rewards: list[float]) -> None:
129
  rewards_text = ",".join(_reward_text(reward) for reward in rewards)
130
- print(f"[END] success={_bool_text(success)} steps={steps} rewards={rewards_text}", flush=True)
 
 
 
 
 
 
 
 
 
 
 
131
 
132
 
133
  def _action_to_log_string(action: dict[str, Any]) -> str:
@@ -482,7 +494,31 @@ class OpenEnvSDKAdapter:
482
  raise RuntimeError("Unsupported step() return format from OpenEnv SDK")
483
 
484
 
485
- def _build_adapter(benchmark: str, env_api_base_url: str) -> EpisodeAdapter:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
486
  if USE_OPENENV_SDK:
487
  try:
488
  return OpenEnvSDKAdapter(benchmark)
@@ -491,7 +527,14 @@ def _build_adapter(benchmark: str, env_api_base_url: str) -> EpisodeAdapter:
491
  return ScrapeRLEpisodeAdapter(env_api_base_url)
492
 
493
 
494
- def run_inference(task_name: str, benchmark: str, max_steps: int, seed: int, env_api_base_url: str) -> int:
 
 
 
 
 
 
 
495
  rewards: list[float] = []
496
  steps = 0
497
  success = False
@@ -506,7 +549,11 @@ def run_inference(task_name: str, benchmark: str, max_steps: int, seed: int, env
506
  from openai import OpenAI
507
 
508
  client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
509
- adapter = _build_adapter(benchmark=benchmark, env_api_base_url=env_api_base_url)
 
 
 
 
510
  observation, info = adapter.reset(task_name=task_name, seed=seed)
511
 
512
  for step_number in range(1, max_steps + 1):
@@ -547,7 +594,7 @@ def run_inference(task_name: str, benchmark: str, max_steps: int, seed: int, env
547
  adapter.close()
548
  except Exception:
549
  pass
550
- _emit_end(success=success, steps=steps, rewards=rewards)
551
 
552
  return 0 if success else 1
553
 
@@ -563,6 +610,11 @@ def parse_args() -> argparse.Namespace:
563
  default=ENV_API_BASE_URL,
564
  help="Fallback environment API base URL (used when OpenEnv SDK is unavailable)",
565
  )
 
 
 
 
 
566
  return parser.parse_args()
567
 
568
 
@@ -575,10 +627,11 @@ if __name__ == "__main__":
575
  max_steps=args.max_steps,
576
  seed=args.seed,
577
  env_api_base_url=args.env_api_base_url,
 
578
  )
579
  except Exception:
580
  # Last-resort guard: never allow an unhandled exception to escape.
581
  _emit_start(task_name=TASK_NAME_DEFAULT, benchmark=BENCHMARK_DEFAULT, model_name=MODEL_NAME)
582
- _emit_end(success=False, steps=0, rewards=[])
583
  exit_code = 1
584
  sys.exit(exit_code)
 
59
  ENV_API_BASE_URL = _env_str("ENV_API_BASE_URL", "http://localhost:8000/api")
60
  TASK_NAME_DEFAULT = _env_str("TASK_NAME", "task_001")
61
  BENCHMARK_DEFAULT = _env_str("BENCHMARK", "openenv")
62
+ LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
63
  MAX_STEPS_DEFAULT = _env_int("MAX_STEPS", 12)
64
  EPISODE_SEED_DEFAULT = _env_int("EPISODE_SEED", 42)
65
  LLM_TEMPERATURE = _env_float("LLM_TEMPERATURE", 0.0)
66
  PROMPT_HTML_LIMIT = _env_int("PROMPT_HTML_LIMIT", 5000)
67
  REQUEST_TIMEOUT_SECONDS = _env_float("REQUEST_TIMEOUT_SECONDS", 30.0)
68
+ USE_OPENENV_SDK = _env_bool("USE_OPENENV_SDK", False)
69
 
70
 
71
  @dataclass
 
126
  )
127
 
128
 
129
+ def _emit_end(success: bool, steps: int, score: float, rewards: list[float]) -> None:
130
  rewards_text = ",".join(_reward_text(reward) for reward in rewards)
131
+ print(
132
+ f"[END] success={_bool_text(success)} steps={steps} score={_reward_text(score)} rewards={rewards_text}",
133
+ flush=True,
134
+ )
135
+
136
+
137
+ def _compute_score(success: bool, rewards: list[float]) -> float:
138
+ if success:
139
+ return 1.0
140
+ if not rewards:
141
+ return 0.0
142
+ return max(0.0, min(1.0, max(float(value) for value in rewards)))
143
 
144
 
145
  def _action_to_log_string(action: dict[str, Any]) -> str:
 
494
  raise RuntimeError("Unsupported step() return format from OpenEnv SDK")
495
 
496
 
497
+ class OpenEnvDockerImageAdapter:
498
+ def __init__(self, image_name: str) -> None:
499
+ import openenv # type: ignore
500
+
501
+ if not hasattr(openenv, "from_docker_image"):
502
+ raise RuntimeError("openenv.from_docker_image is not available")
503
+ self.env = openenv.from_docker_image(image_name)
504
+
505
+ def reset(self, task_name: str, seed: int) -> tuple[dict[str, Any], dict[str, Any]]:
506
+ return OpenEnvSDKAdapter._parse_reset(self.env.reset(task_name=task_name, seed=seed))
507
+
508
+ def step(self, action: dict[str, Any]) -> StepOutcome:
509
+ return OpenEnvSDKAdapter._parse_step(self.env.step(action))
510
+
511
+ def close(self) -> None:
512
+ if hasattr(self.env, "close"):
513
+ self.env.close()
514
+
515
+
516
+ def _build_adapter(benchmark: str, env_api_base_url: str, local_image_name: str | None) -> EpisodeAdapter:
517
+ if isinstance(local_image_name, str) and local_image_name.strip():
518
+ try:
519
+ return OpenEnvDockerImageAdapter(local_image_name.strip())
520
+ except Exception:
521
+ pass
522
  if USE_OPENENV_SDK:
523
  try:
524
  return OpenEnvSDKAdapter(benchmark)
 
527
  return ScrapeRLEpisodeAdapter(env_api_base_url)
528
 
529
 
530
+ def run_inference(
531
+ task_name: str,
532
+ benchmark: str,
533
+ max_steps: int,
534
+ seed: int,
535
+ env_api_base_url: str,
536
+ local_image_name: str | None,
537
+ ) -> int:
538
  rewards: list[float] = []
539
  steps = 0
540
  success = False
 
549
  from openai import OpenAI
550
 
551
  client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
552
+ adapter = _build_adapter(
553
+ benchmark=benchmark,
554
+ env_api_base_url=env_api_base_url,
555
+ local_image_name=local_image_name,
556
+ )
557
  observation, info = adapter.reset(task_name=task_name, seed=seed)
558
 
559
  for step_number in range(1, max_steps + 1):
 
594
  adapter.close()
595
  except Exception:
596
  pass
597
+ _emit_end(success=success, steps=steps, score=_compute_score(success, rewards), rewards=rewards)
598
 
599
  return 0 if success else 1
600
 
 
610
  default=ENV_API_BASE_URL,
611
  help="Fallback environment API base URL (used when OpenEnv SDK is unavailable)",
612
  )
613
+ parser.add_argument(
614
+ "--local-image-name",
615
+ default=LOCAL_IMAGE_NAME,
616
+ help="Docker image name for OpenEnv from_docker_image bridge (optional)",
617
+ )
618
  return parser.parse_args()
619
 
620
 
 
627
  max_steps=args.max_steps,
628
  seed=args.seed,
629
  env_api_base_url=args.env_api_base_url,
630
+ local_image_name=args.local_image_name,
631
  )
632
  except Exception:
633
  # Last-resort guard: never allow an unhandled exception to escape.
634
  _emit_start(task_name=TASK_NAME_DEFAULT, benchmark=BENCHMARK_DEFAULT, model_name=MODEL_NAME)
635
+ _emit_end(success=False, steps=0, score=0.0, rewards=[])
636
  exit_code = 1
637
  sys.exit(exit_code)