anugrah55's picture
trainer v0.4: switch to Qwen2.5-3B-Instruct, dynamic task discovery, delegated probe sampling, difficulty-weighted rollouts, push to opensleuth-qwen2.5-3b-grpo-v2; sentinel cleared on FORCE_TRAIN=1.
78575eb verified
"""Build the training dataset of (function_name, signature, probes) → prompt.
v0.4 update: tasks and probe inputs are *discovered from the live env*, not
hardcoded on the trainer side. This means a fresh task pushed to the
``anugrah55/opensleuth-tasks`` Hub dataset is picked up by the next
trainer run with zero code changes here.
Per-task probe inputs come from the env's ``/tasks/{name}/sample_inputs``
endpoint, which delegates to the same hand-written fuzzer (for the 9
builtins) or auto-fuzzer (for Hub-driven tasks) that the verifier uses.
This guarantees the in-context probes the model trains on are drawn from
the same distribution as the verifier's fuzz batch.
Difficulty-weighted sampling: harder tasks get more rollouts (longer tail
of unique seeds), since the agent needs more attempts to learn them.
Defaults: ``easy=8, medium=16, hard=24`` rollouts per task.
"""
from __future__ import annotations
import logging
import random
from typing import Iterable, List, Optional, Sequence
from datasets import Dataset
from .client import EnvClient
from .prompt import build_prompt
log = logging.getLogger("opensleuth.dataset")
# Difficulty bucket → default rollouts per task. Caller can override per call
# or via N_EASY / N_MEDIUM / N_HARD env vars in train.py.
DEFAULT_N_BY_DIFFICULTY = {"easy": 8, "medium": 16, "hard": 24}
# Tasks with no/unknown difficulty fall back to "medium".
DEFAULT_N_FALLBACK = 16
# ---------------------------------------------------------------------------
# Task discovery
# ---------------------------------------------------------------------------
def discover_functions(
client: EnvClient,
*,
source: str = "all",
include: Optional[Sequence[str]] = None,
difficulty: Optional[str] = None,
) -> List[dict]:
"""Return the live task catalog from the env Space, optionally filtered.
Parameters:
``source``: ``"builtin" | "hub" | "all"`` (default ``"all"``).
``include``: if non-empty, keep only tasks whose ``name`` is in it.
``difficulty``: ``"easy" | "medium" | "hard" | "all" | None``.
``None`` and ``"all"`` mean no filtering.
"""
tasks = client.list_tasks(source=source)
if difficulty and difficulty.lower() != "all":
tasks = [t for t in tasks if (t.get("difficulty") or "").lower() == difficulty.lower()]
if include:
wanted = {n.strip() for n in include if n and n.strip()}
if wanted:
tasks = [t for t in tasks if t["name"] in wanted]
if not tasks:
raise RuntimeError(
f"discover_functions filtered down to 0 tasks "
f"(source={source!r}, include={include!r}, difficulty={difficulty!r})."
)
return tasks
# Backwards-compat shim: old callers (eval/run_eval.py) imported a static
# list. Now defaults to the 9 builtins so import-time consumers don't make
# a network call. Use ``discover_functions(client)`` for the live catalog.
FUNCTIONS_FOR_TRAINING: List[str] = [
"fibonacci",
"reverse_string",
"is_palindrome",
"digit_sum",
"count_vowels",
"gcd",
"sort_unique",
"caesar_cipher",
"is_prime",
]
# ---------------------------------------------------------------------------
# Probe sampling -- delegated to env's auto-fuzzer
# ---------------------------------------------------------------------------
def _make_probe_inputs(
target_name: str,
rng: random.Random,
n: int,
*,
client: Optional[EnvClient] = None,
seed: Optional[int] = None,
) -> List[str]:
"""Get ``n`` Python-literal repr strings appropriate for ``target_name``.
Preferred path: hit the env's ``/tasks/{name}/sample_inputs`` endpoint
so the trainer-side probe pool is always in lock-step with the
verifier's fuzzer. Falls back to a tiny hardcoded pool only for the
9 legacy builtins so callers without a client (e.g. unit tests) still
work.
``rng`` is consulted only for the legacy fallback path; when ``client``
is provided we forward ``seed`` (or a fresh one drawn from ``rng``) to
the env so the result is reproducible across runs.
"""
if client is not None:
if seed is None:
seed = rng.randrange(0, 2**31) if rng is not None else 0
try:
return client.sample_inputs(target_name=target_name, n=n, seed=seed)
except Exception as e: # noqa: BLE001
# Don't crash the dataset build if the env hiccups -- fall through
# to the legacy pool for builtins, or "1" * n for unknowns.
log.warning(
"env sample_inputs(%s, n=%d, seed=%s) failed: %s; falling back to legacy pool",
target_name, n, seed, e,
)
return _legacy_probe_pool(target_name, rng, n)
def _legacy_probe_pool(target_name: str, rng: random.Random, n: int) -> List[str]:
"""Hardcoded pool for the 9 builtin functions. Kept as a fallback only
so unit tests / offline callers still work; the live trainer uses
``client.sample_inputs`` exclusively."""
if target_name == "fibonacci":
pool = [1, 2, 5, 10, 20, 40, 89, -1, 0, 100]
elif target_name == "reverse_string":
pool = ['""', "'a'", "'hello'", "'racecar'", "'abc123'", "''", "'ab'"]
return [rng.choice(pool) for _ in range(n)]
elif target_name == "is_palindrome":
pool = ["'racecar'", "'hello'", "'A man a plan a canal Panama'", "''", "'ab'", "'aba'"]
return [rng.choice(pool) for _ in range(n)]
elif target_name == "digit_sum":
pool = [0, 1, 9, 10, 99, 100, 12345, -3]
elif target_name == "count_vowels":
pool = ["'hello'", "''", "'rhythm'", "'AEIOU'", "'xyz'", "'queueing'"]
return [rng.choice(pool) for _ in range(n)]
elif target_name == "gcd":
pool = ["(12, 8)", "(7, 13)", "(0, 5)", "[15, 25]", "(100, 75)", "[6, 9]"]
return [rng.choice(pool) for _ in range(n)]
elif target_name == "sort_unique":
pool = ["[3, 1, 2, 1]", "[]", "[5, 5, 5]", "[-1, 0, -1, 2]", "[10]"]
return [rng.choice(pool) for _ in range(n)]
elif target_name == "caesar_cipher":
pool = ["'hello'", "'abc'", "'xyz'", "''", "'Hello!'", "'a b c'"]
return [rng.choice(pool) for _ in range(n)]
elif target_name == "is_prime":
pool = [2, 3, 4, 7, 9, 11, 25, 29, 0, 1, -3]
else:
return ["1"] * n
return [repr(rng.choice(pool)) for _ in range(n)]
# ---------------------------------------------------------------------------
# Single-row sampler
# ---------------------------------------------------------------------------
def _sample_probes(
client: EnvClient,
target_name: str,
seed: int,
n_probes: int,
) -> tuple[str, list[tuple[str, str, bool]]]:
"""Open an episode and feed it ``n_probes`` random valid inputs sourced
from the env's own auto-fuzzer."""
rng = random.Random(seed)
ep = client.reset(target_name=target_name, seed=seed, max_steps=n_probes + 5)
sig = ep["target_function_signature"]
eid = ep["episode_id"]
inputs = _make_probe_inputs(target_name, rng, n_probes, client=client, seed=seed)
history: list[tuple[str, str, bool]] = []
for inp_repr in inputs:
try:
resp = client.probe(eid, inp_repr)
except Exception as e: # noqa: BLE001
log.warning("probe failed for %s with %r: %s", target_name, inp_repr, e)
continue
last = resp["observation"]["probe_history"][-1]
history.append((last["input_repr"], last["output_repr"], bool(last["is_error"])))
return sig, history
# ---------------------------------------------------------------------------
# Dataset builder
# ---------------------------------------------------------------------------
def build_synthesis_dataset(
client: EnvClient,
*,
n_per_function: Optional[int] = None,
n_easy: int = DEFAULT_N_BY_DIFFICULTY["easy"],
n_medium: int = DEFAULT_N_BY_DIFFICULTY["medium"],
n_hard: int = DEFAULT_N_BY_DIFFICULTY["hard"],
n_probes: int = 6,
seed: int = 0,
include: Optional[Sequence[str]] = None,
difficulty: Optional[str] = None,
tasks: Optional[Iterable[dict]] = None,
) -> Dataset:
"""Build a HuggingFace Dataset of {prompt, target_function_name} rows.
``n_per_function`` (legacy v0.3 knob) overrides the difficulty-weighted
sampling and applies a uniform N to every task. The new default behaviour
is to sample ``n_easy / n_medium / n_hard`` rollouts per task by
difficulty bucket; harder tasks need more rollouts to learn.
"""
if tasks is None:
tasks = discover_functions(
client, include=include, difficulty=difficulty,
)
tasks = list(tasks)
by_diff = {"easy": n_easy, "medium": n_medium, "hard": n_hard}
rows = []
rng = random.Random(seed)
log.info("building dataset over %d task(s); per-difficulty rollouts: %s%s",
len(tasks), by_diff,
f" (override n_per_function={n_per_function})" if n_per_function else "")
for task in tasks:
fn_name = task["name"]
diff = (task.get("difficulty") or "").lower()
if n_per_function is not None:
n_rollouts = int(n_per_function)
else:
n_rollouts = by_diff.get(diff, DEFAULT_N_FALLBACK)
log.info(" %-22s difficulty=%-8s rollouts=%d source=%s",
fn_name, diff or "?", n_rollouts, task.get("source", "?"))
for _ in range(n_rollouts):
row_seed = rng.randrange(0, 2**31)
try:
sig, probes = _sample_probes(client, fn_name, row_seed, n_probes)
except Exception as e: # noqa: BLE001
log.warning("rollout build failed for %s seed=%d: %s; skipping row",
fn_name, row_seed, e)
continue
prompt = build_prompt(fn_name, sig, probes)
rows.append(
{
"prompt": prompt,
"target_function_name": fn_name,
"row_seed": row_seed,
"difficulty": diff or "unknown",
}
)
rng.shuffle(rows)
log.info("built dataset: %d rows total", len(rows))
return Dataset.from_list(rows)