"""Build the training dataset of (function_name, signature, probes) → prompt. v0.4 update: tasks and probe inputs are *discovered from the live env*, not hardcoded on the trainer side. This means a fresh task pushed to the ``anugrah55/opensleuth-tasks`` Hub dataset is picked up by the next trainer run with zero code changes here. Per-task probe inputs come from the env's ``/tasks/{name}/sample_inputs`` endpoint, which delegates to the same hand-written fuzzer (for the 9 builtins) or auto-fuzzer (for Hub-driven tasks) that the verifier uses. This guarantees the in-context probes the model trains on are drawn from the same distribution as the verifier's fuzz batch. Difficulty-weighted sampling: harder tasks get more rollouts (longer tail of unique seeds), since the agent needs more attempts to learn them. Defaults: ``easy=8, medium=16, hard=24`` rollouts per task. """ from __future__ import annotations import logging import random from typing import Iterable, List, Optional, Sequence from datasets import Dataset from .client import EnvClient from .prompt import build_prompt log = logging.getLogger("opensleuth.dataset") # Difficulty bucket → default rollouts per task. Caller can override per call # or via N_EASY / N_MEDIUM / N_HARD env vars in train.py. DEFAULT_N_BY_DIFFICULTY = {"easy": 8, "medium": 16, "hard": 24} # Tasks with no/unknown difficulty fall back to "medium". DEFAULT_N_FALLBACK = 16 # --------------------------------------------------------------------------- # Task discovery # --------------------------------------------------------------------------- def discover_functions( client: EnvClient, *, source: str = "all", include: Optional[Sequence[str]] = None, difficulty: Optional[str] = None, ) -> List[dict]: """Return the live task catalog from the env Space, optionally filtered. Parameters: ``source``: ``"builtin" | "hub" | "all"`` (default ``"all"``). ``include``: if non-empty, keep only tasks whose ``name`` is in it. ``difficulty``: ``"easy" | "medium" | "hard" | "all" | None``. ``None`` and ``"all"`` mean no filtering. """ tasks = client.list_tasks(source=source) if difficulty and difficulty.lower() != "all": tasks = [t for t in tasks if (t.get("difficulty") or "").lower() == difficulty.lower()] if include: wanted = {n.strip() for n in include if n and n.strip()} if wanted: tasks = [t for t in tasks if t["name"] in wanted] if not tasks: raise RuntimeError( f"discover_functions filtered down to 0 tasks " f"(source={source!r}, include={include!r}, difficulty={difficulty!r})." ) return tasks # Backwards-compat shim: old callers (eval/run_eval.py) imported a static # list. Now defaults to the 9 builtins so import-time consumers don't make # a network call. Use ``discover_functions(client)`` for the live catalog. FUNCTIONS_FOR_TRAINING: List[str] = [ "fibonacci", "reverse_string", "is_palindrome", "digit_sum", "count_vowels", "gcd", "sort_unique", "caesar_cipher", "is_prime", ] # --------------------------------------------------------------------------- # Probe sampling -- delegated to env's auto-fuzzer # --------------------------------------------------------------------------- def _make_probe_inputs( target_name: str, rng: random.Random, n: int, *, client: Optional[EnvClient] = None, seed: Optional[int] = None, ) -> List[str]: """Get ``n`` Python-literal repr strings appropriate for ``target_name``. Preferred path: hit the env's ``/tasks/{name}/sample_inputs`` endpoint so the trainer-side probe pool is always in lock-step with the verifier's fuzzer. Falls back to a tiny hardcoded pool only for the 9 legacy builtins so callers without a client (e.g. unit tests) still work. ``rng`` is consulted only for the legacy fallback path; when ``client`` is provided we forward ``seed`` (or a fresh one drawn from ``rng``) to the env so the result is reproducible across runs. """ if client is not None: if seed is None: seed = rng.randrange(0, 2**31) if rng is not None else 0 try: return client.sample_inputs(target_name=target_name, n=n, seed=seed) except Exception as e: # noqa: BLE001 # Don't crash the dataset build if the env hiccups -- fall through # to the legacy pool for builtins, or "1" * n for unknowns. log.warning( "env sample_inputs(%s, n=%d, seed=%s) failed: %s; falling back to legacy pool", target_name, n, seed, e, ) return _legacy_probe_pool(target_name, rng, n) def _legacy_probe_pool(target_name: str, rng: random.Random, n: int) -> List[str]: """Hardcoded pool for the 9 builtin functions. Kept as a fallback only so unit tests / offline callers still work; the live trainer uses ``client.sample_inputs`` exclusively.""" if target_name == "fibonacci": pool = [1, 2, 5, 10, 20, 40, 89, -1, 0, 100] elif target_name == "reverse_string": pool = ['""', "'a'", "'hello'", "'racecar'", "'abc123'", "''", "'ab'"] return [rng.choice(pool) for _ in range(n)] elif target_name == "is_palindrome": pool = ["'racecar'", "'hello'", "'A man a plan a canal Panama'", "''", "'ab'", "'aba'"] return [rng.choice(pool) for _ in range(n)] elif target_name == "digit_sum": pool = [0, 1, 9, 10, 99, 100, 12345, -3] elif target_name == "count_vowels": pool = ["'hello'", "''", "'rhythm'", "'AEIOU'", "'xyz'", "'queueing'"] return [rng.choice(pool) for _ in range(n)] elif target_name == "gcd": pool = ["(12, 8)", "(7, 13)", "(0, 5)", "[15, 25]", "(100, 75)", "[6, 9]"] return [rng.choice(pool) for _ in range(n)] elif target_name == "sort_unique": pool = ["[3, 1, 2, 1]", "[]", "[5, 5, 5]", "[-1, 0, -1, 2]", "[10]"] return [rng.choice(pool) for _ in range(n)] elif target_name == "caesar_cipher": pool = ["'hello'", "'abc'", "'xyz'", "''", "'Hello!'", "'a b c'"] return [rng.choice(pool) for _ in range(n)] elif target_name == "is_prime": pool = [2, 3, 4, 7, 9, 11, 25, 29, 0, 1, -3] else: return ["1"] * n return [repr(rng.choice(pool)) for _ in range(n)] # --------------------------------------------------------------------------- # Single-row sampler # --------------------------------------------------------------------------- def _sample_probes( client: EnvClient, target_name: str, seed: int, n_probes: int, ) -> tuple[str, list[tuple[str, str, bool]]]: """Open an episode and feed it ``n_probes`` random valid inputs sourced from the env's own auto-fuzzer.""" rng = random.Random(seed) ep = client.reset(target_name=target_name, seed=seed, max_steps=n_probes + 5) sig = ep["target_function_signature"] eid = ep["episode_id"] inputs = _make_probe_inputs(target_name, rng, n_probes, client=client, seed=seed) history: list[tuple[str, str, bool]] = [] for inp_repr in inputs: try: resp = client.probe(eid, inp_repr) except Exception as e: # noqa: BLE001 log.warning("probe failed for %s with %r: %s", target_name, inp_repr, e) continue last = resp["observation"]["probe_history"][-1] history.append((last["input_repr"], last["output_repr"], bool(last["is_error"]))) return sig, history # --------------------------------------------------------------------------- # Dataset builder # --------------------------------------------------------------------------- def build_synthesis_dataset( client: EnvClient, *, n_per_function: Optional[int] = None, n_easy: int = DEFAULT_N_BY_DIFFICULTY["easy"], n_medium: int = DEFAULT_N_BY_DIFFICULTY["medium"], n_hard: int = DEFAULT_N_BY_DIFFICULTY["hard"], n_probes: int = 6, seed: int = 0, include: Optional[Sequence[str]] = None, difficulty: Optional[str] = None, tasks: Optional[Iterable[dict]] = None, ) -> Dataset: """Build a HuggingFace Dataset of {prompt, target_function_name} rows. ``n_per_function`` (legacy v0.3 knob) overrides the difficulty-weighted sampling and applies a uniform N to every task. The new default behaviour is to sample ``n_easy / n_medium / n_hard`` rollouts per task by difficulty bucket; harder tasks need more rollouts to learn. """ if tasks is None: tasks = discover_functions( client, include=include, difficulty=difficulty, ) tasks = list(tasks) by_diff = {"easy": n_easy, "medium": n_medium, "hard": n_hard} rows = [] rng = random.Random(seed) log.info("building dataset over %d task(s); per-difficulty rollouts: %s%s", len(tasks), by_diff, f" (override n_per_function={n_per_function})" if n_per_function else "") for task in tasks: fn_name = task["name"] diff = (task.get("difficulty") or "").lower() if n_per_function is not None: n_rollouts = int(n_per_function) else: n_rollouts = by_diff.get(diff, DEFAULT_N_FALLBACK) log.info(" %-22s difficulty=%-8s rollouts=%d source=%s", fn_name, diff or "?", n_rollouts, task.get("source", "?")) for _ in range(n_rollouts): row_seed = rng.randrange(0, 2**31) try: sig, probes = _sample_probes(client, fn_name, row_seed, n_probes) except Exception as e: # noqa: BLE001 log.warning("rollout build failed for %s seed=%d: %s; skipping row", fn_name, row_seed, e) continue prompt = build_prompt(fn_name, sig, probes) rows.append( { "prompt": prompt, "target_function_name": fn_name, "row_seed": row_seed, "difficulty": diff or "unknown", } ) rng.shuffle(rows) log.info("built dataset: %d rows total", len(rows)) return Dataset.from_list(rows)