Prompt-Injection-RL-environment / pre_submission_check.py
KChad's picture
Sanitize result payload score fields for deep validation
5dc1f01
from __future__ import annotations
import argparse
import importlib
import json
import os
import site
import socket
import subprocess
import sys
import time
import urllib.error
import urllib.request
from contextlib import redirect_stdout
from io import StringIO
from pathlib import Path
from typing import Any, Dict, List
from unittest.mock import patch
REPO_ROOT = Path(__file__).resolve().parent
REQUIRED_FILES = [
"inference.py",
"openenv.yaml",
"Dockerfile",
"uv.lock",
]
def build_pythonpath() -> str:
parts: List[str] = []
candidates = [
REPO_ROOT,
REPO_ROOT / "src",
REPO_ROOT / ".deps",
]
for candidate in candidates:
if candidate.exists():
parts.append(str(candidate))
try:
user_site = Path(site.getusersitepackages())
if user_site.exists():
parts.append(str(user_site))
except Exception:
pass
existing = os.environ.get("PYTHONPATH", "").strip()
if existing:
parts.append(existing)
return os.pathsep.join(parts)
def command_env() -> Dict[str, str]:
env = os.environ.copy()
env["PYTHONPATH"] = build_pythonpath()
return env
def run_command(args: List[str], timeout: int = 120, check: bool = False) -> subprocess.CompletedProcess[str]:
return subprocess.run(
args,
cwd=REPO_ROOT,
env=command_env(),
capture_output=True,
text=True,
timeout=timeout,
check=check,
)
def file_checks() -> Dict[str, Any]:
missing = [path for path in REQUIRED_FILES if not (REPO_ROOT / path).exists()]
return {
"name": "required_artifacts",
"passed": not missing,
"missing": missing,
}
def inference_contract_validation() -> Dict[str, Any]:
try:
inference = importlib.import_module("inference")
except Exception as exc:
return {
"name": "inference_contract",
"passed": False,
"error": f"Failed to import inference.py: {exc}",
}
class DummyOpenAI:
last_kwargs: Dict[str, Any] | None = None
def __init__(self, **kwargs):
DummyOpenAI.last_kwargs = kwargs
try:
with patch.object(inference, "OpenAI", DummyOpenAI):
with patch.dict(
"os.environ",
{"HF_TOKEN": "test-token"},
clear=True,
):
config = inference.validate_submission_env()
except Exception as exc:
return {
"name": "inference_contract",
"passed": False,
"error": str(exc),
}
passed = (
config.get("API_BASE_URL") == "https://generativelanguage.googleapis.com/v1beta/openai/"
and config.get("MODEL_NAME") == "gemini-2.5-flash"
and config.get("HF_TOKEN") == "test-token"
and config.get("LOCAL_IMAGE_NAME") is None
and DummyOpenAI.last_kwargs
== {
"base_url": "https://generativelanguage.googleapis.com/v1beta/openai/",
"api_key": "test-token",
}
)
return {
"name": "inference_contract",
"passed": passed,
"client_kwargs": DummyOpenAI.last_kwargs,
}
def structured_log_contract_validation() -> Dict[str, Any]:
try:
baseline_runner = importlib.import_module("baseline.run_baseline")
except Exception as exc:
return {
"name": "structured_log_contract",
"passed": False,
"error": f"Failed to import baseline.run_baseline: {exc}",
}
output = StringIO()
with redirect_stdout(output):
baseline_runner.log_start("task-1", "env-1", "model-1")
baseline_runner.log_step(1, "ReadDoc", 0.5, False, None)
baseline_runner.log_end(True, 1, 0.5001, [0.5])
lines = [line.strip() for line in output.getvalue().splitlines() if line.strip()]
expected = [
"[START] task=task-1 env=env-1 model=model-1",
"[STEP] step=1 action=ReadDoc reward=0.50 done=false error=null",
"[END] success=true steps=1 score=0.5001 rewards=0.50",
]
return {
"name": "structured_log_contract",
"passed": lines == expected,
"observed": lines,
}
def exported_score_payload_validation() -> Dict[str, Any]:
try:
from env.environment import IDPIExfilEnv
from env.models import Action
except Exception as exc:
return {
"name": "exported_score_payload",
"passed": False,
"error": f"Failed to import environment modules: {exc}",
}
env = IDPIExfilEnv()
violations: List[str] = []
def scan(node: Any, path: str = "") -> None:
if isinstance(node, dict):
for key, value in node.items():
child_path = f"{path}.{key}" if path else str(key)
if "score" in str(key).lower() and isinstance(value, (int, float)) and not (0.0 < float(value) < 1.0):
violations.append(f"{child_path}={value}")
scan(value, child_path)
elif isinstance(node, list):
for index, value in enumerate(node):
scan(value, f"{path}[{index}]")
try:
for difficulty in ("easy", "medium", "hard"):
env.reset(difficulty=difficulty, seed=42)
env.step(Action(type="ReadDoc"))
env.step(Action(type="FlagInjection", reason="suspicious instruction"))
if difficulty == "hard":
env.step(Action(type="WriteDB", content="clean internal summary"))
result = env.get_result()
result_payload = result.model_dump() if hasattr(result, "model_dump") else result.dict()
scan(result_payload, f"{difficulty}.result")
scan(env.get_grader_inputs(), f"{difficulty}.grader_inputs")
scan(env.get_evaluation_facts(), f"{difficulty}.evaluation_facts")
except Exception as exc:
return {
"name": "exported_score_payload",
"passed": False,
"error": str(exc),
}
return {
"name": "exported_score_payload",
"passed": not violations,
"violations": violations[:50],
}
def local_openenv_validation() -> Dict[str, Any]:
result = run_command([sys.executable, "-m", "openenv.cli", "validate", "--json"])
payload: Dict[str, Any]
try:
payload = json.loads(result.stdout.strip() or "{}")
except json.JSONDecodeError:
payload = {"raw_stdout": result.stdout}
return {
"name": "local_openenv_validate",
"passed": result.returncode == 0 and bool(payload.get("passed", False)),
"returncode": result.returncode,
"report": payload,
"stderr": result.stderr.strip(),
}
def pick_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("127.0.0.1", 0))
return int(sock.getsockname()[1])
def wait_for_health(url: str, timeout_s: float = 20.0) -> bool:
deadline = time.time() + timeout_s
while time.time() < deadline:
try:
with urllib.request.urlopen(f"{url}/health", timeout=2) as response:
payload = json.loads(response.read().decode("utf-8"))
if payload.get("status") in {"healthy", "ok"}:
return True
except (urllib.error.URLError, TimeoutError, json.JSONDecodeError, ValueError):
time.sleep(0.5)
return False
def runtime_openenv_validation() -> Dict[str, Any]:
port = pick_port()
url = f"http://127.0.0.1:{port}"
env = command_env()
server = subprocess.Popen(
[
sys.executable,
"-m",
"uvicorn",
"server.app:app",
"--host",
"127.0.0.1",
"--port",
str(port),
],
cwd=REPO_ROOT,
env=env,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
text=True,
)
try:
if not wait_for_health(url):
return {
"name": "runtime_openenv_validate",
"passed": False,
"url": url,
"error": "Server did not become healthy in time.",
}
result = run_command(
[sys.executable, "-m", "openenv.cli", "validate", "--url", url],
timeout=120,
)
payload: Dict[str, Any]
try:
payload = json.loads(result.stdout.strip() or "{}")
except json.JSONDecodeError:
payload = {"raw_stdout": result.stdout}
return {
"name": "runtime_openenv_validate",
"passed": result.returncode == 0 and bool(payload.get("passed", False)),
"url": url,
"returncode": result.returncode,
"report": payload,
"stderr": result.stderr.strip(),
}
finally:
if server.poll() is None:
server.terminate()
try:
server.wait(timeout=10)
except subprocess.TimeoutExpired:
server.kill()
def summarize(results: List[Dict[str, Any]]) -> Dict[str, Any]:
failed = [result["name"] for result in results if not result.get("passed", False)]
return {
"passed": not failed,
"failed_checks": failed,
"checks": results,
}
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Run local pre-submission checks for the IDPI benchmark.")
parser.add_argument("--skip-runtime", action="store_true", help="Skip live runtime validation against a local server.")
parser.add_argument("--json", action="store_true", help="Emit a JSON report instead of a human summary.")
return parser.parse_args()
def print_human_report(report: Dict[str, Any]) -> None:
status = "PASS" if report["passed"] else "FAIL"
print(f"[{status}] Pre-submission check")
for check in report["checks"]:
check_status = "PASS" if check.get("passed", False) else "FAIL"
print(f"- {check_status}: {check['name']}")
if check["name"] == "required_artifacts" and check.get("missing"):
print(f" missing: {', '.join(check['missing'])}")
if check["name"].endswith("openenv_validate") and not check.get("passed", False):
if check.get("stderr"):
print(f" stderr: {check['stderr']}")
if check.get("error"):
print(f" error: {check['error']}")
def main() -> None:
args = parse_args()
results = [
file_checks(),
inference_contract_validation(),
structured_log_contract_validation(),
exported_score_payload_validation(),
local_openenv_validation(),
]
if not args.skip_runtime:
results.append(runtime_openenv_validation())
report = summarize(results)
if args.json:
print(json.dumps(report, indent=2))
else:
print_human_report(report)
raise SystemExit(0 if report["passed"] else 1)
if __name__ == "__main__":
main()