anugrah55 commited on
Commit
78575eb
·
verified ·
1 Parent(s): 8c92f05

trainer v0.4: switch to Qwen2.5-3B-Instruct, dynamic task discovery, delegated probe sampling, difficulty-weighted rollouts, push to opensleuth-qwen2.5-3b-grpo-v2; sentinel cleared on FORCE_TRAIN=1.

Browse files
entrypoint.sh CHANGED
@@ -37,7 +37,17 @@ sleep 2
37
  # previous container start (the Space orchestrator restarts containers
38
  # that exit cleanly), so just idle on the heartbeat to avoid wasting GPU
39
  # on duplicate runs. Set FORCE_TRAIN=1 to override.
 
 
 
 
 
 
40
  SENTINEL="/data/.opensleuth-trained"
 
 
 
 
41
  if [[ -f "$SENTINEL" && -z "${FORCE_TRAIN:-}" ]]; then
42
  log "sentinel $SENTINEL exists; skipping training (set FORCE_TRAIN=1 to retrain). Idling..."
43
  sleep infinity
 
37
  # previous container start (the Space orchestrator restarts containers
38
  # that exit cleanly), so just idle on the heartbeat to avoid wasting GPU
39
  # on duplicate runs. Set FORCE_TRAIN=1 to override.
40
+ #
41
+ # v0.4 update: when FORCE_TRAIN=1 is set, we explicitly *delete* the old
42
+ # sentinel up-front. Without this the sentinel from a previous v0.2 run
43
+ # (Qwen 0.5B / 9 builtins) blocks the v0.4 run (Qwen 3B / 15 tasks) on
44
+ # Space restart. The sentinel only ever gets re-touched after a fresh
45
+ # successful training run completes below.
46
  SENTINEL="/data/.opensleuth-trained"
47
+ if [[ -n "${FORCE_TRAIN:-}" && -f "$SENTINEL" ]]; then
48
+ log "FORCE_TRAIN=1 set; removing stale sentinel $SENTINEL so we re-train."
49
+ rm -f "$SENTINEL"
50
+ fi
51
  if [[ -f "$SENTINEL" && -z "${FORCE_TRAIN:-}" ]]; then
52
  log "sentinel $SENTINEL exists; skipping training (set FORCE_TRAIN=1 to retrain). Idling..."
53
  sleep infinity
opensleuth_train/__init__.py CHANGED
@@ -1,13 +1,20 @@
1
  """OpenSleuth training-side helpers (env client, dataset, reward fn)."""
2
 
3
  from .client import EnvClient
4
- from .dataset import build_synthesis_dataset, FUNCTIONS_FOR_TRAINING
 
 
 
 
 
5
  from .prompt import SYSTEM_PROMPT, build_prompt, extract_code
6
 
7
  __all__ = [
8
  "EnvClient",
9
- "build_synthesis_dataset",
10
  "FUNCTIONS_FOR_TRAINING",
 
 
11
  "SYSTEM_PROMPT",
12
  "build_prompt",
13
  "extract_code",
 
1
  """OpenSleuth training-side helpers (env client, dataset, reward fn)."""
2
 
3
  from .client import EnvClient
4
+ from .dataset import (
5
+ DEFAULT_N_BY_DIFFICULTY,
6
+ FUNCTIONS_FOR_TRAINING,
7
+ build_synthesis_dataset,
8
+ discover_functions,
9
+ )
10
  from .prompt import SYSTEM_PROMPT, build_prompt, extract_code
11
 
12
  __all__ = [
13
  "EnvClient",
14
+ "DEFAULT_N_BY_DIFFICULTY",
15
  "FUNCTIONS_FOR_TRAINING",
16
+ "build_synthesis_dataset",
17
+ "discover_functions",
18
  "SYSTEM_PROMPT",
19
  "build_prompt",
20
  "extract_code",
opensleuth_train/client.py CHANGED
@@ -5,7 +5,7 @@ from __future__ import annotations
5
  import logging
6
  import os
7
  import time
8
- from typing import Any, Dict
9
 
10
  import requests
11
 
@@ -32,16 +32,57 @@ class EnvClient:
32
  time.sleep(wait)
33
  raise RuntimeError(f"env POST {path} failed after {self.retries} retries: {last_exc}")
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  def health(self) -> Dict[str, Any]:
36
  r = requests.get(f"{self.base_url}/health", timeout=self.timeout)
37
  r.raise_for_status()
38
  return r.json()
39
 
40
  def list_functions(self) -> list[Dict[str, str]]:
 
41
  r = requests.get(f"{self.base_url}/functions", timeout=self.timeout)
42
  r.raise_for_status()
43
  return r.json()["functions"]
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  def reset(self, target_name: str, seed: int = 0, max_steps: int = 25) -> Dict[str, Any]:
46
  return self._post("/reset", {"target_name": target_name, "seed": seed, "max_steps": max_steps})
47
 
 
5
  import logging
6
  import os
7
  import time
8
+ from typing import Any, Dict, List, Optional
9
 
10
  import requests
11
 
 
32
  time.sleep(wait)
33
  raise RuntimeError(f"env POST {path} failed after {self.retries} retries: {last_exc}")
34
 
35
+ def _get(self, path: str, params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
36
+ last_exc: Exception | None = None
37
+ for attempt in range(self.retries):
38
+ try:
39
+ r = requests.get(f"{self.base_url}{path}", params=params, timeout=self.timeout)
40
+ r.raise_for_status()
41
+ return r.json()
42
+ except (requests.RequestException, ValueError) as e: # noqa: PERF203
43
+ last_exc = e
44
+ wait = 0.5 * (2 ** attempt)
45
+ log.warning("env GET %s failed (%s); retrying in %.1fs", path, e, wait)
46
+ time.sleep(wait)
47
+ raise RuntimeError(f"env GET {path} failed after {self.retries} retries: {last_exc}")
48
+
49
  def health(self) -> Dict[str, Any]:
50
  r = requests.get(f"{self.base_url}/health", timeout=self.timeout)
51
  r.raise_for_status()
52
  return r.json()
53
 
54
  def list_functions(self) -> list[Dict[str, str]]:
55
+ """Legacy v0.3 endpoint -- only the 9 builtin functions."""
56
  r = requests.get(f"{self.base_url}/functions", timeout=self.timeout)
57
  r.raise_for_status()
58
  return r.json()["functions"]
59
 
60
+ def list_tasks(
61
+ self,
62
+ source: str = "all",
63
+ difficulty: Optional[str] = None,
64
+ ) -> List[Dict[str, Any]]:
65
+ """v0.4 catalog endpoint -- builtins + Hub-driven tasks.
66
+
67
+ Each item carries: ``name``, ``signature``, ``description``,
68
+ ``difficulty`` (``easy|medium|hard|None``), ``edge_case_count``,
69
+ ``source`` (``builtin|hub``).
70
+ """
71
+ params: Dict[str, Any] = {"source": source}
72
+ if difficulty:
73
+ params["difficulty"] = difficulty
74
+ return self._get("/tasks", params=params)["tasks"]
75
+
76
+ def sample_inputs(self, target_name: str, n: int = 8, seed: int = 0) -> List[str]:
77
+ """Pull ``n`` ready-to-probe input_repr strings from the env's own
78
+ auto-fuzzer. Encapsulates the fuzz logic on the env side so the
79
+ trainer doesn't have to keep its own per-task input pools in sync."""
80
+ resp = self._get(
81
+ f"/tasks/{target_name}/sample_inputs",
82
+ params={"n": n, "seed": seed},
83
+ )
84
+ return list(resp["inputs"])
85
+
86
  def reset(self, target_name: str, seed: int = 0, max_steps: int = 25) -> Dict[str, Any]:
87
  return self._post("/reset", {"target_name": target_name, "seed": seed, "max_steps": max_steps})
88
 
opensleuth_train/dataset.py CHANGED
@@ -1,16 +1,26 @@
1
  """Build the training dataset of (function_name, signature, probes) → prompt.
2
 
3
- We pre-sample probes server-side with deterministic seeds so the LLM trains
4
- on a consistent set of in-context examples per task. The actual *reward* is
5
- computed by re-submitting the model's code against the env with a fresh fuzz
6
- seed, so the model can't memorise probe outputs.
 
 
 
 
 
 
 
 
 
 
7
  """
8
 
9
  from __future__ import annotations
10
 
11
  import logging
12
  import random
13
- from typing import List
14
 
15
  from datasets import Dataset
16
 
@@ -19,6 +29,51 @@ from .prompt import build_prompt
19
 
20
  log = logging.getLogger("opensleuth.dataset")
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  FUNCTIONS_FOR_TRAINING: List[str] = [
23
  "fibonacci",
24
  "reverse_string",
@@ -32,31 +87,50 @@ FUNCTIONS_FOR_TRAINING: List[str] = [
32
  ]
33
 
34
 
35
- def _sample_probes(client: EnvClient, target_name: str, seed: int, n_probes: int) -> tuple[str, list[tuple[str, str, bool]]]:
36
- """Open an episode and feed it `n_probes` random valid inputs sourced from
37
- the env's own fuzz generator (we just hit /functions and synthesise inputs
38
- locally to avoid coupling to a specific spec API)."""
39
- rng = random.Random(seed)
40
- ep = client.reset(target_name=target_name, seed=seed, max_steps=n_probes + 5)
41
- sig = ep["target_function_signature"]
42
- eid = ep["episode_id"]
43
 
44
- inputs = _make_probe_inputs(target_name, rng, n_probes)
45
- history: list[tuple[str, str, bool]] = []
46
- for inp_repr in inputs:
47
- resp = client.probe(eid, inp_repr)
48
- last = resp["observation"]["probe_history"][-1]
49
- history.append((last["input_repr"], last["output_repr"], bool(last["is_error"])))
50
- return sig, history
51
 
 
 
 
 
 
 
 
 
 
52
 
53
- def _make_probe_inputs(target_name: str, rng: random.Random, n: int) -> list[str]:
54
- """Generate `n` Python-literal repr strings appropriate for this function.
 
 
 
55
 
56
- Kept in lock-step (loosely) with the env's fuzz generators so probes
57
- almost always land on the function's valid domain, with a few intentional
58
- out-of-domain inputs to expose error-handling.
59
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  if target_name == "fibonacci":
61
  pool = [1, 2, 5, 10, 20, 40, 89, -1, 0, 100]
62
  elif target_name == "reverse_string":
@@ -86,27 +160,101 @@ def _make_probe_inputs(target_name: str, rng: random.Random, n: int) -> list[str
86
  return [repr(rng.choice(pool)) for _ in range(n)]
87
 
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  def build_synthesis_dataset(
90
  client: EnvClient,
91
  *,
92
- n_per_function: int = 24,
 
 
 
93
  n_probes: int = 6,
94
  seed: int = 0,
 
 
 
95
  ) -> Dataset:
96
- """Build a HuggingFace Dataset of {prompt, target_function_name} rows."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  rows = []
98
  rng = random.Random(seed)
99
- for fn_name in FUNCTIONS_FOR_TRAINING:
100
- for k in range(n_per_function):
 
 
 
 
 
 
 
 
 
 
 
101
  row_seed = rng.randrange(0, 2**31)
102
- sig, probes = _sample_probes(client, fn_name, row_seed, n_probes)
 
 
 
 
 
103
  prompt = build_prompt(fn_name, sig, probes)
104
  rows.append(
105
  {
106
  "prompt": prompt,
107
  "target_function_name": fn_name,
108
  "row_seed": row_seed,
 
109
  }
110
  )
111
  rng.shuffle(rows)
 
112
  return Dataset.from_list(rows)
 
1
  """Build the training dataset of (function_name, signature, probes) → prompt.
2
 
3
+ v0.4 update: tasks and probe inputs are *discovered from the live env*, not
4
+ hardcoded on the trainer side. This means a fresh task pushed to the
5
+ ``anugrah55/opensleuth-tasks`` Hub dataset is picked up by the next
6
+ trainer run with zero code changes here.
7
+
8
+ Per-task probe inputs come from the env's ``/tasks/{name}/sample_inputs``
9
+ endpoint, which delegates to the same hand-written fuzzer (for the 9
10
+ builtins) or auto-fuzzer (for Hub-driven tasks) that the verifier uses.
11
+ This guarantees the in-context probes the model trains on are drawn from
12
+ the same distribution as the verifier's fuzz batch.
13
+
14
+ Difficulty-weighted sampling: harder tasks get more rollouts (longer tail
15
+ of unique seeds), since the agent needs more attempts to learn them.
16
+ Defaults: ``easy=8, medium=16, hard=24`` rollouts per task.
17
  """
18
 
19
  from __future__ import annotations
20
 
21
  import logging
22
  import random
23
+ from typing import Iterable, List, Optional, Sequence
24
 
25
  from datasets import Dataset
26
 
 
29
 
30
  log = logging.getLogger("opensleuth.dataset")
31
 
32
+ # Difficulty bucket → default rollouts per task. Caller can override per call
33
+ # or via N_EASY / N_MEDIUM / N_HARD env vars in train.py.
34
+ DEFAULT_N_BY_DIFFICULTY = {"easy": 8, "medium": 16, "hard": 24}
35
+ # Tasks with no/unknown difficulty fall back to "medium".
36
+ DEFAULT_N_FALLBACK = 16
37
+
38
+
39
+ # ---------------------------------------------------------------------------
40
+ # Task discovery
41
+ # ---------------------------------------------------------------------------
42
+
43
+
44
+ def discover_functions(
45
+ client: EnvClient,
46
+ *,
47
+ source: str = "all",
48
+ include: Optional[Sequence[str]] = None,
49
+ difficulty: Optional[str] = None,
50
+ ) -> List[dict]:
51
+ """Return the live task catalog from the env Space, optionally filtered.
52
+
53
+ Parameters:
54
+ ``source``: ``"builtin" | "hub" | "all"`` (default ``"all"``).
55
+ ``include``: if non-empty, keep only tasks whose ``name`` is in it.
56
+ ``difficulty``: ``"easy" | "medium" | "hard" | "all" | None``.
57
+ ``None`` and ``"all"`` mean no filtering.
58
+ """
59
+ tasks = client.list_tasks(source=source)
60
+ if difficulty and difficulty.lower() != "all":
61
+ tasks = [t for t in tasks if (t.get("difficulty") or "").lower() == difficulty.lower()]
62
+ if include:
63
+ wanted = {n.strip() for n in include if n and n.strip()}
64
+ if wanted:
65
+ tasks = [t for t in tasks if t["name"] in wanted]
66
+ if not tasks:
67
+ raise RuntimeError(
68
+ f"discover_functions filtered down to 0 tasks "
69
+ f"(source={source!r}, include={include!r}, difficulty={difficulty!r})."
70
+ )
71
+ return tasks
72
+
73
+
74
+ # Backwards-compat shim: old callers (eval/run_eval.py) imported a static
75
+ # list. Now defaults to the 9 builtins so import-time consumers don't make
76
+ # a network call. Use ``discover_functions(client)`` for the live catalog.
77
  FUNCTIONS_FOR_TRAINING: List[str] = [
78
  "fibonacci",
79
  "reverse_string",
 
87
  ]
88
 
89
 
90
+ # ---------------------------------------------------------------------------
91
+ # Probe sampling -- delegated to env's auto-fuzzer
92
+ # ---------------------------------------------------------------------------
 
 
 
 
 
93
 
 
 
 
 
 
 
 
94
 
95
+ def _make_probe_inputs(
96
+ target_name: str,
97
+ rng: random.Random,
98
+ n: int,
99
+ *,
100
+ client: Optional[EnvClient] = None,
101
+ seed: Optional[int] = None,
102
+ ) -> List[str]:
103
+ """Get ``n`` Python-literal repr strings appropriate for ``target_name``.
104
 
105
+ Preferred path: hit the env's ``/tasks/{name}/sample_inputs`` endpoint
106
+ so the trainer-side probe pool is always in lock-step with the
107
+ verifier's fuzzer. Falls back to a tiny hardcoded pool only for the
108
+ 9 legacy builtins so callers without a client (e.g. unit tests) still
109
+ work.
110
 
111
+ ``rng`` is consulted only for the legacy fallback path; when ``client``
112
+ is provided we forward ``seed`` (or a fresh one drawn from ``rng``) to
113
+ the env so the result is reproducible across runs.
114
  """
115
+ if client is not None:
116
+ if seed is None:
117
+ seed = rng.randrange(0, 2**31) if rng is not None else 0
118
+ try:
119
+ return client.sample_inputs(target_name=target_name, n=n, seed=seed)
120
+ except Exception as e: # noqa: BLE001
121
+ # Don't crash the dataset build if the env hiccups -- fall through
122
+ # to the legacy pool for builtins, or "1" * n for unknowns.
123
+ log.warning(
124
+ "env sample_inputs(%s, n=%d, seed=%s) failed: %s; falling back to legacy pool",
125
+ target_name, n, seed, e,
126
+ )
127
+ return _legacy_probe_pool(target_name, rng, n)
128
+
129
+
130
+ def _legacy_probe_pool(target_name: str, rng: random.Random, n: int) -> List[str]:
131
+ """Hardcoded pool for the 9 builtin functions. Kept as a fallback only
132
+ so unit tests / offline callers still work; the live trainer uses
133
+ ``client.sample_inputs`` exclusively."""
134
  if target_name == "fibonacci":
135
  pool = [1, 2, 5, 10, 20, 40, 89, -1, 0, 100]
136
  elif target_name == "reverse_string":
 
160
  return [repr(rng.choice(pool)) for _ in range(n)]
161
 
162
 
163
+ # ---------------------------------------------------------------------------
164
+ # Single-row sampler
165
+ # ---------------------------------------------------------------------------
166
+
167
+
168
+ def _sample_probes(
169
+ client: EnvClient,
170
+ target_name: str,
171
+ seed: int,
172
+ n_probes: int,
173
+ ) -> tuple[str, list[tuple[str, str, bool]]]:
174
+ """Open an episode and feed it ``n_probes`` random valid inputs sourced
175
+ from the env's own auto-fuzzer."""
176
+ rng = random.Random(seed)
177
+ ep = client.reset(target_name=target_name, seed=seed, max_steps=n_probes + 5)
178
+ sig = ep["target_function_signature"]
179
+ eid = ep["episode_id"]
180
+
181
+ inputs = _make_probe_inputs(target_name, rng, n_probes, client=client, seed=seed)
182
+ history: list[tuple[str, str, bool]] = []
183
+ for inp_repr in inputs:
184
+ try:
185
+ resp = client.probe(eid, inp_repr)
186
+ except Exception as e: # noqa: BLE001
187
+ log.warning("probe failed for %s with %r: %s", target_name, inp_repr, e)
188
+ continue
189
+ last = resp["observation"]["probe_history"][-1]
190
+ history.append((last["input_repr"], last["output_repr"], bool(last["is_error"])))
191
+ return sig, history
192
+
193
+
194
+ # ---------------------------------------------------------------------------
195
+ # Dataset builder
196
+ # ---------------------------------------------------------------------------
197
+
198
+
199
  def build_synthesis_dataset(
200
  client: EnvClient,
201
  *,
202
+ n_per_function: Optional[int] = None,
203
+ n_easy: int = DEFAULT_N_BY_DIFFICULTY["easy"],
204
+ n_medium: int = DEFAULT_N_BY_DIFFICULTY["medium"],
205
+ n_hard: int = DEFAULT_N_BY_DIFFICULTY["hard"],
206
  n_probes: int = 6,
207
  seed: int = 0,
208
+ include: Optional[Sequence[str]] = None,
209
+ difficulty: Optional[str] = None,
210
+ tasks: Optional[Iterable[dict]] = None,
211
  ) -> Dataset:
212
+ """Build a HuggingFace Dataset of {prompt, target_function_name} rows.
213
+
214
+ ``n_per_function`` (legacy v0.3 knob) overrides the difficulty-weighted
215
+ sampling and applies a uniform N to every task. The new default behaviour
216
+ is to sample ``n_easy / n_medium / n_hard`` rollouts per task by
217
+ difficulty bucket; harder tasks need more rollouts to learn.
218
+ """
219
+ if tasks is None:
220
+ tasks = discover_functions(
221
+ client, include=include, difficulty=difficulty,
222
+ )
223
+ tasks = list(tasks)
224
+
225
+ by_diff = {"easy": n_easy, "medium": n_medium, "hard": n_hard}
226
+
227
  rows = []
228
  rng = random.Random(seed)
229
+ log.info("building dataset over %d task(s); per-difficulty rollouts: %s%s",
230
+ len(tasks), by_diff,
231
+ f" (override n_per_function={n_per_function})" if n_per_function else "")
232
+ for task in tasks:
233
+ fn_name = task["name"]
234
+ diff = (task.get("difficulty") or "").lower()
235
+ if n_per_function is not None:
236
+ n_rollouts = int(n_per_function)
237
+ else:
238
+ n_rollouts = by_diff.get(diff, DEFAULT_N_FALLBACK)
239
+ log.info(" %-22s difficulty=%-8s rollouts=%d source=%s",
240
+ fn_name, diff or "?", n_rollouts, task.get("source", "?"))
241
+ for _ in range(n_rollouts):
242
  row_seed = rng.randrange(0, 2**31)
243
+ try:
244
+ sig, probes = _sample_probes(client, fn_name, row_seed, n_probes)
245
+ except Exception as e: # noqa: BLE001
246
+ log.warning("rollout build failed for %s seed=%d: %s; skipping row",
247
+ fn_name, row_seed, e)
248
+ continue
249
  prompt = build_prompt(fn_name, sig, probes)
250
  rows.append(
251
  {
252
  "prompt": prompt,
253
  "target_function_name": fn_name,
254
  "row_seed": row_seed,
255
+ "difficulty": diff or "unknown",
256
  }
257
  )
258
  rng.shuffle(rows)
259
+ log.info("built dataset: %d rows total", len(rows))
260
  return Dataset.from_list(rows)
train.py CHANGED
@@ -27,6 +27,7 @@ from opensleuth_train import (
27
  EnvClient,
28
  SYSTEM_PROMPT,
29
  build_synthesis_dataset,
 
30
  )
31
  from opensleuth_train.reward import format_reward, make_env_reward
32
 
@@ -39,23 +40,58 @@ logging.basicConfig(
39
  log = logging.getLogger("opensleuth.train")
40
 
41
 
 
 
 
 
42
  def parse_args() -> argparse.Namespace:
43
  p = argparse.ArgumentParser()
44
  p.add_argument("--env-url", default=os.environ.get("ENV_URL", "https://anugrah55-opensleuth-env-gemini-cli.hf.space"))
45
- p.add_argument("--model-name", default=os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-0.5B-Instruct"))
 
 
 
46
  p.add_argument("--output-dir", default=os.environ.get("OUTPUT_DIR", "/data/opensleuth-grpo"))
47
- p.add_argument("--push-to-hub", default=os.environ.get("PUSH_TO_HUB", "anugrah55/opensleuth-qwen2.5-0.5b-grpo"))
48
- p.add_argument("--n-per-function", type=int, default=int(os.environ.get("N_PER_FUNCTION", "16")))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  p.add_argument("--n-probes", type=int, default=int(os.environ.get("N_PROBES", "6")))
50
- p.add_argument("--num-generations", type=int, default=int(os.environ.get("NUM_GENERATIONS", "4")))
51
- p.add_argument("--max-completion-length", type=int, default=int(os.environ.get("MAX_COMPLETION_LENGTH", "320")))
52
- p.add_argument("--max-prompt-length", type=int, default=int(os.environ.get("MAX_PROMPT_LENGTH", "768")))
 
53
  p.add_argument("--learning-rate", type=float, default=float(os.environ.get("LEARNING_RATE", "1e-5")))
54
  p.add_argument("--num-train-epochs", type=float, default=float(os.environ.get("NUM_TRAIN_EPOCHS", "1")))
55
  # GRPO requires per_device_train_batch_size to be a multiple of num_generations
56
  # (one prompt is repeated num_generations times, all in the same forward pass).
57
- # Default to 1 prompt × num_generations completions per device step.
58
- p.add_argument("--per-device-batch-size", type=int, default=int(os.environ.get("PER_DEVICE_BATCH_SIZE", "0")))
59
  p.add_argument("--gradient-accumulation-steps", type=int, default=int(os.environ.get("GRAD_ACCUM", "4")))
60
  p.add_argument("--no-4bit", action="store_true", default=os.environ.get("NO_4BIT", "0") == "1")
61
  p.add_argument("--seed", type=int, default=int(os.environ.get("SEED", "42")))
@@ -83,12 +119,34 @@ def main() -> int:
83
 
84
  client = EnvClient(base_url=args.env_url, timeout=60.0, retries=4)
85
  wait_for_env(client)
86
- fns = client.list_functions()
87
- log.info("env exposes %d functions: %s", len(fns), [f["name"] for f in fns])
88
 
89
- log.info("building synthesis dataset (n_per_function=%d, n_probes=%d)", args.n_per_function, args.n_probes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  dataset = build_synthesis_dataset(
91
- client, n_per_function=args.n_per_function, n_probes=args.n_probes, seed=args.seed
 
 
 
 
 
 
 
92
  )
93
  log.info("dataset size: %d rows", len(dataset))
94
 
@@ -105,7 +163,10 @@ def main() -> int:
105
  "row_seed": row["row_seed"],
106
  }
107
 
108
- dataset = dataset.map(to_chat, remove_columns=["prompt"])
 
 
 
109
 
110
  # ---- Model + LoRA ----
111
  log.info("loading model %s (4bit=%s)", args.model_name, not args.no_4bit)
 
27
  EnvClient,
28
  SYSTEM_PROMPT,
29
  build_synthesis_dataset,
30
+ discover_functions,
31
  )
32
  from opensleuth_train.reward import format_reward, make_env_reward
33
 
 
40
  log = logging.getLogger("opensleuth.train")
41
 
42
 
43
+ def _split_csv(s: str) -> list[str]:
44
+ return [x.strip() for x in s.split(",") if x.strip()]
45
+
46
+
47
  def parse_args() -> argparse.Namespace:
48
  p = argparse.ArgumentParser()
49
  p.add_argument("--env-url", default=os.environ.get("ENV_URL", "https://anugrah55-opensleuth-env-gemini-cli.hf.space"))
50
+ # v0.4 default: switch to Qwen2.5-3B-Instruct for the open-ended task pool.
51
+ # The 0.5B baseline saturated easy tasks but couldn't solve the hard /
52
+ # Hub-driven ones. 3B + LoRA + 4-bit fits T4-small (16GB).
53
+ p.add_argument("--model-name", default=os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-3B-Instruct"))
54
  p.add_argument("--output-dir", default=os.environ.get("OUTPUT_DIR", "/data/opensleuth-grpo"))
55
+ p.add_argument(
56
+ "--push-to-hub",
57
+ default=os.environ.get(
58
+ "PUSH_TO_HUB", "anugrah55/opensleuth-qwen2.5-3b-grpo-v2"
59
+ ),
60
+ )
61
+ # Task selection / curriculum knobs (v0.4).
62
+ p.add_argument(
63
+ "--functions",
64
+ default=os.environ.get("FUNCTIONS_INCLUDE", ""),
65
+ help="Comma-separated subset of task names to train on. Empty = all "
66
+ "tasks the env exposes (builtin + Hub).",
67
+ )
68
+ p.add_argument(
69
+ "--difficulty",
70
+ default=os.environ.get("DIFFICULTY_FILTER", "all"),
71
+ choices=["easy", "medium", "hard", "all"],
72
+ help="Curriculum filter: only sample tasks at this difficulty level.",
73
+ )
74
+ # Difficulty-weighted rollout counts. Replaces the v0.3 single
75
+ # n-per-function knob (kept as an optional override).
76
+ p.add_argument("--n-easy", type=int, default=int(os.environ.get("N_EASY", "8")))
77
+ p.add_argument("--n-medium", type=int, default=int(os.environ.get("N_MEDIUM", "16")))
78
+ p.add_argument("--n-hard", type=int, default=int(os.environ.get("N_HARD", "24")))
79
+ p.add_argument(
80
+ "--n-per-function",
81
+ type=int,
82
+ default=int(os.environ.get("N_PER_FUNCTION", "0")),
83
+ help="If >0, overrides per-difficulty rollout counts with a uniform N.",
84
+ )
85
  p.add_argument("--n-probes", type=int, default=int(os.environ.get("N_PROBES", "6")))
86
+ # GRPO/optim defaults sized for T4-small (16GB) + Qwen2.5-3B-4bit + LoRA.
87
+ p.add_argument("--num-generations", type=int, default=int(os.environ.get("NUM_GENERATIONS", "2")))
88
+ p.add_argument("--max-completion-length", type=int, default=int(os.environ.get("MAX_COMPLETION_LENGTH", "384")))
89
+ p.add_argument("--max-prompt-length", type=int, default=int(os.environ.get("MAX_PROMPT_LENGTH", "1024")))
90
  p.add_argument("--learning-rate", type=float, default=float(os.environ.get("LEARNING_RATE", "1e-5")))
91
  p.add_argument("--num-train-epochs", type=float, default=float(os.environ.get("NUM_TRAIN_EPOCHS", "1")))
92
  # GRPO requires per_device_train_batch_size to be a multiple of num_generations
93
  # (one prompt is repeated num_generations times, all in the same forward pass).
94
+ p.add_argument("--per-device-batch-size", type=int, default=int(os.environ.get("PER_DEVICE_BATCH_SIZE", "2")))
 
95
  p.add_argument("--gradient-accumulation-steps", type=int, default=int(os.environ.get("GRAD_ACCUM", "4")))
96
  p.add_argument("--no-4bit", action="store_true", default=os.environ.get("NO_4BIT", "0") == "1")
97
  p.add_argument("--seed", type=int, default=int(os.environ.get("SEED", "42")))
 
119
 
120
  client = EnvClient(base_url=args.env_url, timeout=60.0, retries=4)
121
  wait_for_env(client)
 
 
122
 
123
+ include = _split_csv(args.functions) if args.functions else None
124
+ difficulty = None if args.difficulty == "all" else args.difficulty
125
+ tasks = discover_functions(client, include=include, difficulty=difficulty)
126
+ log.info(
127
+ "env catalog: %d task(s) after filter (include=%s, difficulty=%s):",
128
+ len(tasks), include, difficulty,
129
+ )
130
+ for t in tasks:
131
+ log.info(
132
+ " - %-22s difficulty=%-8s source=%s",
133
+ t["name"], t.get("difficulty"), t.get("source"),
134
+ )
135
+
136
+ n_per_function_override = args.n_per_function if args.n_per_function > 0 else None
137
+ log.info(
138
+ "building synthesis dataset (n_easy=%d n_medium=%d n_hard=%d override=%s n_probes=%d)",
139
+ args.n_easy, args.n_medium, args.n_hard, n_per_function_override, args.n_probes,
140
+ )
141
  dataset = build_synthesis_dataset(
142
+ client,
143
+ tasks=tasks,
144
+ n_per_function=n_per_function_override,
145
+ n_easy=args.n_easy,
146
+ n_medium=args.n_medium,
147
+ n_hard=args.n_hard,
148
+ n_probes=args.n_probes,
149
+ seed=args.seed,
150
  )
151
  log.info("dataset size: %d rows", len(dataset))
152
 
 
163
  "row_seed": row["row_seed"],
164
  }
165
 
166
+ # Drop the human-readable difficulty column from the GRPO-visible map so
167
+ # the trainer doesn't try to forward it as a reward-fn kwarg.
168
+ drop_cols = [c for c in ("prompt", "difficulty") if c in dataset.column_names]
169
+ dataset = dataset.map(to_chat, remove_columns=drop_cols)
170
 
171
  # ---- Model + LoRA ----
172
  log.info("loading model %s (4bit=%s)", args.model_name, not args.no_4bit)