| """Build the training dataset of (function_name, signature, probes) → prompt. |
| |
| v0.4 update: tasks and probe inputs are *discovered from the live env*, not |
| hardcoded on the trainer side. This means a fresh task pushed to the |
| ``anugrah55/opensleuth-tasks`` Hub dataset is picked up by the next |
| trainer run with zero code changes here. |
| |
| Per-task probe inputs come from the env's ``/tasks/{name}/sample_inputs`` |
| endpoint, which delegates to the same hand-written fuzzer (for the 9 |
| builtins) or auto-fuzzer (for Hub-driven tasks) that the verifier uses. |
| This guarantees the in-context probes the model trains on are drawn from |
| the same distribution as the verifier's fuzz batch. |
| |
| Difficulty-weighted sampling: harder tasks get more rollouts (longer tail |
| of unique seeds), since the agent needs more attempts to learn them. |
| Defaults: ``easy=8, medium=16, hard=24`` rollouts per task. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import logging |
| import random |
| from typing import Iterable, List, Optional, Sequence |
|
|
| from datasets import Dataset |
|
|
| from .client import EnvClient |
| from .prompt import build_prompt |
|
|
| log = logging.getLogger("opensleuth.dataset") |
|
|
| |
| |
| DEFAULT_N_BY_DIFFICULTY = {"easy": 8, "medium": 16, "hard": 24} |
| |
| DEFAULT_N_FALLBACK = 16 |
|
|
|
|
| |
| |
| |
|
|
|
|
| def discover_functions( |
| client: EnvClient, |
| *, |
| source: str = "all", |
| include: Optional[Sequence[str]] = None, |
| difficulty: Optional[str] = None, |
| ) -> List[dict]: |
| """Return the live task catalog from the env Space, optionally filtered. |
| |
| Parameters: |
| ``source``: ``"builtin" | "hub" | "all"`` (default ``"all"``). |
| ``include``: if non-empty, keep only tasks whose ``name`` is in it. |
| ``difficulty``: ``"easy" | "medium" | "hard" | "all" | None``. |
| ``None`` and ``"all"`` mean no filtering. |
| """ |
| tasks = client.list_tasks(source=source) |
| if difficulty and difficulty.lower() != "all": |
| tasks = [t for t in tasks if (t.get("difficulty") or "").lower() == difficulty.lower()] |
| if include: |
| wanted = {n.strip() for n in include if n and n.strip()} |
| if wanted: |
| tasks = [t for t in tasks if t["name"] in wanted] |
| if not tasks: |
| raise RuntimeError( |
| f"discover_functions filtered down to 0 tasks " |
| f"(source={source!r}, include={include!r}, difficulty={difficulty!r})." |
| ) |
| return tasks |
|
|
|
|
| |
| |
| |
| FUNCTIONS_FOR_TRAINING: List[str] = [ |
| "fibonacci", |
| "reverse_string", |
| "is_palindrome", |
| "digit_sum", |
| "count_vowels", |
| "gcd", |
| "sort_unique", |
| "caesar_cipher", |
| "is_prime", |
| ] |
|
|
|
|
| |
| |
| |
|
|
|
|
| def _make_probe_inputs( |
| target_name: str, |
| rng: random.Random, |
| n: int, |
| *, |
| client: Optional[EnvClient] = None, |
| seed: Optional[int] = None, |
| ) -> List[str]: |
| """Get ``n`` Python-literal repr strings appropriate for ``target_name``. |
| |
| Preferred path: hit the env's ``/tasks/{name}/sample_inputs`` endpoint |
| so the trainer-side probe pool is always in lock-step with the |
| verifier's fuzzer. Falls back to a tiny hardcoded pool only for the |
| 9 legacy builtins so callers without a client (e.g. unit tests) still |
| work. |
| |
| ``rng`` is consulted only for the legacy fallback path; when ``client`` |
| is provided we forward ``seed`` (or a fresh one drawn from ``rng``) to |
| the env so the result is reproducible across runs. |
| """ |
| if client is not None: |
| if seed is None: |
| seed = rng.randrange(0, 2**31) if rng is not None else 0 |
| try: |
| return client.sample_inputs(target_name=target_name, n=n, seed=seed) |
| except Exception as e: |
| |
| |
| log.warning( |
| "env sample_inputs(%s, n=%d, seed=%s) failed: %s; falling back to legacy pool", |
| target_name, n, seed, e, |
| ) |
| return _legacy_probe_pool(target_name, rng, n) |
|
|
|
|
| def _legacy_probe_pool(target_name: str, rng: random.Random, n: int) -> List[str]: |
| """Hardcoded pool for the 9 builtin functions. Kept as a fallback only |
| so unit tests / offline callers still work; the live trainer uses |
| ``client.sample_inputs`` exclusively.""" |
| if target_name == "fibonacci": |
| pool = [1, 2, 5, 10, 20, 40, 89, -1, 0, 100] |
| elif target_name == "reverse_string": |
| pool = ['""', "'a'", "'hello'", "'racecar'", "'abc123'", "''", "'ab'"] |
| return [rng.choice(pool) for _ in range(n)] |
| elif target_name == "is_palindrome": |
| pool = ["'racecar'", "'hello'", "'A man a plan a canal Panama'", "''", "'ab'", "'aba'"] |
| return [rng.choice(pool) for _ in range(n)] |
| elif target_name == "digit_sum": |
| pool = [0, 1, 9, 10, 99, 100, 12345, -3] |
| elif target_name == "count_vowels": |
| pool = ["'hello'", "''", "'rhythm'", "'AEIOU'", "'xyz'", "'queueing'"] |
| return [rng.choice(pool) for _ in range(n)] |
| elif target_name == "gcd": |
| pool = ["(12, 8)", "(7, 13)", "(0, 5)", "[15, 25]", "(100, 75)", "[6, 9]"] |
| return [rng.choice(pool) for _ in range(n)] |
| elif target_name == "sort_unique": |
| pool = ["[3, 1, 2, 1]", "[]", "[5, 5, 5]", "[-1, 0, -1, 2]", "[10]"] |
| return [rng.choice(pool) for _ in range(n)] |
| elif target_name == "caesar_cipher": |
| pool = ["'hello'", "'abc'", "'xyz'", "''", "'Hello!'", "'a b c'"] |
| return [rng.choice(pool) for _ in range(n)] |
| elif target_name == "is_prime": |
| pool = [2, 3, 4, 7, 9, 11, 25, 29, 0, 1, -3] |
| else: |
| return ["1"] * n |
| return [repr(rng.choice(pool)) for _ in range(n)] |
|
|
|
|
| |
| |
| |
|
|
|
|
| def _sample_probes( |
| client: EnvClient, |
| target_name: str, |
| seed: int, |
| n_probes: int, |
| ) -> tuple[str, list[tuple[str, str, bool]]]: |
| """Open an episode and feed it ``n_probes`` random valid inputs sourced |
| from the env's own auto-fuzzer.""" |
| rng = random.Random(seed) |
| ep = client.reset(target_name=target_name, seed=seed, max_steps=n_probes + 5) |
| sig = ep["target_function_signature"] |
| eid = ep["episode_id"] |
|
|
| inputs = _make_probe_inputs(target_name, rng, n_probes, client=client, seed=seed) |
| history: list[tuple[str, str, bool]] = [] |
| for inp_repr in inputs: |
| try: |
| resp = client.probe(eid, inp_repr) |
| except Exception as e: |
| log.warning("probe failed for %s with %r: %s", target_name, inp_repr, e) |
| continue |
| last = resp["observation"]["probe_history"][-1] |
| history.append((last["input_repr"], last["output_repr"], bool(last["is_error"]))) |
| return sig, history |
|
|
|
|
| |
| |
| |
|
|
|
|
| def build_synthesis_dataset( |
| client: EnvClient, |
| *, |
| n_per_function: Optional[int] = None, |
| n_easy: int = DEFAULT_N_BY_DIFFICULTY["easy"], |
| n_medium: int = DEFAULT_N_BY_DIFFICULTY["medium"], |
| n_hard: int = DEFAULT_N_BY_DIFFICULTY["hard"], |
| n_probes: int = 6, |
| seed: int = 0, |
| include: Optional[Sequence[str]] = None, |
| difficulty: Optional[str] = None, |
| tasks: Optional[Iterable[dict]] = None, |
| ) -> Dataset: |
| """Build a HuggingFace Dataset of {prompt, target_function_name} rows. |
| |
| ``n_per_function`` (legacy v0.3 knob) overrides the difficulty-weighted |
| sampling and applies a uniform N to every task. The new default behaviour |
| is to sample ``n_easy / n_medium / n_hard`` rollouts per task by |
| difficulty bucket; harder tasks need more rollouts to learn. |
| """ |
| if tasks is None: |
| tasks = discover_functions( |
| client, include=include, difficulty=difficulty, |
| ) |
| tasks = list(tasks) |
|
|
| by_diff = {"easy": n_easy, "medium": n_medium, "hard": n_hard} |
|
|
| rows = [] |
| rng = random.Random(seed) |
| log.info("building dataset over %d task(s); per-difficulty rollouts: %s%s", |
| len(tasks), by_diff, |
| f" (override n_per_function={n_per_function})" if n_per_function else "") |
| for task in tasks: |
| fn_name = task["name"] |
| diff = (task.get("difficulty") or "").lower() |
| if n_per_function is not None: |
| n_rollouts = int(n_per_function) |
| else: |
| n_rollouts = by_diff.get(diff, DEFAULT_N_FALLBACK) |
| log.info(" %-22s difficulty=%-8s rollouts=%d source=%s", |
| fn_name, diff or "?", n_rollouts, task.get("source", "?")) |
| for _ in range(n_rollouts): |
| row_seed = rng.randrange(0, 2**31) |
| try: |
| sig, probes = _sample_probes(client, fn_name, row_seed, n_probes) |
| except Exception as e: |
| log.warning("rollout build failed for %s seed=%d: %s; skipping row", |
| fn_name, row_seed, e) |
| continue |
| prompt = build_prompt(fn_name, sig, probes) |
| rows.append( |
| { |
| "prompt": prompt, |
| "target_function_name": fn_name, |
| "row_seed": row_seed, |
| "difficulty": diff or "unknown", |
| } |
| ) |
| rng.shuffle(rows) |
| log.info("built dataset: %d rows total", len(rows)) |
| return Dataset.from_list(rows) |
|
|