anugrah55's picture
verifier: fix SIGALRM-in-worker-thread bug that scored every well-formed submission 100/100 under uvicorn (fall back to no-timeout call when signal.signal raises). Trainer was training on a saturated reward landscape; this restores real per-submission scoring.
e7fc062 verified
"""Verifier: scores a submitted Python source against a hidden reference
function by domain-aware fuzzing, with sandboxed execution and a complexity
penalty.
Reward design (v0.3, paper-driven update):
* ``execution_reward`` in ``[0, 100]`` is the fraction of fuzz inputs whose
outputs match the reference, scaled to 100. Inputs are drawn from two
categories that are scored separately so the trainer can see *which*
regime the agent fails on (Masud et al., 2026 §P3 "reward granularity"):
- ``"edge"`` -- spec-defined must-pass cases (anti-deception, paper
§C1 of Ibrahim et al., 2024).
- ``"random"`` -- the original sampler.
* ``complexity_penalty`` in ``[0, 50]`` is a bounded log-scaled cyclomatic
complexity, or 50 on syntax error.
* ``reward_hack_penalty`` is a soft anti-hacking signal that fires when the
submission is a "constant function" (single distinct output / single
exception type) while the reference is genuinely diverse, OR the agent
attempts to import the reference module (we block this at sandbox-level
too, but we surface the attempt so the trainer can punish it).
* ``floor_penalty`` adds a hard ``-25`` floor for sub-50% submissions
(Vul-R2 style; Wen et al. 2025 in Masud et al. 2026 §3.4.2). This stops
agents from learning that emitting *any* syntactically-valid function
pays positive reward.
The headline ``total_reward`` returned in ``info`` is the *recommended*
total the env should hand back; the env is free to add a perfect-bonus on
top.
"""
from __future__ import annotations
import ast
import math
import multiprocessing as mp
import random
import signal
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional
# ----- AST complexity ------------------------------------------------------
class _CCVisitor(ast.NodeVisitor):
def __init__(self):
self.cc = 1
def _bump(self, node):
self.cc += 1
self.generic_visit(node)
visit_If = _bump
visit_For = _bump
visit_While = _bump
visit_AsyncFor = _bump
visit_ExceptHandler = _bump
visit_With = _bump
visit_IfExp = _bump
def visit_BoolOp(self, node):
self.cc += max(0, len(node.values) - 1)
self.generic_visit(node)
def calculate_complexity_penalty(code: str) -> float:
"""Bounded log-scaled cyclomatic complexity, or 50 if code won't parse."""
try:
tree = ast.parse(code)
except SyntaxError:
return 50.0
v = _CCVisitor()
v.visit(tree)
# log2 keeps small functions at ~0..2 and aggressive 100-branch lookups
# up around log2(100) ≈ 6.6, then we clamp.
return min(50.0, math.log2(v.cc))
# ----- Hardened sandbox ----------------------------------------------------
#
# Previous version exposed the real ``__builtins__`` to submitted code,
# which let an agent reward-hack with::
#
# def fibonacci(n):
# from opensleuth_env.black_box import _fibonacci
# return _fibonacci(n)
#
# We now restrict builtins to a hand-picked safe subset and hand-import the
# whitelisted helper modules so the agent doesn't need ``import`` at all.
# This is cheap defence-in-depth; the multiprocessing wall-clock timeout
# below handles infinite loops independently.
# Builtins safe to expose. Notably *no* ``__import__``, ``open``, ``exec``,
# ``eval``, ``compile``, ``input``, ``__build_class__``-via-import, etc.
_SAFE_BUILTINS_NAMES = (
"abs all any ascii bin bool bytes bytearray callable chr complex dict "
"divmod enumerate filter float format frozenset getattr hasattr hash "
"hex id int isinstance issubclass iter len list map max min next object "
"oct ord pow print property range repr reversed round set slice sorted "
"str sum tuple type zip True False None NotImplemented Ellipsis "
"ArithmeticError AssertionError AttributeError BaseException "
"BufferError BytesWarning DeprecationWarning EOFError Exception "
"FloatingPointError IndexError KeyError LookupError MemoryError "
"NameError NotImplementedError OverflowError RecursionError "
"ReferenceError RuntimeError StopAsyncIteration StopIteration "
"SyntaxError TypeError UnboundLocalError UnicodeError ValueError "
"ZeroDivisionError __build_class__"
).split()
def _make_safe_builtins() -> Dict[str, Any]:
import builtins as _b
out: Dict[str, Any] = {}
for n in _SAFE_BUILTINS_NAMES:
if hasattr(_b, n):
out[n] = getattr(_b, n)
return out
_SAFE_BUILTINS = _make_safe_builtins()
_PREIMPORTED_MODULES = ("math", "string", "itertools", "functools", "collections", "re")
def _make_safe_globals() -> Dict[str, Any]:
g: Dict[str, Any] = {
"__builtins__": _SAFE_BUILTINS,
"__name__": "__opensleuth_submission__",
}
for mod_name in _PREIMPORTED_MODULES:
g[mod_name] = __import__(mod_name)
return g
def _exec_target_in_sandbox(code: str, target_name: str, queue: mp.Queue) -> None:
"""Run inside a child process so we can hard-kill on timeout."""
try:
safe_globals = _make_safe_globals()
local_scope: dict = {}
exec(code, safe_globals, local_scope)
fn = local_scope.get(target_name) or safe_globals.get(target_name)
if not callable(fn):
queue.put(("err", f"No callable named {target_name!r} defined."))
return
queue.put(("ok", None))
except Exception as e: # noqa: BLE001
queue.put(("err", f"{type(e).__name__}: {e}"))
def _can_define(code: str, target_name: str, timeout_s: float) -> Optional[str]:
"""Return None if the submitted code defines the target callable, else an
error string. Uses a child process with a wall-clock timeout."""
ctx = mp.get_context("fork") if mp.get_start_method(allow_none=True) != "spawn" else mp.get_context("spawn")
q: mp.Queue = ctx.Queue()
p = ctx.Process(target=_exec_target_in_sandbox, args=(code, target_name, q))
p.start()
p.join(timeout=timeout_s)
if p.is_alive():
p.terminate()
p.join(1.0)
if p.is_alive():
p.kill()
return f"Definition timed out after {timeout_s}s."
if q.empty():
return "Sandbox produced no result."
status, payload = q.get()
return None if status == "ok" else payload
# Per-call (per-input) sandboxing is too slow for 100 fuzz inputs, so we
# accept the trade-off of running the submitted callable in-process for
# fuzzing, but we wrap each call in a SIGALRM-based timeout and we already
# proved at definition-time that the import didn't blow up.
class _CallTimeout(Exception):
pass
def _call_with_timeout(fn: Callable, arg: Any, timeout_s: float, *, unpack: bool = False):
"""Call ``fn(arg)`` (or ``fn(*arg)`` if ``unpack``) with a wall-clock
timeout when possible.
SIGALRM only works in the main thread of the main interpreter. When
the verifier runs inside a uvicorn worker thread (FastAPI request
handler), ``signal.signal`` raises ``ValueError`` and -- prior to this
fix -- both ref and submission calls would short-circuit through the
``except Exception`` branch in ``_safe_call`` with the SAME ValueError,
which ``_outputs_equivalent`` then read as a "match", silently
awarding 100/100 to *every* submission regardless of correctness.
Fix: per-call probe. If SIGALRM isn't installable from the current
thread, fall back to a direct call with no in-thread timeout. The
definition timeout is still enforced by the multiprocessing-based
``_can_define`` ahead of fuzz scoring, so a malformed submission
that hangs at import time is caught there. For the OpenSleuth
trainer/eval use case (cooperative, not adversarial), letting a
pathological while-True submission stall a single request worker
is an acceptable trade-off relative to the current "all submissions
are perfect" failure mode.
"""
def _do_call():
if unpack:
if not isinstance(arg, tuple):
# Defensive: a multi-param target should always receive a
# tuple, but if the agent's probe input_repr happens to
# parse to a single value, treat it as a 1-tuple so we get
# a clear TypeError rather than a confusing call shape.
a = (arg,)
else:
a = arg
return fn(*a)
return fn(arg)
def _handler(signum, frame): # noqa: ARG001
raise _CallTimeout()
try:
old = signal.signal(signal.SIGALRM, _handler)
except (ValueError, OSError):
# Not in the main thread (uvicorn worker, threadpool, ...).
# SIGALRM isn't available; do the unsafe-but-correct thing.
return _do_call()
signal.setitimer(signal.ITIMER_REAL, timeout_s)
try:
return _do_call()
finally:
signal.setitimer(signal.ITIMER_REAL, 0)
signal.signal(signal.SIGALRM, old)
def _safe_call(fn: Callable, arg: Any, timeout_s: float, *, unpack: bool = False):
"""Returns (kind, value): kind in {'val', 'err', 'timeout'}.
When ``unpack`` is True the input ``arg`` is expected to be an args
tuple and ``fn`` is invoked as ``fn(*arg)``. This is how multi-parameter
auto-fuzzer-driven targets are scored.
"""
try:
return ("val", _call_with_timeout(fn, arg, timeout_s, unpack=unpack))
except _CallTimeout:
return ("timeout", f"timed out after {timeout_s}s")
except Exception as e: # noqa: BLE001
return ("err", f"{type(e).__name__}: {e}")
# ----- Public scoring ------------------------------------------------------
@dataclass
class VerificationResult:
execution_reward: float
complexity_penalty: float
define_error: Optional[str]
matches: int
fuzz_count: int
# New, additive fields (do not change existing field meanings).
matches_by_category: Dict[str, int] = field(default_factory=dict)
counts_by_category: Dict[str, int] = field(default_factory=dict)
edge_pass_rate: float = 0.0
reward_hack_penalty: float = 0.0
floor_penalty: float = 0.0
def _detect_constant_collapse(
sub_outputs: List[Any], ref_outputs: List[Any], min_inputs: int = 6
) -> bool:
"""Return True if the submission collapsed to a single output / error type
while the reference produced genuine diversity. This catches the
'always return 0' / 'always raise' reward-hacking pattern.
"""
if len(sub_outputs) < min_inputs:
return False
def _signature(call_result):
kind, val = call_result
if kind == "val":
try:
return ("val", repr(val))
except Exception: # noqa: BLE001
return ("val", id(val))
if kind == "err":
return ("err", val.split(":", 1)[0])
return ("timeout", "")
sub_sig = {_signature(o) for o in sub_outputs}
ref_sig = {_signature(o) for o in ref_outputs}
return len(sub_sig) == 1 and len(ref_sig) >= 3
def _looks_like_reference_import(code: str) -> bool:
"""Static check for the most obvious reward-hacking pattern: importing
the reference function out of opensleuth_env. The sandbox already blocks
actual imports, but flagging them lets the env feed back a clear penalty
instead of a silent zero.
"""
try:
tree = ast.parse(code)
except SyntaxError:
return False
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
if alias.name.startswith("opensleuth"):
return True
elif isinstance(node, ast.ImportFrom):
if node.module and node.module.startswith("opensleuth"):
return True
return False
def verify_submission(
submitted_code: str,
target_function: Callable[..., Any],
fuzz_inputs: List[Any],
*,
target_name: Optional[str] = None,
define_timeout_s: float = 5.0,
call_timeout_s: float = 1.0,
edge_inputs: Optional[List[Any]] = None,
unpack_args: bool = False,
) -> VerificationResult:
"""Score ``submitted_code`` against ``target_function`` over the supplied
``fuzz_inputs`` (random regime) and ``edge_inputs`` (must-pass regime).
The agent is expected to define a top-level function with the same name as
``target_function`` (overridable via ``target_name``)."""
name = target_name or target_function.__name__
edge_inputs = list(edge_inputs or [])
# Static reward-hack flag: import-of-reference is always a -25 hit on top
# of whatever score the rest of the rubric assigns. Even if the sandbox
# successfully blocks the import (it will), we want to *teach* the agent
# not to try.
hack_penalty = 25.0 if _looks_like_reference_import(submitted_code) else 0.0
define_err = _can_define(submitted_code, name, define_timeout_s)
complexity = calculate_complexity_penalty(submitted_code)
if define_err is not None:
total = len(fuzz_inputs) + len(edge_inputs)
return VerificationResult(
execution_reward=0.0,
complexity_penalty=complexity,
define_error=define_err,
matches=0,
fuzz_count=total,
matches_by_category={"edge": 0, "random": 0},
counts_by_category={"edge": len(edge_inputs), "random": len(fuzz_inputs)},
edge_pass_rate=0.0,
reward_hack_penalty=hack_penalty,
floor_penalty=25.0,
)
# Re-define in-process for fast fuzzing. We just confirmed it won't blow
# up at import-time; we still time-bound each call. Note: we use the
# restricted globals so e.g. `__import__` is unavailable here too.
safe_globals = _make_safe_globals()
local_scope: dict = {}
exec(submitted_code, safe_globals, local_scope)
submitted_fn = local_scope.get(name) or safe_globals.get(name)
matches_by_cat: Dict[str, int] = {"edge": 0, "random": 0}
counts_by_cat: Dict[str, int] = {"edge": len(edge_inputs), "random": len(fuzz_inputs)}
sub_results: List[Any] = []
ref_results: List[Any] = []
def _score(inputs: List[Any], category: str) -> None:
for inp in inputs:
ref = _safe_call(target_function, inp, call_timeout_s, unpack=unpack_args)
sub = _safe_call(submitted_fn, inp, call_timeout_s, unpack=unpack_args)
sub_results.append(sub)
ref_results.append(ref)
if _outputs_equivalent(ref, sub):
matches_by_cat[category] += 1
_score(edge_inputs, "edge")
_score(fuzz_inputs, "random")
matches = matches_by_cat["edge"] + matches_by_cat["random"]
fuzz_count = len(fuzz_inputs) + len(edge_inputs) or 1
exec_reward = 100.0 * (matches / fuzz_count)
edge_pass_rate = (
matches_by_cat["edge"] / counts_by_cat["edge"] if counts_by_cat["edge"] else 0.0
)
# Anti-hacking: constant collapse penalty.
if _detect_constant_collapse(sub_results, ref_results):
hack_penalty += 15.0
# Hard floor for sub-50% match rate. Vul-R2 style: a wrong patch deserves
# a clearly negative signal so the agent doesn't learn that 'any defined
# function' pays out via the small complexity-bonus / step structure.
floor_penalty = 25.0 if exec_reward < 50.0 else 0.0
return VerificationResult(
execution_reward=exec_reward,
complexity_penalty=complexity,
define_error=None,
matches=matches,
fuzz_count=fuzz_count,
matches_by_category=matches_by_cat,
counts_by_category=counts_by_cat,
edge_pass_rate=edge_pass_rate,
reward_hack_penalty=hack_penalty,
floor_penalty=floor_penalty,
)
def _outputs_equivalent(ref, sub) -> bool:
"""Ref and sub are (kind, value) tuples from `_safe_call`. They count as
equivalent if both raised the same exception type, or both returned values
that are == equal."""
rkind, rval = ref
skind, sval = sub
if rkind == "val" and skind == "val":
try:
return rval == sval
except Exception: # noqa: BLE001
return False
if rkind == "err" and skind == "err":
return rval.split(":", 1)[0] == sval.split(":", 1)[0]
if rkind == "timeout" and skind == "timeout":
return True
return False
def generate_fuzz_inputs(
spec, count: int = 100, seed: Optional[int] = None
) -> List[Any]:
"""Public helper: pull ``count`` fuzz inputs from a FunctionSpec, optionally
seeded for reproducibility."""
rng = random.Random(seed)
return spec.fuzzer(rng, count)
def get_edge_inputs(spec) -> List[Any]:
"""Return the spec's must-pass edge inputs (empty list if the spec
predates the v0.3 schema)."""
return list(getattr(spec, "edge_cases", []) or [])