anugrah55's picture
Level 2 open-ended env: auto-fuzzer + TaskCatalog + Hub-driven catalog + extended /reset
77e65fb verified
"""Core OpenSleuth episodic environment.
A single OpenSleuthEnv holds a *registry of episodes* keyed by episode_id, so
multiple training rollouts can hit the same FastAPI server in parallel without
stepping on each other's state.
Reward shaping (v0.3 -- paper-driven update):
* ``PROBE_STEP_COST`` -- per-step cost so the agent doesn't probe forever.
* ``NEW_OUTPUT_BONUS`` -- first-visit bonus for an output value the target
hasn't produced yet (existing behaviour, kept).
* ``NEW_ERROR_TYPE_BONUS`` -- first-visit bonus for an exception type the
target hasn't raised yet (existing behaviour, kept).
* ``NEW_BUCKET_BONUS`` -- *new*: TF-IDF / count-based exploration bonus
(CovRL-Fuzz; Eom et al. 2024 in Masud et al. 2026 §3.5.2 and SimHash;
Ibrahim et al. 2024 §IV-C-1). Encourages probing *under-explored regions
of the input domain* (negative ints, empty strings, edge values, ...) not
just under-observed outputs. Small magnitude so it doesn't drown out the
output/error-type bonuses.
* ``PERFECT_SUBMISSION_BONUS`` -- existing terminal bonus, gated to require
100% match (including all edge cases).
The submission reward formula is now::
reward = execution_reward
- complexity_penalty
- reward_hack_penalty # new: import-of-reference detector etc.
- floor_penalty # new: -25 floor below 50% match rate
+ (PERFECT_SUBMISSION_BONUS if execution_reward >= 99.999 else 0)
This keeps the ``reward`` field a single float (so the in-flight trainer's
``reward / 100`` GRPO scaling still works) but pushes wrong submissions
clearly into the negative regime.
"""
from __future__ import annotations
import ast
import logging
import uuid
from typing import Any, Optional, Tuple
from .black_box import BLACK_BOX_FUNCTIONS, FunctionSpec
from .models import (
Action,
Observation,
ProbeAction,
ProbeRecord,
State,
StepResponse,
SubmitAction,
)
from .task_catalog import TaskCatalog, TaskResolutionError
from .verifier import generate_fuzz_inputs, get_edge_inputs, verify_submission
log = logging.getLogger("opensleuth.env")
# Reward shaping knobs (kept here so they're easy to tune).
PROBE_STEP_COST = -1.0
NEW_OUTPUT_BONUS = 2.0
NEW_ERROR_TYPE_BONUS = 5.0
NEW_BUCKET_BONUS = 0.5 # CovRL-style coverage bonus; small to avoid drowning the rest.
PERFECT_SUBMISSION_BONUS = 50.0
MAX_PROBE_HISTORY_IN_OBS = 25
def _bucket_of(x: Any) -> str:
"""Coarse, deterministic bucketisation of a probe input, used for
coverage-based intrinsic reward (CovRL-Fuzz inspired). Buckets are by
type + a few qualitative magnitudes (sign / size / emptiness) so that
e.g. ``-1`` and ``-99`` share a bucket, while ``-1`` and ``0`` don't.
"""
if isinstance(x, bool):
return f"bool:{x}"
if isinstance(x, int):
if x < 0:
return "int:negative"
if x == 0:
return "int:zero"
if x < 10:
return "int:small"
if x < 100:
return "int:medium"
if x < 10_000:
return "int:large"
return "int:huge"
if isinstance(x, float):
if x != x: # NaN
return "float:nan"
if x < 0:
return "float:negative"
if x == 0:
return "float:zero"
return "float:positive"
if isinstance(x, str):
if x == "":
return "str:empty"
if len(x) == 1:
return "str:singleton"
if len(x) <= 5:
return "str:short"
if len(x) <= 20:
return "str:medium"
return "str:long"
if isinstance(x, (list, tuple)):
kind = type(x).__name__
if len(x) == 0:
return f"{kind}:empty"
if len(x) == 1:
return f"{kind}:singleton"
if len(x) <= 5:
return f"{kind}:short"
return f"{kind}:long"
if isinstance(x, dict):
return f"dict:{len(x)}"
if x is None:
return "none"
return f"other:{type(x).__name__}"
class OpenSleuthEnv:
"""Multi-episode environment registry."""
def __init__(
self,
fuzz_count: int = 100,
catalog: Optional["TaskCatalog"] = None,
) -> None:
self._states: dict[str, State] = {}
self._configs: dict[str, dict] = {}
# Per-episode resolved spec. We cache it here (rather than looking it
# up by name on every step from BLACK_BOX_FUNCTIONS) because
# caller-supplied / Hub-loaded specs aren't in BLACK_BOX_FUNCTIONS.
self._episode_specs: dict[str, FunctionSpec] = {}
self.fuzz_count = fuzz_count
self._catalog = catalog or TaskCatalog()
@property
def catalog(self) -> "TaskCatalog":
return self._catalog
# --- Lifecycle ---------------------------------------------------------
def reset(
self,
target_name: Optional[str] = None,
seed: int = 0,
max_steps: int = 25,
*,
target_code: Optional[str] = None,
target_function_name: Optional[str] = None,
edge_cases: Optional[list] = None,
fuzz_spec: Optional[dict] = None,
) -> Observation:
# Backwards-compat: legacy callers pass ``target_name="fibonacci"``
# only. The catalog handles that path identically to before.
try:
spec = self._catalog.resolve(
target_name=target_name,
target_code=target_code,
target_function_name=target_function_name,
edge_cases=edge_cases,
fuzz_spec=fuzz_spec,
)
except TaskResolutionError as e:
raise ValueError(str(e)) from e
episode_id = uuid.uuid4().hex
self._states[episode_id] = State(
episode_id=episode_id,
target_function_name=spec.name,
seed=seed,
)
self._configs[episode_id] = {"max_steps": max_steps}
self._episode_specs[episode_id] = spec
return self._build_observation(episode_id, spec, last_error="")
def _spec_for(self, state: State) -> FunctionSpec:
spec = self._episode_specs.get(state.episode_id)
if spec is not None:
return spec
# Legacy fallback: if an episode was created before we started
# caching specs (or via a code path that bypassed reset), look up
# by name in the builtin registry.
return BLACK_BOX_FUNCTIONS[state.target_function_name]
def step(self, episode_id: str, action: Action) -> StepResponse:
state = self._states.get(episode_id)
if state is None:
raise KeyError(f"Unknown episode_id {episode_id!r}. Did you /reset first?")
if state.done:
spec = self._spec_for(state)
obs = self._build_observation(episode_id, spec, last_error="Episode already terminated.")
return StepResponse(observation=obs, reward=0.0, done=True, info={"reason": "already_done"})
spec = self._spec_for(state)
state.steps_taken += 1
max_steps = self._configs[episode_id]["max_steps"]
if isinstance(action, ProbeAction):
obs, reward, done, info = self._handle_probe(state, spec, action)
elif isinstance(action, SubmitAction):
obs, reward, done, info = self._handle_submit(state, spec, action)
else:
obs = self._build_observation(
episode_id, spec, last_error=f"Invalid action type: {type(action).__name__}"
)
reward, done, info = -20.0, True, {"reason": "invalid_action"}
# Step-budget exhaustion ends the episode with no extra reward.
if not done and state.steps_taken >= max_steps:
done = True
info = {**info, "reason": info.get("reason", "step_limit")}
if done:
state.done = True
return StepResponse(observation=obs, reward=reward, done=done, info=info)
# --- Action handlers ---------------------------------------------------
def _handle_probe(
self, state: State, spec: FunctionSpec, action: ProbeAction
) -> Tuple[Observation, float, bool, dict]:
try:
parsed = ast.literal_eval(action.input_repr)
except (ValueError, SyntaxError) as e:
err = f"Could not parse input_repr as a Python literal: {e}"
state.probe_history.append(
ProbeRecord(
input_repr=action.input_repr,
output_repr=err,
is_error=True,
error_type="ParseError",
)
)
obs = self._build_observation(state.episode_id, spec, last_error=err)
return obs, PROBE_STEP_COST, False, {"reason": "parse_error"}
bucket = _bucket_of(parsed)
bucket_bonus = 0.0
if bucket not in state.seen_buckets:
bucket_bonus = NEW_BUCKET_BONUS
state.seen_buckets.add(bucket)
intrinsic = 0.0
last_error = ""
try:
if spec.unpack_args:
if not isinstance(parsed, tuple):
raise TypeError(
f"Multi-parameter target {spec.name!r} expects a tuple "
f"of args, got {type(parsed).__name__}."
)
output = spec.fn(*parsed)
else:
output = spec.fn(parsed)
output_repr = repr(output)
state.probe_history.append(
ProbeRecord(
input_repr=repr(parsed),
output_repr=output_repr,
is_error=False,
bucket=bucket,
)
)
if output_repr not in state.seen_outputs:
intrinsic += NEW_OUTPUT_BONUS
state.seen_outputs.add(output_repr)
except Exception as e: # noqa: BLE001
error_type = type(e).__name__
err_repr = f"{error_type}: {e}"
state.probe_history.append(
ProbeRecord(
input_repr=repr(parsed),
output_repr=err_repr,
is_error=True,
error_type=error_type,
bucket=bucket,
)
)
last_error = err_repr
if error_type not in state.seen_error_types:
intrinsic += NEW_ERROR_TYPE_BONUS
state.seen_error_types.add(error_type)
reward = intrinsic + bucket_bonus + PROBE_STEP_COST
obs = self._build_observation(state.episode_id, spec, last_error=last_error)
return obs, reward, False, {
"intrinsic": intrinsic,
"coverage_bonus": bucket_bonus,
"bucket": bucket,
"buckets_seen": len(state.seen_buckets),
}
def _handle_submit(
self, state: State, spec: FunctionSpec, action: SubmitAction
) -> Tuple[Observation, float, bool, dict]:
fuzz_inputs = generate_fuzz_inputs(spec, count=self.fuzz_count, seed=state.seed)
edge_inputs = get_edge_inputs(spec)
result = verify_submission(
action.code,
spec.fn,
fuzz_inputs,
target_name=spec.name,
edge_inputs=edge_inputs,
unpack_args=spec.unpack_args,
)
total = (
result.execution_reward
- result.complexity_penalty
- result.reward_hack_penalty
- result.floor_penalty
)
if result.execution_reward >= 99.999:
total += PERFECT_SUBMISSION_BONUS
obs = self._build_observation(
state.episode_id,
spec,
last_error=result.define_error or "",
)
info = {
# --- Existing fields the live trainer + eval already read. ----
"execution_reward": result.execution_reward,
"complexity_penalty": result.complexity_penalty,
"matches": result.matches,
"fuzz_count": result.fuzz_count,
"define_error": result.define_error,
"reason": "submission",
# --- New, additive fields. -----------------------------------
"matches_by_category": result.matches_by_category,
"counts_by_category": result.counts_by_category,
"edge_pass_rate": result.edge_pass_rate,
"reward_hack_penalty": result.reward_hack_penalty,
"floor_penalty": result.floor_penalty,
"perfect_bonus": (
PERFECT_SUBMISSION_BONUS if result.execution_reward >= 99.999 else 0.0
),
}
return obs, total, True, info
# --- Helpers -----------------------------------------------------------
def _build_observation(
self, episode_id: str, spec: FunctionSpec, last_error: str
) -> Observation:
state = self._states[episode_id]
max_steps = self._configs[episode_id]["max_steps"]
history = state.probe_history[-MAX_PROBE_HISTORY_IN_OBS:]
return Observation(
episode_id=episode_id,
target_function_name=state.target_function_name,
target_function_signature=f"{spec.signature}\n\n{spec.description}",
probe_history=history,
last_error=last_error,
steps_taken=state.steps_taken,
max_steps=max_steps,
difficulty=getattr(spec, "difficulty", None),
coverage_buckets_seen=len(state.seen_buckets),
seen_outputs_count=len(state.seen_outputs),
seen_error_types_count=len(state.seen_error_types),
)
# --- Introspection -----------------------------------------------------
def get_state(self, episode_id: str) -> dict:
s = self._states.get(episode_id)
if s is None:
return {}
return {
"episode_id": s.episode_id,
"target_function_name": s.target_function_name,
"steps_taken": s.steps_taken,
"done": s.done,
"seen_outputs": sorted(s.seen_outputs),
"seen_error_types": sorted(s.seen_error_types),
"seen_buckets": sorted(s.seen_buckets),
"probe_history": [r.model_dump() for r in s.probe_history],
}