File size: 16,848 Bytes
d3cd20c
 
 
 
536dda7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3cd20c
 
 
 
ee14542
 
d3cd20c
 
 
536dda7
 
d3cd20c
ee14542
d3cd20c
 
 
 
ee14542
d3cd20c
 
 
 
ee14542
d3cd20c
 
 
 
 
 
 
 
 
 
 
ee14542
 
d3cd20c
 
 
ee14542
 
 
d3cd20c
 
 
 
 
 
 
 
536dda7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee14542
 
d3cd20c
 
ee14542
536dda7
d3cd20c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77e65fb
e7fc062
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77e65fb
 
 
 
 
 
e7fc062
 
 
 
d3cd20c
e7fc062
 
 
 
 
 
 
 
 
 
 
 
 
 
d3cd20c
 
 
 
 
77e65fb
 
 
 
 
 
 
d3cd20c
77e65fb
d3cd20c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
536dda7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3cd20c
 
 
 
77e65fb
d3cd20c
 
 
 
 
536dda7
77e65fb
d3cd20c
536dda7
 
 
 
d3cd20c
536dda7
 
 
 
 
 
 
d3cd20c
 
 
 
536dda7
d3cd20c
 
 
 
 
536dda7
 
 
 
 
 
d3cd20c
 
 
536dda7
 
 
d3cd20c
 
 
 
536dda7
 
d3cd20c
536dda7
 
 
 
 
77e65fb
 
536dda7
 
 
 
 
 
 
 
 
 
d3cd20c
536dda7
 
 
 
 
 
 
 
 
 
 
 
 
d3cd20c
 
 
 
 
 
536dda7
 
 
 
 
d3cd20c
 
 
 
 
 
 
 
 
 
ee14542
d3cd20c
 
 
 
 
 
 
 
 
 
 
 
 
536dda7
d3cd20c
 
 
536dda7
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
"""Verifier: scores a submitted Python source against a hidden reference
function by domain-aware fuzzing, with sandboxed execution and a complexity
penalty.

Reward design (v0.3, paper-driven update):

* ``execution_reward`` in ``[0, 100]`` is the fraction of fuzz inputs whose
  outputs match the reference, scaled to 100. Inputs are drawn from two
  categories that are scored separately so the trainer can see *which*
  regime the agent fails on (Masud et al., 2026 §P3 "reward granularity"):

    - ``"edge"``   -- spec-defined must-pass cases (anti-deception, paper
      §C1 of Ibrahim et al., 2024).
    - ``"random"`` -- the original sampler.

* ``complexity_penalty`` in ``[0, 50]`` is a bounded log-scaled cyclomatic
  complexity, or 50 on syntax error.

* ``reward_hack_penalty`` is a soft anti-hacking signal that fires when the
  submission is a "constant function" (single distinct output / single
  exception type) while the reference is genuinely diverse, OR the agent
  attempts to import the reference module (we block this at sandbox-level
  too, but we surface the attempt so the trainer can punish it).

* ``floor_penalty`` adds a hard ``-25`` floor for sub-50% submissions
  (Vul-R2 style; Wen et al. 2025 in Masud et al. 2026 §3.4.2). This stops
  agents from learning that emitting *any* syntactically-valid function
  pays positive reward.

The headline ``total_reward`` returned in ``info`` is the *recommended*
total the env should hand back; the env is free to add a perfect-bonus on
top.
"""

from __future__ import annotations

import ast
import math
import multiprocessing as mp
import random
import signal
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional


# ----- AST complexity ------------------------------------------------------


class _CCVisitor(ast.NodeVisitor):
    def __init__(self):
        self.cc = 1

    def _bump(self, node):
        self.cc += 1
        self.generic_visit(node)

    visit_If = _bump
    visit_For = _bump
    visit_While = _bump
    visit_AsyncFor = _bump
    visit_ExceptHandler = _bump
    visit_With = _bump
    visit_IfExp = _bump

    def visit_BoolOp(self, node):
        self.cc += max(0, len(node.values) - 1)
        self.generic_visit(node)


def calculate_complexity_penalty(code: str) -> float:
    """Bounded log-scaled cyclomatic complexity, or 50 if code won't parse."""
    try:
        tree = ast.parse(code)
    except SyntaxError:
        return 50.0
    v = _CCVisitor()
    v.visit(tree)
    # log2 keeps small functions at ~0..2 and aggressive 100-branch lookups
    # up around log2(100) ≈ 6.6, then we clamp.
    return min(50.0, math.log2(v.cc))


# ----- Hardened sandbox ----------------------------------------------------
#
# Previous version exposed the real ``__builtins__`` to submitted code,
# which let an agent reward-hack with::
#
#     def fibonacci(n):
#         from opensleuth_env.black_box import _fibonacci
#         return _fibonacci(n)
#
# We now restrict builtins to a hand-picked safe subset and hand-import the
# whitelisted helper modules so the agent doesn't need ``import`` at all.
# This is cheap defence-in-depth; the multiprocessing wall-clock timeout
# below handles infinite loops independently.


# Builtins safe to expose. Notably *no* ``__import__``, ``open``, ``exec``,
# ``eval``, ``compile``, ``input``, ``__build_class__``-via-import, etc.
_SAFE_BUILTINS_NAMES = (
    "abs all any ascii bin bool bytes bytearray callable chr complex dict "
    "divmod enumerate filter float format frozenset getattr hasattr hash "
    "hex id int isinstance issubclass iter len list map max min next object "
    "oct ord pow print property range repr reversed round set slice sorted "
    "str sum tuple type zip True False None NotImplemented Ellipsis "
    "ArithmeticError AssertionError AttributeError BaseException "
    "BufferError BytesWarning DeprecationWarning EOFError Exception "
    "FloatingPointError IndexError KeyError LookupError MemoryError "
    "NameError NotImplementedError OverflowError RecursionError "
    "ReferenceError RuntimeError StopAsyncIteration StopIteration "
    "SyntaxError TypeError UnboundLocalError UnicodeError ValueError "
    "ZeroDivisionError __build_class__"
).split()


def _make_safe_builtins() -> Dict[str, Any]:
    import builtins as _b
    out: Dict[str, Any] = {}
    for n in _SAFE_BUILTINS_NAMES:
        if hasattr(_b, n):
            out[n] = getattr(_b, n)
    return out


_SAFE_BUILTINS = _make_safe_builtins()
_PREIMPORTED_MODULES = ("math", "string", "itertools", "functools", "collections", "re")


def _make_safe_globals() -> Dict[str, Any]:
    g: Dict[str, Any] = {
        "__builtins__": _SAFE_BUILTINS,
        "__name__": "__opensleuth_submission__",
    }
    for mod_name in _PREIMPORTED_MODULES:
        g[mod_name] = __import__(mod_name)
    return g


def _exec_target_in_sandbox(code: str, target_name: str, queue: mp.Queue) -> None:
    """Run inside a child process so we can hard-kill on timeout."""
    try:
        safe_globals = _make_safe_globals()
        local_scope: dict = {}
        exec(code, safe_globals, local_scope)
        fn = local_scope.get(target_name) or safe_globals.get(target_name)
        if not callable(fn):
            queue.put(("err", f"No callable named {target_name!r} defined."))
            return
        queue.put(("ok", None))
    except Exception as e:  # noqa: BLE001
        queue.put(("err", f"{type(e).__name__}: {e}"))


def _can_define(code: str, target_name: str, timeout_s: float) -> Optional[str]:
    """Return None if the submitted code defines the target callable, else an
    error string. Uses a child process with a wall-clock timeout."""
    ctx = mp.get_context("fork") if mp.get_start_method(allow_none=True) != "spawn" else mp.get_context("spawn")
    q: mp.Queue = ctx.Queue()
    p = ctx.Process(target=_exec_target_in_sandbox, args=(code, target_name, q))
    p.start()
    p.join(timeout=timeout_s)
    if p.is_alive():
        p.terminate()
        p.join(1.0)
        if p.is_alive():
            p.kill()
        return f"Definition timed out after {timeout_s}s."
    if q.empty():
        return "Sandbox produced no result."
    status, payload = q.get()
    return None if status == "ok" else payload


# Per-call (per-input) sandboxing is too slow for 100 fuzz inputs, so we
# accept the trade-off of running the submitted callable in-process for
# fuzzing, but we wrap each call in a SIGALRM-based timeout and we already
# proved at definition-time that the import didn't blow up.

class _CallTimeout(Exception):
    pass


def _call_with_timeout(fn: Callable, arg: Any, timeout_s: float, *, unpack: bool = False):
    """Call ``fn(arg)`` (or ``fn(*arg)`` if ``unpack``) with a wall-clock
    timeout when possible.

    SIGALRM only works in the main thread of the main interpreter. When
    the verifier runs inside a uvicorn worker thread (FastAPI request
    handler), ``signal.signal`` raises ``ValueError`` and -- prior to this
    fix -- both ref and submission calls would short-circuit through the
    ``except Exception`` branch in ``_safe_call`` with the SAME ValueError,
    which ``_outputs_equivalent`` then read as a "match", silently
    awarding 100/100 to *every* submission regardless of correctness.

    Fix: per-call probe. If SIGALRM isn't installable from the current
    thread, fall back to a direct call with no in-thread timeout. The
    definition timeout is still enforced by the multiprocessing-based
    ``_can_define`` ahead of fuzz scoring, so a malformed submission
    that hangs at import time is caught there. For the OpenSleuth
    trainer/eval use case (cooperative, not adversarial), letting a
    pathological while-True submission stall a single request worker
    is an acceptable trade-off relative to the current "all submissions
    are perfect" failure mode.
    """
    def _do_call():
        if unpack:
            if not isinstance(arg, tuple):
                # Defensive: a multi-param target should always receive a
                # tuple, but if the agent's probe input_repr happens to
                # parse to a single value, treat it as a 1-tuple so we get
                # a clear TypeError rather than a confusing call shape.
                a = (arg,)
            else:
                a = arg
            return fn(*a)
        return fn(arg)

    def _handler(signum, frame):  # noqa: ARG001
        raise _CallTimeout()

    try:
        old = signal.signal(signal.SIGALRM, _handler)
    except (ValueError, OSError):
        # Not in the main thread (uvicorn worker, threadpool, ...).
        # SIGALRM isn't available; do the unsafe-but-correct thing.
        return _do_call()

    signal.setitimer(signal.ITIMER_REAL, timeout_s)
    try:
        return _do_call()
    finally:
        signal.setitimer(signal.ITIMER_REAL, 0)
        signal.signal(signal.SIGALRM, old)


def _safe_call(fn: Callable, arg: Any, timeout_s: float, *, unpack: bool = False):
    """Returns (kind, value): kind in {'val', 'err', 'timeout'}.

    When ``unpack`` is True the input ``arg`` is expected to be an args
    tuple and ``fn`` is invoked as ``fn(*arg)``. This is how multi-parameter
    auto-fuzzer-driven targets are scored.
    """
    try:
        return ("val", _call_with_timeout(fn, arg, timeout_s, unpack=unpack))
    except _CallTimeout:
        return ("timeout", f"timed out after {timeout_s}s")
    except Exception as e:  # noqa: BLE001
        return ("err", f"{type(e).__name__}: {e}")


# ----- Public scoring ------------------------------------------------------


@dataclass
class VerificationResult:
    execution_reward: float
    complexity_penalty: float
    define_error: Optional[str]
    matches: int
    fuzz_count: int
    # New, additive fields (do not change existing field meanings).
    matches_by_category: Dict[str, int] = field(default_factory=dict)
    counts_by_category: Dict[str, int] = field(default_factory=dict)
    edge_pass_rate: float = 0.0
    reward_hack_penalty: float = 0.0
    floor_penalty: float = 0.0


def _detect_constant_collapse(
    sub_outputs: List[Any], ref_outputs: List[Any], min_inputs: int = 6
) -> bool:
    """Return True if the submission collapsed to a single output / error type
    while the reference produced genuine diversity. This catches the
    'always return 0' / 'always raise' reward-hacking pattern.
    """
    if len(sub_outputs) < min_inputs:
        return False

    def _signature(call_result):
        kind, val = call_result
        if kind == "val":
            try:
                return ("val", repr(val))
            except Exception:  # noqa: BLE001
                return ("val", id(val))
        if kind == "err":
            return ("err", val.split(":", 1)[0])
        return ("timeout", "")

    sub_sig = {_signature(o) for o in sub_outputs}
    ref_sig = {_signature(o) for o in ref_outputs}
    return len(sub_sig) == 1 and len(ref_sig) >= 3


def _looks_like_reference_import(code: str) -> bool:
    """Static check for the most obvious reward-hacking pattern: importing
    the reference function out of opensleuth_env. The sandbox already blocks
    actual imports, but flagging them lets the env feed back a clear penalty
    instead of a silent zero.
    """
    try:
        tree = ast.parse(code)
    except SyntaxError:
        return False
    for node in ast.walk(tree):
        if isinstance(node, ast.Import):
            for alias in node.names:
                if alias.name.startswith("opensleuth"):
                    return True
        elif isinstance(node, ast.ImportFrom):
            if node.module and node.module.startswith("opensleuth"):
                return True
    return False


def verify_submission(
    submitted_code: str,
    target_function: Callable[..., Any],
    fuzz_inputs: List[Any],
    *,
    target_name: Optional[str] = None,
    define_timeout_s: float = 5.0,
    call_timeout_s: float = 1.0,
    edge_inputs: Optional[List[Any]] = None,
    unpack_args: bool = False,
) -> VerificationResult:
    """Score ``submitted_code`` against ``target_function`` over the supplied
    ``fuzz_inputs`` (random regime) and ``edge_inputs`` (must-pass regime).
    The agent is expected to define a top-level function with the same name as
    ``target_function`` (overridable via ``target_name``)."""
    name = target_name or target_function.__name__
    edge_inputs = list(edge_inputs or [])

    # Static reward-hack flag: import-of-reference is always a -25 hit on top
    # of whatever score the rest of the rubric assigns. Even if the sandbox
    # successfully blocks the import (it will), we want to *teach* the agent
    # not to try.
    hack_penalty = 25.0 if _looks_like_reference_import(submitted_code) else 0.0

    define_err = _can_define(submitted_code, name, define_timeout_s)
    complexity = calculate_complexity_penalty(submitted_code)
    if define_err is not None:
        total = len(fuzz_inputs) + len(edge_inputs)
        return VerificationResult(
            execution_reward=0.0,
            complexity_penalty=complexity,
            define_error=define_err,
            matches=0,
            fuzz_count=total,
            matches_by_category={"edge": 0, "random": 0},
            counts_by_category={"edge": len(edge_inputs), "random": len(fuzz_inputs)},
            edge_pass_rate=0.0,
            reward_hack_penalty=hack_penalty,
            floor_penalty=25.0,
        )

    # Re-define in-process for fast fuzzing. We just confirmed it won't blow
    # up at import-time; we still time-bound each call. Note: we use the
    # restricted globals so e.g. `__import__` is unavailable here too.
    safe_globals = _make_safe_globals()
    local_scope: dict = {}
    exec(submitted_code, safe_globals, local_scope)
    submitted_fn = local_scope.get(name) or safe_globals.get(name)

    matches_by_cat: Dict[str, int] = {"edge": 0, "random": 0}
    counts_by_cat: Dict[str, int] = {"edge": len(edge_inputs), "random": len(fuzz_inputs)}

    sub_results: List[Any] = []
    ref_results: List[Any] = []

    def _score(inputs: List[Any], category: str) -> None:
        for inp in inputs:
            ref = _safe_call(target_function, inp, call_timeout_s, unpack=unpack_args)
            sub = _safe_call(submitted_fn, inp, call_timeout_s, unpack=unpack_args)
            sub_results.append(sub)
            ref_results.append(ref)
            if _outputs_equivalent(ref, sub):
                matches_by_cat[category] += 1

    _score(edge_inputs, "edge")
    _score(fuzz_inputs, "random")

    matches = matches_by_cat["edge"] + matches_by_cat["random"]
    fuzz_count = len(fuzz_inputs) + len(edge_inputs) or 1
    exec_reward = 100.0 * (matches / fuzz_count)
    edge_pass_rate = (
        matches_by_cat["edge"] / counts_by_cat["edge"] if counts_by_cat["edge"] else 0.0
    )

    # Anti-hacking: constant collapse penalty.
    if _detect_constant_collapse(sub_results, ref_results):
        hack_penalty += 15.0

    # Hard floor for sub-50% match rate. Vul-R2 style: a wrong patch deserves
    # a clearly negative signal so the agent doesn't learn that 'any defined
    # function' pays out via the small complexity-bonus / step structure.
    floor_penalty = 25.0 if exec_reward < 50.0 else 0.0

    return VerificationResult(
        execution_reward=exec_reward,
        complexity_penalty=complexity,
        define_error=None,
        matches=matches,
        fuzz_count=fuzz_count,
        matches_by_category=matches_by_cat,
        counts_by_category=counts_by_cat,
        edge_pass_rate=edge_pass_rate,
        reward_hack_penalty=hack_penalty,
        floor_penalty=floor_penalty,
    )


def _outputs_equivalent(ref, sub) -> bool:
    """Ref and sub are (kind, value) tuples from `_safe_call`. They count as
    equivalent if both raised the same exception type, or both returned values
    that are == equal."""
    rkind, rval = ref
    skind, sval = sub
    if rkind == "val" and skind == "val":
        try:
            return rval == sval
        except Exception:  # noqa: BLE001
            return False
    if rkind == "err" and skind == "err":
        return rval.split(":", 1)[0] == sval.split(":", 1)[0]
    if rkind == "timeout" and skind == "timeout":
        return True
    return False


def generate_fuzz_inputs(
    spec, count: int = 100, seed: Optional[int] = None
) -> List[Any]:
    """Public helper: pull ``count`` fuzz inputs from a FunctionSpec, optionally
    seeded for reproducibility."""
    rng = random.Random(seed)
    return spec.fuzzer(rng, count)


def get_edge_inputs(spec) -> List[Any]:
    """Return the spec's must-pass edge inputs (empty list if the spec
    predates the v0.3 schema)."""
    return list(getattr(spec, "edge_cases", []) or [])