| from __future__ import annotations |
|
|
| import argparse |
| import importlib |
| import json |
| import os |
| import site |
| import socket |
| import subprocess |
| import sys |
| import time |
| import urllib.error |
| import urllib.request |
| from contextlib import redirect_stdout |
| from io import StringIO |
| from pathlib import Path |
| from typing import Any, Dict, List |
| from unittest.mock import patch |
|
|
|
|
| REPO_ROOT = Path(__file__).resolve().parent |
| REQUIRED_FILES = [ |
| "inference.py", |
| "openenv.yaml", |
| "Dockerfile", |
| "uv.lock", |
| ] |
|
|
|
|
| def build_pythonpath() -> str: |
| parts: List[str] = [] |
| candidates = [ |
| REPO_ROOT, |
| REPO_ROOT / "src", |
| REPO_ROOT / ".deps", |
| ] |
| for candidate in candidates: |
| if candidate.exists(): |
| parts.append(str(candidate)) |
|
|
| try: |
| user_site = Path(site.getusersitepackages()) |
| if user_site.exists(): |
| parts.append(str(user_site)) |
| except Exception: |
| pass |
|
|
| existing = os.environ.get("PYTHONPATH", "").strip() |
| if existing: |
| parts.append(existing) |
| return os.pathsep.join(parts) |
|
|
|
|
| def command_env() -> Dict[str, str]: |
| env = os.environ.copy() |
| env["PYTHONPATH"] = build_pythonpath() |
| return env |
|
|
|
|
| def run_command(args: List[str], timeout: int = 120, check: bool = False) -> subprocess.CompletedProcess[str]: |
| return subprocess.run( |
| args, |
| cwd=REPO_ROOT, |
| env=command_env(), |
| capture_output=True, |
| text=True, |
| timeout=timeout, |
| check=check, |
| ) |
|
|
|
|
| def file_checks() -> Dict[str, Any]: |
| missing = [path for path in REQUIRED_FILES if not (REPO_ROOT / path).exists()] |
| return { |
| "name": "required_artifacts", |
| "passed": not missing, |
| "missing": missing, |
| } |
|
|
|
|
| def inference_contract_validation() -> Dict[str, Any]: |
| try: |
| inference = importlib.import_module("inference") |
| except Exception as exc: |
| return { |
| "name": "inference_contract", |
| "passed": False, |
| "error": f"Failed to import inference.py: {exc}", |
| } |
|
|
| class DummyOpenAI: |
| last_kwargs: Dict[str, Any] | None = None |
|
|
| def __init__(self, **kwargs): |
| DummyOpenAI.last_kwargs = kwargs |
|
|
| try: |
| with patch.object(inference, "OpenAI", DummyOpenAI): |
| with patch.dict( |
| "os.environ", |
| {"HF_TOKEN": "test-token"}, |
| clear=True, |
| ): |
| config = inference.validate_submission_env() |
| except Exception as exc: |
| return { |
| "name": "inference_contract", |
| "passed": False, |
| "error": str(exc), |
| } |
|
|
| passed = ( |
| config.get("API_BASE_URL") == "https://generativelanguage.googleapis.com/v1beta/openai/" |
| and config.get("MODEL_NAME") == "gemini-2.5-flash" |
| and config.get("HF_TOKEN") == "test-token" |
| and config.get("LOCAL_IMAGE_NAME") is None |
| and DummyOpenAI.last_kwargs |
| == { |
| "base_url": "https://generativelanguage.googleapis.com/v1beta/openai/", |
| "api_key": "test-token", |
| } |
| ) |
| return { |
| "name": "inference_contract", |
| "passed": passed, |
| "client_kwargs": DummyOpenAI.last_kwargs, |
| } |
|
|
|
|
| def structured_log_contract_validation() -> Dict[str, Any]: |
| try: |
| baseline_runner = importlib.import_module("baseline.run_baseline") |
| except Exception as exc: |
| return { |
| "name": "structured_log_contract", |
| "passed": False, |
| "error": f"Failed to import baseline.run_baseline: {exc}", |
| } |
|
|
| output = StringIO() |
| with redirect_stdout(output): |
| baseline_runner.log_start("task-1", "env-1", "model-1") |
| baseline_runner.log_step(1, "ReadDoc", 0.5, False, None) |
| baseline_runner.log_end(True, 1, 0.5001, [0.5]) |
|
|
| lines = [line.strip() for line in output.getvalue().splitlines() if line.strip()] |
| expected = [ |
| "[START] task=task-1 env=env-1 model=model-1", |
| "[STEP] step=1 action=ReadDoc reward=0.50 done=false error=null", |
| "[END] success=true steps=1 score=0.5001 rewards=0.50", |
| ] |
| return { |
| "name": "structured_log_contract", |
| "passed": lines == expected, |
| "observed": lines, |
| } |
|
|
|
|
| def exported_score_payload_validation() -> Dict[str, Any]: |
| try: |
| from env.environment import IDPIExfilEnv |
| from env.models import Action |
| except Exception as exc: |
| return { |
| "name": "exported_score_payload", |
| "passed": False, |
| "error": f"Failed to import environment modules: {exc}", |
| } |
|
|
| env = IDPIExfilEnv() |
| violations: List[str] = [] |
|
|
| def scan(node: Any, path: str = "") -> None: |
| if isinstance(node, dict): |
| for key, value in node.items(): |
| child_path = f"{path}.{key}" if path else str(key) |
| if "score" in str(key).lower() and isinstance(value, (int, float)) and not (0.0 < float(value) < 1.0): |
| violations.append(f"{child_path}={value}") |
| scan(value, child_path) |
| elif isinstance(node, list): |
| for index, value in enumerate(node): |
| scan(value, f"{path}[{index}]") |
|
|
| try: |
| for difficulty in ("easy", "medium", "hard"): |
| env.reset(difficulty=difficulty, seed=42) |
| env.step(Action(type="ReadDoc")) |
| env.step(Action(type="FlagInjection", reason="suspicious instruction")) |
| if difficulty == "hard": |
| env.step(Action(type="WriteDB", content="clean internal summary")) |
|
|
| result = env.get_result() |
| result_payload = result.model_dump() if hasattr(result, "model_dump") else result.dict() |
| scan(result_payload, f"{difficulty}.result") |
| scan(env.get_grader_inputs(), f"{difficulty}.grader_inputs") |
| scan(env.get_evaluation_facts(), f"{difficulty}.evaluation_facts") |
| except Exception as exc: |
| return { |
| "name": "exported_score_payload", |
| "passed": False, |
| "error": str(exc), |
| } |
|
|
| return { |
| "name": "exported_score_payload", |
| "passed": not violations, |
| "violations": violations[:50], |
| } |
|
|
|
|
| def local_openenv_validation() -> Dict[str, Any]: |
| result = run_command([sys.executable, "-m", "openenv.cli", "validate", "--json"]) |
| payload: Dict[str, Any] |
| try: |
| payload = json.loads(result.stdout.strip() or "{}") |
| except json.JSONDecodeError: |
| payload = {"raw_stdout": result.stdout} |
|
|
| return { |
| "name": "local_openenv_validate", |
| "passed": result.returncode == 0 and bool(payload.get("passed", False)), |
| "returncode": result.returncode, |
| "report": payload, |
| "stderr": result.stderr.strip(), |
| } |
|
|
|
|
| def pick_port() -> int: |
| with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: |
| sock.bind(("127.0.0.1", 0)) |
| return int(sock.getsockname()[1]) |
|
|
|
|
| def wait_for_health(url: str, timeout_s: float = 20.0) -> bool: |
| deadline = time.time() + timeout_s |
| while time.time() < deadline: |
| try: |
| with urllib.request.urlopen(f"{url}/health", timeout=2) as response: |
| payload = json.loads(response.read().decode("utf-8")) |
| if payload.get("status") in {"healthy", "ok"}: |
| return True |
| except (urllib.error.URLError, TimeoutError, json.JSONDecodeError, ValueError): |
| time.sleep(0.5) |
| return False |
|
|
|
|
| def runtime_openenv_validation() -> Dict[str, Any]: |
| port = pick_port() |
| url = f"http://127.0.0.1:{port}" |
| env = command_env() |
| server = subprocess.Popen( |
| [ |
| sys.executable, |
| "-m", |
| "uvicorn", |
| "server.app:app", |
| "--host", |
| "127.0.0.1", |
| "--port", |
| str(port), |
| ], |
| cwd=REPO_ROOT, |
| env=env, |
| stdout=subprocess.DEVNULL, |
| stderr=subprocess.DEVNULL, |
| text=True, |
| ) |
|
|
| try: |
| if not wait_for_health(url): |
| return { |
| "name": "runtime_openenv_validate", |
| "passed": False, |
| "url": url, |
| "error": "Server did not become healthy in time.", |
| } |
|
|
| result = run_command( |
| [sys.executable, "-m", "openenv.cli", "validate", "--url", url], |
| timeout=120, |
| ) |
| payload: Dict[str, Any] |
| try: |
| payload = json.loads(result.stdout.strip() or "{}") |
| except json.JSONDecodeError: |
| payload = {"raw_stdout": result.stdout} |
|
|
| return { |
| "name": "runtime_openenv_validate", |
| "passed": result.returncode == 0 and bool(payload.get("passed", False)), |
| "url": url, |
| "returncode": result.returncode, |
| "report": payload, |
| "stderr": result.stderr.strip(), |
| } |
| finally: |
| if server.poll() is None: |
| server.terminate() |
| try: |
| server.wait(timeout=10) |
| except subprocess.TimeoutExpired: |
| server.kill() |
|
|
|
|
| def summarize(results: List[Dict[str, Any]]) -> Dict[str, Any]: |
| failed = [result["name"] for result in results if not result.get("passed", False)] |
| return { |
| "passed": not failed, |
| "failed_checks": failed, |
| "checks": results, |
| } |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Run local pre-submission checks for the IDPI benchmark.") |
| parser.add_argument("--skip-runtime", action="store_true", help="Skip live runtime validation against a local server.") |
| parser.add_argument("--json", action="store_true", help="Emit a JSON report instead of a human summary.") |
| return parser.parse_args() |
|
|
|
|
| def print_human_report(report: Dict[str, Any]) -> None: |
| status = "PASS" if report["passed"] else "FAIL" |
| print(f"[{status}] Pre-submission check") |
| for check in report["checks"]: |
| check_status = "PASS" if check.get("passed", False) else "FAIL" |
| print(f"- {check_status}: {check['name']}") |
| if check["name"] == "required_artifacts" and check.get("missing"): |
| print(f" missing: {', '.join(check['missing'])}") |
| if check["name"].endswith("openenv_validate") and not check.get("passed", False): |
| if check.get("stderr"): |
| print(f" stderr: {check['stderr']}") |
| if check.get("error"): |
| print(f" error: {check['error']}") |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| results = [ |
| file_checks(), |
| inference_contract_validation(), |
| structured_log_contract_validation(), |
| exported_score_payload_validation(), |
| local_openenv_validation(), |
| ] |
| if not args.skip_runtime: |
| results.append(runtime_openenv_validation()) |
|
|
| report = summarize(results) |
| if args.json: |
| print(json.dumps(report, indent=2)) |
| else: |
| print_human_report(report) |
|
|
| raise SystemExit(0 if report["passed"] else 1) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|