Adhitya-Vardhan commited on
Commit
c2edad1
·
1 Parent(s): f00a888

fix: read API_KEY env var and auto-select openai policy when credentials present

Browse files

- Read API_KEY (validator), HF_TOKEN, OPENAI_API_KEY — all supported
- Auto-switch to openai policy when any API key is detected in env
- Default ENV_BASE_URL to live HF Space (https://adhitya122-vulnops.hf.space)
- Falls back to heuristic locally when no key is set

Files changed (1) hide show
  1. inference.py +15 -7
inference.py CHANGED
@@ -12,7 +12,8 @@ from openenv.core import GenericEnvClient
12
 
13
  API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
14
  MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini")
15
- HF_TOKEN = os.getenv("HF_TOKEN")
 
16
 
17
  from models import VulnTriageAction
18
  from server.cases import TASK_ORDER, get_case_definition
@@ -47,11 +48,12 @@ Note: You CANNOT inspect "nvd_assessment", "github_commit_diff", or "vendor_stat
47
 
48
 
49
  def get_openai_client() -> OpenAI:
50
- api_key = HF_TOKEN or os.getenv("OPENAI_API_KEY")
51
  if not api_key:
52
- raise RuntimeError("Set HF_TOKEN before running the OpenAI baseline.")
53
-
54
- kwargs = {"api_key": api_key}
 
55
  if API_BASE_URL:
56
  kwargs["base_url"] = API_BASE_URL
57
  return OpenAI(**kwargs)
@@ -293,9 +295,15 @@ def run_remote_episode(base_url: str, task_id: str, policy: str, model_name: str
293
 
294
  def main() -> None:
295
  parser = argparse.ArgumentParser()
296
- parser.add_argument("--policy", choices=["openai", "heuristic"], default="heuristic")
 
 
 
 
297
  parser.add_argument("--model", default=MODEL_NAME)
298
- parser.add_argument("--env-base-url", dest="base_url", default=os.getenv("ENV_BASE_URL"))
 
 
299
  args = parser.parse_args()
300
 
301
  results: List[Dict[str, float]] = []
 
12
 
13
  API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
14
  MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini")
15
+ # Support all key variants the validator may inject
16
+ _API_KEY = os.getenv("API_KEY") or os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY")
17
 
18
  from models import VulnTriageAction
19
  from server.cases import TASK_ORDER, get_case_definition
 
48
 
49
 
50
  def get_openai_client() -> OpenAI:
51
+ api_key = _API_KEY
52
  if not api_key:
53
+ raise RuntimeError(
54
+ "Set API_KEY, HF_TOKEN, or OPENAI_API_KEY before running the OpenAI baseline."
55
+ )
56
+ kwargs: Dict[str, str] = {"api_key": api_key}
57
  if API_BASE_URL:
58
  kwargs["base_url"] = API_BASE_URL
59
  return OpenAI(**kwargs)
 
295
 
296
  def main() -> None:
297
  parser = argparse.ArgumentParser()
298
+ # Auto-select openai policy when the validator injects API credentials;
299
+ # fall back to heuristic for local smoke-tests with no key.
300
+ _has_credentials = bool(_API_KEY)
301
+ _default_policy = "openai" if _has_credentials else "heuristic"
302
+ parser.add_argument("--policy", choices=["openai", "heuristic"], default=_default_policy)
303
  parser.add_argument("--model", default=MODEL_NAME)
304
+ # Default ENV_BASE_URL to the live HF Space so the validator can reach our environment
305
+ _default_env_url = os.getenv("ENV_BASE_URL", "https://adhitya122-vulnops.hf.space")
306
+ parser.add_argument("--env-base-url", dest="base_url", default=_default_env_url)
307
  args = parser.parse_args()
308
 
309
  results: List[Dict[str, float]] = []