File size: 16,953 Bytes
77e65fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
"""TaskCatalog: resolve a target function from one of three sources.

OpenSleuth Level 2 makes the env open-ended. Where v0.3 only knew about the
9 hand-written ``BLACK_BOX_FUNCTIONS``, the catalog accepts targets from:

  1. **Caller-supplied** -- per-/reset payload, the most specific source.
     The caller passes ``target_code`` + ``target_function_name`` (and
     optionally ``edge_cases`` / ``fuzz_spec``) and we compile the source
     in the same hardened sandbox the verifier uses for submissions.

  2. **Hub dataset** -- ``anugrah55/opensleuth-tasks`` on Hugging Face Hub.
     Each row carries ``{name, signature, description, difficulty,
     source_code, edge_cases_json, fuzz_spec_json}``. Loaded lazily on
     first reset and cached in-process.

  3. **Builtin registry** -- the original 9 ``BLACK_BOX_FUNCTIONS``. Kept
     as the safety-net so the in-flight trainer keeps working unchanged.

Resolution priority: caller-supplied wins, then Hub by name, then builtin.
This makes "trainer asks for fibonacci" still resolve to the builtin
fibonacci even when the Hub copy exists, *unless* the caller explicitly
overrides via ``target_code``.

Sandbox: caller-supplied / Hub source code is executed via the same
``_make_safe_globals`` whitelist as agent submissions (no ``__import__``,
``open``, ``eval``, ...). On top we statically reject any source that
imports ``opensleuth_*`` to prevent oracle-cheesing.
"""

from __future__ import annotations

import ast
import inspect
import json
import logging
import threading
from typing import Any, Callable, Dict, List, Optional

from .auto_fuzzer import auto_fuzz, make_fuzzer
from .black_box import BLACK_BOX_FUNCTIONS, FunctionSpec
from .verifier import _make_safe_globals  # reuse the hardened sandbox

log = logging.getLogger("opensleuth.task_catalog")

HUB_DATASET_ID = "anugrah55/opensleuth-tasks"


class TaskResolutionError(ValueError):
    """Raised when a /reset request can't be turned into a FunctionSpec."""


# ---------------------------------------------------------------------------
# Caller / Hub source-code compilation
# ---------------------------------------------------------------------------


_FORBIDDEN_PREFIXES = ("opensleuth", "opensleuth_env")


def _statically_reject_oracle_import(code: str) -> Optional[str]:
    """Return an error string if the source statically imports the env's own
    oracle module (which would let the agent / Hub author cheese the
    verifier). The hardened sandbox already blocks ``__import__``, but we
    fail fast and surface a clear error.
    """
    try:
        tree = ast.parse(code)
    except SyntaxError as e:
        return f"target_code is not valid Python: {e}"
    for node in ast.walk(tree):
        if isinstance(node, ast.Import):
            for alias in node.names:
                if any(alias.name.startswith(p) for p in _FORBIDDEN_PREFIXES):
                    return (
                        f"target_code is not allowed to import {alias.name!r} "
                        "(oracle import)."
                    )
        elif isinstance(node, ast.ImportFrom):
            mod = node.module or ""
            if any(mod.startswith(p) for p in _FORBIDDEN_PREFIXES):
                return (
                    f"target_code is not allowed to import from {mod!r} "
                    "(oracle import)."
                )
    return None


def _compile_target_in_sandbox(code: str, function_name: str) -> Callable[..., Any]:
    """Compile ``code`` in the same restricted globals the verifier uses for
    agent submissions, then return the named callable. Raises
    ``TaskResolutionError`` on any problem so /reset can return a clean 400.
    """
    err = _statically_reject_oracle_import(code)
    if err:
        raise TaskResolutionError(err)
    safe_globals = _make_safe_globals()
    local_scope: Dict[str, Any] = {}
    try:
        exec(code, safe_globals, local_scope)
    except Exception as e:  # noqa: BLE001
        raise TaskResolutionError(
            f"target_code raised at definition time: {type(e).__name__}: {e}"
        ) from e
    fn = local_scope.get(function_name) or safe_globals.get(function_name)
    if not callable(fn):
        raise TaskResolutionError(
            f"target_code does not define a callable named {function_name!r}."
        )
    return fn


def _arity_of(fn: Callable[..., Any]) -> int:
    """Number of positional / positional-or-keyword params on ``fn``."""
    try:
        sig = inspect.signature(fn)
    except (TypeError, ValueError):
        return 1
    n = 0
    for p in sig.parameters.values():
        if p.kind in (
            inspect.Parameter.POSITIONAL_ONLY,
            inspect.Parameter.POSITIONAL_OR_KEYWORD,
        ):
            n += 1
    return max(n, 1)


def _signature_string(fn: Callable[..., Any], name: str) -> str:
    try:
        sig = inspect.signature(fn)
        return f"{name}{sig}"
    except (TypeError, ValueError):
        return f"{name}(...)"


def _description_of(fn: Callable[..., Any]) -> str:
    return inspect.getdoc(fn) or ""


def _parse_edge_cases(edge_cases: Optional[List[Any]]) -> List[Any]:
    """Edge cases arrive as a list of strings (Python literal reprs) when
    coming from the API or from the Hub's ``edge_cases_json`` column. Each
    string is parsed via ``ast.literal_eval``. Already-parsed values
    (e.g. ints from the bootstrap script) are passed through unchanged.
    """
    if not edge_cases:
        return []
    parsed: List[Any] = []
    for raw in edge_cases:
        if isinstance(raw, str):
            try:
                parsed.append(ast.literal_eval(raw))
            except (ValueError, SyntaxError) as e:
                raise TaskResolutionError(
                    f"edge_cases entry {raw!r} is not a Python literal: {e}"
                ) from e
        else:
            parsed.append(raw)
    return parsed


def _flatten_unary_edges(arity: int, edges: List[Any]) -> List[Any]:
    """For unary fns we accept either ``[5, 10]`` or ``[(5,), (10,)]`` and
    normalise to flat values; for multi-arg fns we require tuples and pass
    them through."""
    if arity == 1:
        out = []
        for e in edges:
            if isinstance(e, tuple) and len(e) == 1:
                out.append(e[0])
            else:
                out.append(e)
        return out
    out = []
    for e in edges:
        if not isinstance(e, tuple):
            raise TaskResolutionError(
                f"edge_cases for a {arity}-arg target must be tuples, "
                f"got {type(e).__name__}: {e!r}"
            )
        out.append(e)
    return out


def _spec_from_callable(
    name: str,
    fn: Callable[..., Any],
    *,
    description: Optional[str] = None,
    signature: Optional[str] = None,
    difficulty: str = "medium",
    edge_cases: Optional[List[Any]] = None,
    fuzz_spec: Optional[Dict[str, Dict[str, Any]]] = None,
    source: str = "user",
) -> FunctionSpec:
    """Build a FunctionSpec from a Python callable + optional metadata.

    Wraps ``auto_fuzz`` for the fuzzer. The arity is auto-detected from
    ``inspect.signature`` so ``unpack_args`` is set correctly: unary fns
    behave like the existing builtins (single-arg call), N-arg fns flow
    through the tuple-unpacking path in env / verifier.
    """
    arity = _arity_of(fn)
    unpack = arity > 1

    parsed_edges = _flatten_unary_edges(arity, _parse_edge_cases(edge_cases))

    if unpack:
        # Catalog-level adapter: keep the public spec.fn one-arg-style for
        # the *unary* path so existing call sites work, but for multi-arg
        # the env/verifier respect ``unpack_args`` and call ``fn(*args)``.
        # We still store the original here -- env._handle_probe and
        # verify_submission do the unpacking.
        ref_fn: Callable[..., Any] = fn

        def _fuzzer(rng, n):
            return auto_fuzz(fn, n, rng, fuzz_spec=fuzz_spec)

    else:
        ref_fn = fn

        def _unary_fuzzer(rng, n):
            tuples = auto_fuzz(fn, n, rng, fuzz_spec=fuzz_spec)
            return [t[0] if isinstance(t, tuple) and len(t) == 1 else t for t in tuples]

        _fuzzer = _unary_fuzzer

    return FunctionSpec(
        name=name,
        fn=ref_fn,
        signature=signature or _signature_string(fn, name),
        description=description or _description_of(fn),
        fuzzer=_fuzzer,
        difficulty=difficulty,
        edge_cases=parsed_edges,
        unpack_args=unpack,
        source=source,
    )


# ---------------------------------------------------------------------------
# Hub loader
# ---------------------------------------------------------------------------


class _HubCache:
    """Lazily loads the Hub dataset into ``{name: FunctionSpec}``. Thread-
    safe initialisation; subsequent reads are lock-free."""

    def __init__(self, dataset_id: str):
        self.dataset_id = dataset_id
        self._lock = threading.Lock()
        self._loaded: bool = False
        self._specs: Dict[str, FunctionSpec] = {}
        self._raw_rows: List[Dict[str, Any]] = []
        self._load_error: Optional[str] = None

    @property
    def loaded(self) -> bool:
        return self._loaded

    @property
    def load_error(self) -> Optional[str]:
        return self._load_error

    def _row_to_spec(self, row: Dict[str, Any]) -> Optional[FunctionSpec]:
        name = row.get("name")
        code = row.get("source_code")
        if not name or not code:
            return None
        fn_name = row.get("target_function_name") or name
        try:
            fn = _compile_target_in_sandbox(code, fn_name)
        except TaskResolutionError as e:
            log.warning("hub task %r failed to compile: %s", name, e)
            return None
        edge_cases_raw = row.get("edge_cases_json") or "[]"
        fuzz_spec_raw = row.get("fuzz_spec_json") or "null"
        try:
            edge_cases = json.loads(edge_cases_raw) if isinstance(edge_cases_raw, str) else edge_cases_raw
        except json.JSONDecodeError:
            edge_cases = []
        try:
            fuzz_spec = json.loads(fuzz_spec_raw) if isinstance(fuzz_spec_raw, str) else fuzz_spec_raw
        except json.JSONDecodeError:
            fuzz_spec = None
        try:
            return _spec_from_callable(
                name=name,
                fn=fn,
                description=row.get("description") or _description_of(fn),
                signature=row.get("signature") or _signature_string(fn, name),
                difficulty=row.get("difficulty") or "medium",
                edge_cases=edge_cases,
                fuzz_spec=fuzz_spec,
                source="hub",
            )
        except TaskResolutionError as e:
            log.warning("hub task %r could not be specced: %s", name, e)
            return None

    def ensure_loaded(self) -> None:
        if self._loaded:
            return
        with self._lock:
            if self._loaded:
                return
            try:
                from datasets import load_dataset  # type: ignore

                ds = load_dataset(self.dataset_id, split="train")
                rows = list(ds)
                specs: Dict[str, FunctionSpec] = {}
                for row in rows:
                    spec = self._row_to_spec(row)
                    if spec is not None:
                        specs[spec.name] = spec
                self._specs = specs
                self._raw_rows = rows
                log.info(
                    "loaded %d task(s) from %s (%d row(s) total)",
                    len(specs),
                    self.dataset_id,
                    len(rows),
                )
            except Exception as e:  # noqa: BLE001
                # Hub unreachable / not yet bootstrapped / offline. We swallow
                # the error so the env keeps working from the builtin
                # registry alone -- this is what lets the trainer keep
                # running even if the Hub goes down mid-rollout.
                self._load_error = f"{type(e).__name__}: {e}"
                log.warning("hub dataset %s unavailable: %s", self.dataset_id, self._load_error)
            finally:
                self._loaded = True

    def specs(self) -> Dict[str, FunctionSpec]:
        self.ensure_loaded()
        return self._specs

    def rows(self) -> List[Dict[str, Any]]:
        self.ensure_loaded()
        return self._raw_rows


# ---------------------------------------------------------------------------
# TaskCatalog
# ---------------------------------------------------------------------------


class TaskCatalog:
    """Resolves /reset payloads to FunctionSpecs from caller / Hub / builtin."""

    def __init__(
        self,
        hub_dataset_id: str = HUB_DATASET_ID,
        *,
        enable_hub: bool = True,
    ) -> None:
        self.hub_dataset_id = hub_dataset_id
        self.enable_hub = enable_hub
        self._hub = _HubCache(hub_dataset_id) if enable_hub else None

    # --- Resolution --------------------------------------------------------

    def resolve(
        self,
        target_name: Optional[str] = None,
        target_code: Optional[str] = None,
        target_function_name: Optional[str] = None,
        edge_cases: Optional[List[Any]] = None,
        fuzz_spec: Optional[Dict[str, Dict[str, Any]]] = None,
    ) -> FunctionSpec:
        # 1. Caller-supplied: highest priority.
        if target_code is not None:
            if not target_function_name:
                raise TaskResolutionError(
                    "target_code requires target_function_name to identify "
                    "which callable in the source to use."
                )
            fn = _compile_target_in_sandbox(target_code, target_function_name)
            return _spec_from_callable(
                name=target_function_name,
                fn=fn,
                edge_cases=edge_cases,
                fuzz_spec=fuzz_spec,
                source="user",
            )

        # 2. & 3. Hub-by-name / builtin-by-name. Builtin wins for legacy names
        # (so the trainer's "fibonacci" always means the in-process oracle,
        # never a possibly-modified Hub copy).
        if not target_name:
            raise TaskResolutionError(
                "Either target_name or (target_code + target_function_name) must be set."
            )

        if target_name in BLACK_BOX_FUNCTIONS:
            return BLACK_BOX_FUNCTIONS[target_name]

        if self._hub is not None:
            hub_specs = self._hub.specs()
            if target_name in hub_specs:
                return hub_specs[target_name]

        available = self.list_known_names()
        raise TaskResolutionError(
            f"Unknown target function: {target_name!r}. Available: {sorted(available)[:25]}"
        )

    # --- Listing -----------------------------------------------------------

    def list_known_names(self) -> List[str]:
        names = set(BLACK_BOX_FUNCTIONS)
        if self._hub is not None:
            try:
                names.update(self._hub.specs())
            except Exception:  # noqa: BLE001 -- best effort
                pass
        return sorted(names)

    def list_builtin(self) -> List[Dict[str, Any]]:
        return [
            {
                "name": s.name,
                "signature": s.signature,
                "description": s.description,
                "difficulty": s.difficulty,
                "edge_case_count": len(s.edge_cases or []),
                "source": "builtin",
            }
            for s in BLACK_BOX_FUNCTIONS.values()
        ]

    def list_hub(self) -> List[Dict[str, Any]]:
        if self._hub is None:
            return []
        out = []
        for s in self._hub.specs().values():
            # Don't shadow builtins in the Hub list (avoids surprising the
            # caller with a "fibonacci@hub" entry that's never used).
            if s.name in BLACK_BOX_FUNCTIONS:
                continue
            out.append(
                {
                    "name": s.name,
                    "signature": s.signature,
                    "description": s.description,
                    "difficulty": s.difficulty,
                    "edge_case_count": len(s.edge_cases or []),
                    "source": "hub",
                }
            )
        return out

    def list_all(self) -> List[Dict[str, Any]]:
        return self.list_builtin() + self.list_hub()

    # --- Diagnostics -------------------------------------------------------

    def hub_status(self) -> Dict[str, Any]:
        if self._hub is None:
            return {"enabled": False}
        return {
            "enabled": True,
            "dataset_id": self.hub_dataset_id,
            "loaded": self._hub.loaded,
            "task_count": len(self._hub.specs()) if self._hub.loaded else None,
            "error": self._hub.load_error,
        }