anugrah55's picture
Level 2 open-ended env: auto-fuzzer + TaskCatalog + Hub-driven catalog + extended /reset
77e65fb verified
"""TaskCatalog: resolve a target function from one of three sources.
OpenSleuth Level 2 makes the env open-ended. Where v0.3 only knew about the
9 hand-written ``BLACK_BOX_FUNCTIONS``, the catalog accepts targets from:
1. **Caller-supplied** -- per-/reset payload, the most specific source.
The caller passes ``target_code`` + ``target_function_name`` (and
optionally ``edge_cases`` / ``fuzz_spec``) and we compile the source
in the same hardened sandbox the verifier uses for submissions.
2. **Hub dataset** -- ``anugrah55/opensleuth-tasks`` on Hugging Face Hub.
Each row carries ``{name, signature, description, difficulty,
source_code, edge_cases_json, fuzz_spec_json}``. Loaded lazily on
first reset and cached in-process.
3. **Builtin registry** -- the original 9 ``BLACK_BOX_FUNCTIONS``. Kept
as the safety-net so the in-flight trainer keeps working unchanged.
Resolution priority: caller-supplied wins, then Hub by name, then builtin.
This makes "trainer asks for fibonacci" still resolve to the builtin
fibonacci even when the Hub copy exists, *unless* the caller explicitly
overrides via ``target_code``.
Sandbox: caller-supplied / Hub source code is executed via the same
``_make_safe_globals`` whitelist as agent submissions (no ``__import__``,
``open``, ``eval``, ...). On top we statically reject any source that
imports ``opensleuth_*`` to prevent oracle-cheesing.
"""
from __future__ import annotations
import ast
import inspect
import json
import logging
import threading
from typing import Any, Callable, Dict, List, Optional
from .auto_fuzzer import auto_fuzz, make_fuzzer
from .black_box import BLACK_BOX_FUNCTIONS, FunctionSpec
from .verifier import _make_safe_globals # reuse the hardened sandbox
log = logging.getLogger("opensleuth.task_catalog")
HUB_DATASET_ID = "anugrah55/opensleuth-tasks"
class TaskResolutionError(ValueError):
"""Raised when a /reset request can't be turned into a FunctionSpec."""
# ---------------------------------------------------------------------------
# Caller / Hub source-code compilation
# ---------------------------------------------------------------------------
_FORBIDDEN_PREFIXES = ("opensleuth", "opensleuth_env")
def _statically_reject_oracle_import(code: str) -> Optional[str]:
"""Return an error string if the source statically imports the env's own
oracle module (which would let the agent / Hub author cheese the
verifier). The hardened sandbox already blocks ``__import__``, but we
fail fast and surface a clear error.
"""
try:
tree = ast.parse(code)
except SyntaxError as e:
return f"target_code is not valid Python: {e}"
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
if any(alias.name.startswith(p) for p in _FORBIDDEN_PREFIXES):
return (
f"target_code is not allowed to import {alias.name!r} "
"(oracle import)."
)
elif isinstance(node, ast.ImportFrom):
mod = node.module or ""
if any(mod.startswith(p) for p in _FORBIDDEN_PREFIXES):
return (
f"target_code is not allowed to import from {mod!r} "
"(oracle import)."
)
return None
def _compile_target_in_sandbox(code: str, function_name: str) -> Callable[..., Any]:
"""Compile ``code`` in the same restricted globals the verifier uses for
agent submissions, then return the named callable. Raises
``TaskResolutionError`` on any problem so /reset can return a clean 400.
"""
err = _statically_reject_oracle_import(code)
if err:
raise TaskResolutionError(err)
safe_globals = _make_safe_globals()
local_scope: Dict[str, Any] = {}
try:
exec(code, safe_globals, local_scope)
except Exception as e: # noqa: BLE001
raise TaskResolutionError(
f"target_code raised at definition time: {type(e).__name__}: {e}"
) from e
fn = local_scope.get(function_name) or safe_globals.get(function_name)
if not callable(fn):
raise TaskResolutionError(
f"target_code does not define a callable named {function_name!r}."
)
return fn
def _arity_of(fn: Callable[..., Any]) -> int:
"""Number of positional / positional-or-keyword params on ``fn``."""
try:
sig = inspect.signature(fn)
except (TypeError, ValueError):
return 1
n = 0
for p in sig.parameters.values():
if p.kind in (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
):
n += 1
return max(n, 1)
def _signature_string(fn: Callable[..., Any], name: str) -> str:
try:
sig = inspect.signature(fn)
return f"{name}{sig}"
except (TypeError, ValueError):
return f"{name}(...)"
def _description_of(fn: Callable[..., Any]) -> str:
return inspect.getdoc(fn) or ""
def _parse_edge_cases(edge_cases: Optional[List[Any]]) -> List[Any]:
"""Edge cases arrive as a list of strings (Python literal reprs) when
coming from the API or from the Hub's ``edge_cases_json`` column. Each
string is parsed via ``ast.literal_eval``. Already-parsed values
(e.g. ints from the bootstrap script) are passed through unchanged.
"""
if not edge_cases:
return []
parsed: List[Any] = []
for raw in edge_cases:
if isinstance(raw, str):
try:
parsed.append(ast.literal_eval(raw))
except (ValueError, SyntaxError) as e:
raise TaskResolutionError(
f"edge_cases entry {raw!r} is not a Python literal: {e}"
) from e
else:
parsed.append(raw)
return parsed
def _flatten_unary_edges(arity: int, edges: List[Any]) -> List[Any]:
"""For unary fns we accept either ``[5, 10]`` or ``[(5,), (10,)]`` and
normalise to flat values; for multi-arg fns we require tuples and pass
them through."""
if arity == 1:
out = []
for e in edges:
if isinstance(e, tuple) and len(e) == 1:
out.append(e[0])
else:
out.append(e)
return out
out = []
for e in edges:
if not isinstance(e, tuple):
raise TaskResolutionError(
f"edge_cases for a {arity}-arg target must be tuples, "
f"got {type(e).__name__}: {e!r}"
)
out.append(e)
return out
def _spec_from_callable(
name: str,
fn: Callable[..., Any],
*,
description: Optional[str] = None,
signature: Optional[str] = None,
difficulty: str = "medium",
edge_cases: Optional[List[Any]] = None,
fuzz_spec: Optional[Dict[str, Dict[str, Any]]] = None,
source: str = "user",
) -> FunctionSpec:
"""Build a FunctionSpec from a Python callable + optional metadata.
Wraps ``auto_fuzz`` for the fuzzer. The arity is auto-detected from
``inspect.signature`` so ``unpack_args`` is set correctly: unary fns
behave like the existing builtins (single-arg call), N-arg fns flow
through the tuple-unpacking path in env / verifier.
"""
arity = _arity_of(fn)
unpack = arity > 1
parsed_edges = _flatten_unary_edges(arity, _parse_edge_cases(edge_cases))
if unpack:
# Catalog-level adapter: keep the public spec.fn one-arg-style for
# the *unary* path so existing call sites work, but for multi-arg
# the env/verifier respect ``unpack_args`` and call ``fn(*args)``.
# We still store the original here -- env._handle_probe and
# verify_submission do the unpacking.
ref_fn: Callable[..., Any] = fn
def _fuzzer(rng, n):
return auto_fuzz(fn, n, rng, fuzz_spec=fuzz_spec)
else:
ref_fn = fn
def _unary_fuzzer(rng, n):
tuples = auto_fuzz(fn, n, rng, fuzz_spec=fuzz_spec)
return [t[0] if isinstance(t, tuple) and len(t) == 1 else t for t in tuples]
_fuzzer = _unary_fuzzer
return FunctionSpec(
name=name,
fn=ref_fn,
signature=signature or _signature_string(fn, name),
description=description or _description_of(fn),
fuzzer=_fuzzer,
difficulty=difficulty,
edge_cases=parsed_edges,
unpack_args=unpack,
source=source,
)
# ---------------------------------------------------------------------------
# Hub loader
# ---------------------------------------------------------------------------
class _HubCache:
"""Lazily loads the Hub dataset into ``{name: FunctionSpec}``. Thread-
safe initialisation; subsequent reads are lock-free."""
def __init__(self, dataset_id: str):
self.dataset_id = dataset_id
self._lock = threading.Lock()
self._loaded: bool = False
self._specs: Dict[str, FunctionSpec] = {}
self._raw_rows: List[Dict[str, Any]] = []
self._load_error: Optional[str] = None
@property
def loaded(self) -> bool:
return self._loaded
@property
def load_error(self) -> Optional[str]:
return self._load_error
def _row_to_spec(self, row: Dict[str, Any]) -> Optional[FunctionSpec]:
name = row.get("name")
code = row.get("source_code")
if not name or not code:
return None
fn_name = row.get("target_function_name") or name
try:
fn = _compile_target_in_sandbox(code, fn_name)
except TaskResolutionError as e:
log.warning("hub task %r failed to compile: %s", name, e)
return None
edge_cases_raw = row.get("edge_cases_json") or "[]"
fuzz_spec_raw = row.get("fuzz_spec_json") or "null"
try:
edge_cases = json.loads(edge_cases_raw) if isinstance(edge_cases_raw, str) else edge_cases_raw
except json.JSONDecodeError:
edge_cases = []
try:
fuzz_spec = json.loads(fuzz_spec_raw) if isinstance(fuzz_spec_raw, str) else fuzz_spec_raw
except json.JSONDecodeError:
fuzz_spec = None
try:
return _spec_from_callable(
name=name,
fn=fn,
description=row.get("description") or _description_of(fn),
signature=row.get("signature") or _signature_string(fn, name),
difficulty=row.get("difficulty") or "medium",
edge_cases=edge_cases,
fuzz_spec=fuzz_spec,
source="hub",
)
except TaskResolutionError as e:
log.warning("hub task %r could not be specced: %s", name, e)
return None
def ensure_loaded(self) -> None:
if self._loaded:
return
with self._lock:
if self._loaded:
return
try:
from datasets import load_dataset # type: ignore
ds = load_dataset(self.dataset_id, split="train")
rows = list(ds)
specs: Dict[str, FunctionSpec] = {}
for row in rows:
spec = self._row_to_spec(row)
if spec is not None:
specs[spec.name] = spec
self._specs = specs
self._raw_rows = rows
log.info(
"loaded %d task(s) from %s (%d row(s) total)",
len(specs),
self.dataset_id,
len(rows),
)
except Exception as e: # noqa: BLE001
# Hub unreachable / not yet bootstrapped / offline. We swallow
# the error so the env keeps working from the builtin
# registry alone -- this is what lets the trainer keep
# running even if the Hub goes down mid-rollout.
self._load_error = f"{type(e).__name__}: {e}"
log.warning("hub dataset %s unavailable: %s", self.dataset_id, self._load_error)
finally:
self._loaded = True
def specs(self) -> Dict[str, FunctionSpec]:
self.ensure_loaded()
return self._specs
def rows(self) -> List[Dict[str, Any]]:
self.ensure_loaded()
return self._raw_rows
# ---------------------------------------------------------------------------
# TaskCatalog
# ---------------------------------------------------------------------------
class TaskCatalog:
"""Resolves /reset payloads to FunctionSpecs from caller / Hub / builtin."""
def __init__(
self,
hub_dataset_id: str = HUB_DATASET_ID,
*,
enable_hub: bool = True,
) -> None:
self.hub_dataset_id = hub_dataset_id
self.enable_hub = enable_hub
self._hub = _HubCache(hub_dataset_id) if enable_hub else None
# --- Resolution --------------------------------------------------------
def resolve(
self,
target_name: Optional[str] = None,
target_code: Optional[str] = None,
target_function_name: Optional[str] = None,
edge_cases: Optional[List[Any]] = None,
fuzz_spec: Optional[Dict[str, Dict[str, Any]]] = None,
) -> FunctionSpec:
# 1. Caller-supplied: highest priority.
if target_code is not None:
if not target_function_name:
raise TaskResolutionError(
"target_code requires target_function_name to identify "
"which callable in the source to use."
)
fn = _compile_target_in_sandbox(target_code, target_function_name)
return _spec_from_callable(
name=target_function_name,
fn=fn,
edge_cases=edge_cases,
fuzz_spec=fuzz_spec,
source="user",
)
# 2. & 3. Hub-by-name / builtin-by-name. Builtin wins for legacy names
# (so the trainer's "fibonacci" always means the in-process oracle,
# never a possibly-modified Hub copy).
if not target_name:
raise TaskResolutionError(
"Either target_name or (target_code + target_function_name) must be set."
)
if target_name in BLACK_BOX_FUNCTIONS:
return BLACK_BOX_FUNCTIONS[target_name]
if self._hub is not None:
hub_specs = self._hub.specs()
if target_name in hub_specs:
return hub_specs[target_name]
available = self.list_known_names()
raise TaskResolutionError(
f"Unknown target function: {target_name!r}. Available: {sorted(available)[:25]}"
)
# --- Listing -----------------------------------------------------------
def list_known_names(self) -> List[str]:
names = set(BLACK_BOX_FUNCTIONS)
if self._hub is not None:
try:
names.update(self._hub.specs())
except Exception: # noqa: BLE001 -- best effort
pass
return sorted(names)
def list_builtin(self) -> List[Dict[str, Any]]:
return [
{
"name": s.name,
"signature": s.signature,
"description": s.description,
"difficulty": s.difficulty,
"edge_case_count": len(s.edge_cases or []),
"source": "builtin",
}
for s in BLACK_BOX_FUNCTIONS.values()
]
def list_hub(self) -> List[Dict[str, Any]]:
if self._hub is None:
return []
out = []
for s in self._hub.specs().values():
# Don't shadow builtins in the Hub list (avoids surprising the
# caller with a "fibonacci@hub" entry that's never used).
if s.name in BLACK_BOX_FUNCTIONS:
continue
out.append(
{
"name": s.name,
"signature": s.signature,
"description": s.description,
"difficulty": s.difficulty,
"edge_case_count": len(s.edge_cases or []),
"source": "hub",
}
)
return out
def list_all(self) -> List[Dict[str, Any]]:
return self.list_builtin() + self.list_hub()
# --- Diagnostics -------------------------------------------------------
def hub_status(self) -> Dict[str, Any]:
if self._hub is None:
return {"enabled": False}
return {
"enabled": True,
"dataset_id": self.hub_dataset_id,
"loaded": self._hub.loaded,
"task_count": len(self._hub.specs()) if self._hub.loaded else None,
"error": self._hub.load_error,
}