anugrah55 commited on
Commit
d597642
·
verified ·
1 Parent(s): d8aa978

Overhaul trainer: TRL GRPO with env-backed reward, Qwen2.5-0.5B 4bit+LoRA, slim PyTorch CUDA base, heartbeat HTTP for HF Spaces health probe

Browse files
Dockerfile CHANGED
@@ -1,13 +1,33 @@
1
- # Use a base image with CUDA and Python
2
- FROM huggingface/transformers-pytorch-gpu:latest
 
 
 
 
 
 
 
 
3
 
4
- # Copy all the files from the repo to the container
5
- COPY . /app
6
  WORKDIR /app
7
 
8
- # Install dependencies
9
- RUN pip install -r requirements.txt
10
- RUN pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git" trl peft bitsandbytes requests
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- # Run the training script using python3
13
- CMD ["python3", "hf_train_runner.py"]
 
1
+ # Slim, well-tested CUDA + PyTorch base. Avoids HF transformers-pytorch-gpu's
2
+ # bloat and unsloth's CUDA-version sensitivity.
3
+ FROM pytorch/pytorch:2.5.1-cuda12.1-cudnn9-runtime
4
+
5
+ ENV PYTHONUNBUFFERED=1 \
6
+ PIP_NO_CACHE_DIR=1 \
7
+ PIP_DISABLE_PIP_VERSION_CHECK=1 \
8
+ HF_HOME=/data/.cache/huggingface \
9
+ TRANSFORMERS_CACHE=/data/.cache/huggingface/hub \
10
+ TOKENIZERS_PARALLELISM=false
11
 
 
 
12
  WORKDIR /app
13
 
14
+ # System deps for bitsandbytes / build
15
+ RUN apt-get update && apt-get install -y --no-install-recommends \
16
+ git curl build-essential \
17
+ && rm -rf /var/lib/apt/lists/*
18
+
19
+ COPY requirements.txt /app/
20
+ RUN pip install --no-cache-dir -r requirements.txt
21
+
22
+ # Project code
23
+ COPY opensleuth_train /app/opensleuth_train
24
+ COPY train.py /app/
25
+ COPY entrypoint.sh /app/
26
+ RUN chmod +x /app/entrypoint.sh \
27
+ && mkdir -p /data/opensleuth-grpo /data/.cache/huggingface
28
+
29
+ # HF Spaces health probe expects the container to expose a port; keep it open
30
+ # so the orchestrator considers us alive while training runs.
31
+ EXPOSE 7860
32
 
33
+ CMD ["/app/entrypoint.sh"]
 
README.md CHANGED
@@ -1,10 +1,61 @@
1
  ---
2
- title: Opensleuth Training Gemini Cli
3
- emoji: 🌖
4
- colorFrom: purple
5
- colorTo: green
6
  sdk: docker
 
7
  pinned: false
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: OpenSleuth Trainer
3
+ emoji: 🛰️
4
+ colorFrom: red
5
+ colorTo: yellow
6
  sdk: docker
7
+ app_port: 7860
8
  pinned: false
9
+ suggested_hardware: t4-small
10
  ---
11
 
12
+ # OpenSleuth Trainer
13
+
14
+ GPU Space that fine-tunes a small Qwen2.5 model with TRL **GRPO** to do
15
+ in-context program synthesis against the live OpenSleuth env.
16
+
17
+ ## Pipeline
18
+
19
+ 1. Wait for the env Space to report healthy.
20
+ 2. Build a dataset of synthesis prompts: each row pairs one black-box
21
+ function with N pre-sampled `(input, output)` probes drawn from the env.
22
+ 3. Load `Qwen/Qwen2.5-0.5B-Instruct` in 4-bit + LoRA via `bitsandbytes` and
23
+ `peft`.
24
+ 4. Train with `trl.GRPOTrainer`, generating `num_generations=4` candidate
25
+ completions per prompt and rewarding each against the env's verifier.
26
+ 5. Persist the LoRA adapter to `/data/opensleuth-grpo` and (if `HF_TOKEN` is
27
+ set as a Space secret) push to `anugrah55/opensleuth-qwen2.5-0.5b-grpo`.
28
+
29
+ ## Reward
30
+
31
+ * `env_verifier_reward = env.score_submission(...) / 100` — the headline
32
+ shaped reward, ranging roughly `[-0.5, +1.5]`.
33
+ * `format_reward` — small bonus for emitting a fenced ```python``` block
34
+ whose `def` matches the target function name; helps the model converge on
35
+ parseable output early.
36
+
37
+ ## Hardware
38
+
39
+ `t4-small` is sufficient for 0.5B + LoRA + bnb-4bit. `a10g-small` will train
40
+ faster if available.
41
+
42
+ ## Required Space secrets
43
+
44
+ * `HF_TOKEN` — write token if you want the LoRA adapter pushed to the Hub at
45
+ the end of training.
46
+
47
+ ## Tuning knobs
48
+
49
+ All knobs are exposed as env vars (defaults shown):
50
+
51
+ | Env var | Default | Meaning |
52
+ |---------|---------|---------|
53
+ | `ENV_URL` | env Space URL | OpenSleuth env to target |
54
+ | `MODEL_NAME` | `Qwen/Qwen2.5-0.5B-Instruct` | Base policy |
55
+ | `N_PER_FUNCTION` | `16` | Prompts per black-box function |
56
+ | `N_PROBES` | `6` | Probes per prompt |
57
+ | `NUM_GENERATIONS` | `4` | GRPO group size |
58
+ | `LEARNING_RATE` | `1e-5` | |
59
+ | `NUM_TRAIN_EPOCHS` | `1` | |
60
+ | `PER_DEVICE_BATCH_SIZE` | `1` | |
61
+ | `GRAD_ACCUM` | `8` | |
__pycache__/train.cpython-313.pyc ADDED
Binary file (11.7 kB). View file
 
entrypoint.sh ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # OpenSleuth training Space entrypoint.
3
+ #
4
+ # Starts a tiny background HTTP server on $PORT (default 7860) so the HF
5
+ # Spaces health probe is satisfied, then runs the actual training script in
6
+ # the foreground. All training logs go to stdout and are visible in the
7
+ # Space's "Container logs" tab.
8
+ set -euo pipefail
9
+
10
+ PORT="${PORT:-7860}"
11
+
12
+ log() { echo "[entrypoint $(date -u +%H:%M:%S)] $*"; }
13
+
14
+ # 1. Background heartbeat HTTP server. Just returns 200 OK on every request.
15
+ log "starting heartbeat server on :${PORT}"
16
+ python -c "
17
+ import http.server, socketserver, os, threading, time
18
+ class H(http.server.BaseHTTPRequestHandler):
19
+ def do_GET(self):
20
+ self.send_response(200)
21
+ self.send_header('Content-Type','text/plain')
22
+ self.end_headers()
23
+ self.wfile.write(b'opensleuth-trainer alive\n')
24
+ def log_message(self, *a, **kw): pass
25
+ port = int(os.environ.get('PORT','7860'))
26
+ srv = socketserver.TCPServer(('0.0.0.0', port), H)
27
+ threading.Thread(target=srv.serve_forever, daemon=True).start()
28
+ print(f'[heartbeat] listening on :{port}', flush=True)
29
+ while True: time.sleep(3600)
30
+ " &
31
+ HB_PID=$!
32
+
33
+ # Give the heartbeat a moment to bind before the orchestrator probes it.
34
+ sleep 2
35
+
36
+ # 2. Run training in the foreground. Crash here = container exits, which is
37
+ # what we want: HF will mark the Space failed and surface the error.
38
+ log "starting training (PID $$)"
39
+ log "GPU info:"
40
+ python -c "import torch; print('cuda available:', torch.cuda.is_available()); print('device:', torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'cpu')"
41
+
42
+ exec python /app/train.py
opensleuth_train/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """OpenSleuth training-side helpers (env client, dataset, reward fn)."""
2
+
3
+ from .client import EnvClient
4
+ from .dataset import build_synthesis_dataset, FUNCTIONS_FOR_TRAINING
5
+ from .prompt import SYSTEM_PROMPT, build_prompt, extract_code
6
+
7
+ __all__ = [
8
+ "EnvClient",
9
+ "build_synthesis_dataset",
10
+ "FUNCTIONS_FOR_TRAINING",
11
+ "SYSTEM_PROMPT",
12
+ "build_prompt",
13
+ "extract_code",
14
+ ]
opensleuth_train/client.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Thin HTTP client for the OpenSleuth env Space."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import os
7
+ import time
8
+ from typing import Any, Dict
9
+
10
+ import requests
11
+
12
+ log = logging.getLogger("opensleuth.client")
13
+
14
+
15
+ class EnvClient:
16
+ def __init__(self, base_url: str | None = None, timeout: float = 30.0, retries: int = 3):
17
+ self.base_url = (base_url or os.environ.get("ENV_URL", "http://127.0.0.1:7860")).rstrip("/")
18
+ self.timeout = timeout
19
+ self.retries = retries
20
+
21
+ def _post(self, path: str, payload: Dict[str, Any]) -> Dict[str, Any]:
22
+ last_exc: Exception | None = None
23
+ for attempt in range(self.retries):
24
+ try:
25
+ r = requests.post(f"{self.base_url}{path}", json=payload, timeout=self.timeout)
26
+ r.raise_for_status()
27
+ return r.json()
28
+ except (requests.RequestException, ValueError) as e: # noqa: PERF203
29
+ last_exc = e
30
+ wait = 0.5 * (2 ** attempt)
31
+ log.warning("env POST %s failed (%s); retrying in %.1fs", path, e, wait)
32
+ time.sleep(wait)
33
+ raise RuntimeError(f"env POST {path} failed after {self.retries} retries: {last_exc}")
34
+
35
+ def health(self) -> Dict[str, Any]:
36
+ r = requests.get(f"{self.base_url}/health", timeout=self.timeout)
37
+ r.raise_for_status()
38
+ return r.json()
39
+
40
+ def list_functions(self) -> list[Dict[str, str]]:
41
+ r = requests.get(f"{self.base_url}/functions", timeout=self.timeout)
42
+ r.raise_for_status()
43
+ return r.json()["functions"]
44
+
45
+ def reset(self, target_name: str, seed: int = 0, max_steps: int = 25) -> Dict[str, Any]:
46
+ return self._post("/reset", {"target_name": target_name, "seed": seed, "max_steps": max_steps})
47
+
48
+ def step(self, episode_id: str, action: Dict[str, Any]) -> Dict[str, Any]:
49
+ return self._post("/step", {"episode_id": episode_id, "action": action})
50
+
51
+ # --- High-level helpers used by the reward function --------------------
52
+
53
+ def submit(self, episode_id: str, code: str) -> Dict[str, Any]:
54
+ return self.step(episode_id, {"action_type": "submit", "code": code})
55
+
56
+ def probe(self, episode_id: str, input_repr: str) -> Dict[str, Any]:
57
+ return self.step(episode_id, {"action_type": "probe", "input_repr": input_repr})
58
+
59
+ def score_submission(self, target_name: str, code: str, seed: int = 0) -> float:
60
+ """One-shot: open an episode, submit the code, return total reward."""
61
+ ep = self.reset(target_name=target_name, seed=seed, max_steps=2)
62
+ resp = self.submit(ep["episode_id"], code)
63
+ return float(resp["reward"])
opensleuth_train/dataset.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Build the training dataset of (function_name, signature, probes) → prompt.
2
+
3
+ We pre-sample probes server-side with deterministic seeds so the LLM trains
4
+ on a consistent set of in-context examples per task. The actual *reward* is
5
+ computed by re-submitting the model's code against the env with a fresh fuzz
6
+ seed, so the model can't memorise probe outputs.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import logging
12
+ import random
13
+ from typing import List
14
+
15
+ from datasets import Dataset
16
+
17
+ from .client import EnvClient
18
+ from .prompt import build_prompt
19
+
20
+ log = logging.getLogger("opensleuth.dataset")
21
+
22
+ FUNCTIONS_FOR_TRAINING: List[str] = [
23
+ "fibonacci",
24
+ "reverse_string",
25
+ "is_palindrome",
26
+ "digit_sum",
27
+ "count_vowels",
28
+ "gcd",
29
+ "sort_unique",
30
+ "caesar_cipher",
31
+ "is_prime",
32
+ ]
33
+
34
+
35
+ def _sample_probes(client: EnvClient, target_name: str, seed: int, n_probes: int) -> tuple[str, list[tuple[str, str, bool]]]:
36
+ """Open an episode and feed it `n_probes` random valid inputs sourced from
37
+ the env's own fuzz generator (we just hit /functions and synthesise inputs
38
+ locally to avoid coupling to a specific spec API)."""
39
+ rng = random.Random(seed)
40
+ ep = client.reset(target_name=target_name, seed=seed, max_steps=n_probes + 5)
41
+ sig = ep["target_function_signature"]
42
+ eid = ep["episode_id"]
43
+
44
+ inputs = _make_probe_inputs(target_name, rng, n_probes)
45
+ history: list[tuple[str, str, bool]] = []
46
+ for inp_repr in inputs:
47
+ resp = client.probe(eid, inp_repr)
48
+ last = resp["observation"]["probe_history"][-1]
49
+ history.append((last["input_repr"], last["output_repr"], bool(last["is_error"])))
50
+ return sig, history
51
+
52
+
53
+ def _make_probe_inputs(target_name: str, rng: random.Random, n: int) -> list[str]:
54
+ """Generate `n` Python-literal repr strings appropriate for this function.
55
+
56
+ Kept in lock-step (loosely) with the env's fuzz generators so probes
57
+ almost always land on the function's valid domain, with a few intentional
58
+ out-of-domain inputs to expose error-handling.
59
+ """
60
+ if target_name == "fibonacci":
61
+ pool = [1, 2, 5, 10, 20, 40, 89, -1, 0, 100]
62
+ elif target_name == "reverse_string":
63
+ pool = ['""', "'a'", "'hello'", "'racecar'", "'abc123'", "''", "'ab'"]
64
+ return [rng.choice(pool) for _ in range(n)]
65
+ elif target_name == "is_palindrome":
66
+ pool = ["'racecar'", "'hello'", "'A man a plan a canal Panama'", "''", "'ab'", "'aba'"]
67
+ return [rng.choice(pool) for _ in range(n)]
68
+ elif target_name == "digit_sum":
69
+ pool = [0, 1, 9, 10, 99, 100, 12345, -3]
70
+ elif target_name == "count_vowels":
71
+ pool = ["'hello'", "''", "'rhythm'", "'AEIOU'", "'xyz'", "'queueing'"]
72
+ return [rng.choice(pool) for _ in range(n)]
73
+ elif target_name == "gcd":
74
+ pool = ["(12, 8)", "(7, 13)", "(0, 5)", "[15, 25]", "(100, 75)", "[6, 9]"]
75
+ return [rng.choice(pool) for _ in range(n)]
76
+ elif target_name == "sort_unique":
77
+ pool = ["[3, 1, 2, 1]", "[]", "[5, 5, 5]", "[-1, 0, -1, 2]", "[10]"]
78
+ return [rng.choice(pool) for _ in range(n)]
79
+ elif target_name == "caesar_cipher":
80
+ pool = ["'hello'", "'abc'", "'xyz'", "''", "'Hello!'", "'a b c'"]
81
+ return [rng.choice(pool) for _ in range(n)]
82
+ elif target_name == "is_prime":
83
+ pool = [2, 3, 4, 7, 9, 11, 25, 29, 0, 1, -3]
84
+ else:
85
+ return ["1"] * n
86
+ return [repr(rng.choice(pool)) for _ in range(n)]
87
+
88
+
89
+ def build_synthesis_dataset(
90
+ client: EnvClient,
91
+ *,
92
+ n_per_function: int = 24,
93
+ n_probes: int = 6,
94
+ seed: int = 0,
95
+ ) -> Dataset:
96
+ """Build a HuggingFace Dataset of {prompt, target_function_name} rows."""
97
+ rows = []
98
+ rng = random.Random(seed)
99
+ for fn_name in FUNCTIONS_FOR_TRAINING:
100
+ for k in range(n_per_function):
101
+ row_seed = rng.randrange(0, 2**31)
102
+ sig, probes = _sample_probes(client, fn_name, row_seed, n_probes)
103
+ prompt = build_prompt(fn_name, sig, probes)
104
+ rows.append(
105
+ {
106
+ "prompt": prompt,
107
+ "target_function_name": fn_name,
108
+ "row_seed": row_seed,
109
+ }
110
+ )
111
+ rng.shuffle(rows)
112
+ return Dataset.from_list(rows)
opensleuth_train/prompt.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Prompt construction + code extraction for the OpenSleuth agent."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import re
6
+ from typing import Iterable
7
+
8
+ SYSTEM_PROMPT = (
9
+ "You are an algorithmic detective. You are given the public signature of a hidden "
10
+ "Python function plus several (input, output) examples observed by probing it. "
11
+ "Your job is to write a Python function that *exactly* reproduces the hidden "
12
+ "function's behavior on all valid inputs. Match its return values AND its "
13
+ "exception types on invalid inputs. Keep your implementation as simple and clean "
14
+ "as possible (it is penalised for being needlessly branchy). Return ONLY the "
15
+ "function definition wrapped in a single ```python ... ``` code block."
16
+ )
17
+
18
+
19
+ def build_prompt(target_name: str, signature: str, probes: Iterable[tuple[str, str, bool]]) -> str:
20
+ """Build the user-side prompt.
21
+
22
+ `probes` is an iterable of `(input_repr, output_repr, is_error)` tuples,
23
+ typically pre-sampled by the dataset builder.
24
+ """
25
+ lines = [
26
+ f"## Hidden function: {target_name}",
27
+ "",
28
+ f"### Public signature & docstring",
29
+ signature.strip() or "(no signature provided)",
30
+ "",
31
+ "### Observed probes",
32
+ ]
33
+ probe_list = list(probes)
34
+ if not probe_list:
35
+ lines.append("(none)")
36
+ else:
37
+ for inp, out, is_err in probe_list:
38
+ tag = "raises" if is_err else "returns"
39
+ lines.append(f"- input={inp} -> {tag} {out}")
40
+ lines += [
41
+ "",
42
+ "### Task",
43
+ f"Write a Python function named `{target_name}` that reproduces the hidden "
44
+ "function's behaviour. Return ONLY the function definition in a single "
45
+ "```python ... ``` code block. Do not add explanations.",
46
+ ]
47
+ return "\n".join(lines)
48
+
49
+
50
+ _CODE_RE = re.compile(r"```(?:python)?\s*(.*?)```", re.DOTALL | re.IGNORECASE)
51
+
52
+
53
+ def extract_code(completion: str) -> str:
54
+ """Pull the python source from a model completion. If no fenced block is
55
+ present we fall back to the whole completion (the verifier will then judge
56
+ it on its own)."""
57
+ m = _CODE_RE.search(completion)
58
+ if m:
59
+ return m.group(1).strip()
60
+ return completion.strip()
opensleuth_train/reward.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Reward functions for GRPO: env-backed verification + cheap shaping signals.
2
+
3
+ GRPO takes a list of `reward_funcs`. Each must accept `completions` and any
4
+ columns from the dataset as kwargs, and return one scalar per completion.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import logging
10
+ import re
11
+ from typing import List
12
+
13
+ from .client import EnvClient
14
+ from .prompt import extract_code
15
+
16
+ log = logging.getLogger("opensleuth.reward")
17
+
18
+
19
+ def make_env_reward(client: EnvClient, *, scale: float = 1.0 / 100.0) -> callable:
20
+ """Verifier-backed reward. Calls the env's `submit` and returns the env's
21
+ reward divided by `scale` (default: divide by 100 so a perfect submission
22
+ is ~+1.5 and a bad one is around -0.5; this keeps GRPO advantages well
23
+ behaved without needing reward normalisation).
24
+ """
25
+
26
+ def env_reward(completions, target_function_name=None, row_seed=None, **kwargs): # noqa: ANN001
27
+ rewards: List[float] = []
28
+ # GRPO calls the reward fn once per completion; both target_function_name
29
+ # and row_seed come in as lists of length len(completions).
30
+ for i, completion in enumerate(completions):
31
+ text = _extract_text(completion)
32
+ code = extract_code(text)
33
+ tname = _index(target_function_name, i, default="fibonacci")
34
+ seed = _index(row_seed, i, default=0)
35
+ try:
36
+ env_reward_value = client.score_submission(tname, code, seed=seed)
37
+ except Exception as e: # noqa: BLE001
38
+ log.warning("env scoring failed for %s: %s", tname, e)
39
+ env_reward_value = -50.0
40
+ rewards.append(env_reward_value * scale)
41
+ return rewards
42
+
43
+ return env_reward
44
+
45
+
46
+ _FUNC_RE = re.compile(r"^def\s+(\w+)\s*\(", re.MULTILINE)
47
+
48
+
49
+ def format_reward(completions, target_function_name=None, **kwargs): # noqa: ANN001
50
+ """Cheap shaping reward: +0.2 if the completion contains a fenced python
51
+ block AND defines a function with the right name. Encourages the model to
52
+ converge on the expected output format quickly so the env reward becomes
53
+ informative early in training."""
54
+ rewards: List[float] = []
55
+ for i, completion in enumerate(completions):
56
+ text = _extract_text(completion)
57
+ score = 0.0
58
+ if "```python" in text or "```\n" in text:
59
+ score += 0.1
60
+ code = extract_code(text)
61
+ m = _FUNC_RE.search(code)
62
+ tname = _index(target_function_name, i, default=None)
63
+ if m and (tname is None or m.group(1) == tname):
64
+ score += 0.1
65
+ rewards.append(score)
66
+ return rewards
67
+
68
+
69
+ def _extract_text(completion): # noqa: ANN001
70
+ """GRPO can pass either a string or an OpenAI-style chat list of dicts.
71
+ Normalise to a single string."""
72
+ if isinstance(completion, str):
73
+ return completion
74
+ if isinstance(completion, list):
75
+ # [{role:..., content:...}, ...]
76
+ parts = []
77
+ for msg in completion:
78
+ if isinstance(msg, dict) and "content" in msg:
79
+ parts.append(str(msg["content"]))
80
+ else:
81
+ parts.append(str(msg))
82
+ return "\n".join(parts)
83
+ return str(completion)
84
+
85
+
86
+ def _index(value, i: int, default):
87
+ if value is None:
88
+ return default
89
+ if isinstance(value, list):
90
+ return value[i] if i < len(value) else default
91
+ return value
requirements.txt CHANGED
@@ -1,3 +1,20 @@
1
- fastapi
2
- uvicorn
3
- pydantic
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core ML stack. torch is provided by the base image.
2
+ transformers==4.46.3
3
+ trl==0.13.0
4
+ peft==0.13.2
5
+ accelerate==1.1.1
6
+ bitsandbytes==0.44.1
7
+ datasets==3.1.0
8
+
9
+ # Tokenizer + utility deps
10
+ sentencepiece==0.2.0
11
+ tiktoken==0.8.0
12
+ einops==0.8.0
13
+ safetensors==0.4.5
14
+
15
+ # HTTP + Hub
16
+ requests==2.32.3
17
+ huggingface_hub==0.26.2
18
+
19
+ # Misc
20
+ numpy==1.26.4
train.py CHANGED
@@ -1,157 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
- import requests
3
- from transformers import AutoTokenizer
4
- from unsloth import FastLanguageModel
5
- from trl import GPPOTrainer, PPOConfig
6
- import json
7
- import re
8
-
9
- # == 1. Constants ==
10
- MAX_STEPS_PER_EPISODE = 15
11
- ENV_URL = "https://anugrah55-opensleuth-env-gemini-cli.hf.space"
12
- MODEL_NAME = "unsloth/qwen2-0.5b-instruct-sft-bnb-4bit"
13
-
14
- # == 2. Prompt Engineering ==
15
- def build_prompt(probe_history):
16
- """
17
- Creates the prompt for the LLM based on the probe history.
18
- """
19
- prompt = "You are a reverse-engineering AI. Your goal is to understand a hidden black-box function by probing it and then writing a Python replica.\\n\\n"
20
- prompt += "== Probe History ==\\n"
21
- if not probe_history:
22
- prompt += "No probes yet. Your first action should be a probe.\\n"
23
- else:
24
- for i, (inp, out) in enumerate(probe_history):
25
- prompt += f"{i+1}. IN: {inp} -> OUT: {out}\\n"
26
-
27
- prompt += "\\n== Your Action ==\\n"
28
- prompt += "You can either PROBE or SUBMIT.\\n"
29
- prompt += "To probe, respond with: PROBE(input)\\n"
30
- prompt += "To submit your code, respond with: SUBMIT\\n```python\\n[your code here]\\n```\\n"
31
- prompt += "Your decision: "
32
- return prompt
33
-
34
- # == 3. Action Parsing ==
35
- def parse_action_from_response(response_text):
36
- """
37
- Parses the model's text response to determine the action.
38
- """
39
- probe_match = re.search(r"PROBE\\((.*)\\)", response_text)
40
- if probe_match:
41
- inp = probe_match.group(1).strip()
42
- return {"action_type": "probe", "input": inp}
43
-
44
- submit_match = re.search(r"SUBMIT\\s*```python\\n(.*)```", response_text, re.DOTALL)
45
- if submit_match:
46
- code = submit_match.group(1).strip()
47
- return {"action_type": "submit", "code": code}
48
-
49
- # Default to a probe if parsing fails
50
- return {"action_type": "probe", "input": "1"}
51
-
52
-
53
- # == 4. Main Training Script ==
54
- def main():
55
- # --- Initialize Model ---
56
- model, tokenizer = FastLanguageModel.from_pretrained(
57
- model_name = MODEL_NAME,
58
- max_seq_length = 2048,
59
- dtype = None,
60
- load_in_4bit = True,
 
 
 
 
 
 
 
 
61
  )
62
- # LoRA configuration
63
- model = FastLanguageModel.get_peft_model(
64
- model,
65
- r = 16,
66
- target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
67
- lora_alpha = 16,
68
- lora_dropout = 0,
69
- bias = "none",
70
- use_gradient_checkpointing = True,
71
- random_state = 3407,
72
- use_rslora = False,
73
- loftq_config = None,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  )
75
 
76
- # --- Initialize GPPO Trainer ---
77
- # Note: GPPO is a new trainer in TRL and might require specific config.
78
- # This is a placeholder configuration.
79
- ppo_config = PPOConfig(
80
- batch_size=4,
81
- mini_batch_size=1,
82
- learning_rate=1.41e-5,
83
- adap_kl_ctrl=False,
84
- log_with="tensorboard",
85
- project_kwargs={"logging_dir": "./logs"}
 
 
 
 
 
 
 
 
 
 
 
86
  )
87
 
88
- # We need a dataset for the trainer, even if it's just a dummy one for initialization
89
- # In a real RL loop, we provide the experiences directly to the `step` method.
90
- dummy_dataset = [{"query": "dummy"}]
91
- gppo_trainer = GPPOTrainer(
92
- config=ppo_config,
93
- model=model,
94
- tokenizer=tokenizer,
95
- dataset=dummy_dataset,
 
 
 
 
 
96
  )
97
-
98
- # --- Training Loop ---
99
- for episode in range(10): # Run for 10 episodes for demonstration
100
- print(f"--- Episode {episode+1} ---")
101
-
102
- # Reset environment
103
- try:
104
- resp = requests.post(f"{ENV_URL}/reset", json={"target_name": "fibonacci"})
105
- obs = resp.json()
106
- except requests.exceptions.ConnectionError as e:
107
- print(f"ERROR: Could not connect to environment at {ENV_URL}. Is it running?")
108
- print("Please run 'uvicorn server:app --host 0.0.0.0 --port 8000' in the 'opensleuth_env' directory.")
109
- return
110
-
111
- queries, responses, rewards = [], [], []
112
-
113
- for step in range(MAX_STEPS_PER_EPISODE):
114
- # Build prompt and generate action
115
- prompt = build_prompt(obs.get("probe_history", []))
116
- query_tensor = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
117
-
118
- # Generate a response from the model
119
- generation_kwargs = {"min_new_tokens": -1, "top_k": 0.0, "top_p": 1.0, "do_sample": True, "pad_token_id": tokenizer.eos_token_id, "max_new_tokens": 150}
120
- response_tensor = gppo_trainer.generate(query_tensor, **generation_kwargs)
121
- response_text = tokenizer.decode(response_tensor[0])
122
-
123
- # Parse action and execute in environment
124
- action = parse_action_from_response(response_text)
125
- step_resp = requests.post(f"{ENV_URL}/step", json=action)
126
- step_data = step_resp.json()
127
-
128
- reward = torch.tensor(step_data["reward"], dtype=torch.float32)
129
- obs = step_data["observation"]
130
- done = step_data["done"]
131
-
132
- # Store experience
133
- queries.append(query_tensor.squeeze())
134
- responses.append(response_tensor.squeeze())
135
- rewards.append(reward)
136
-
137
- print(f"Step {step+1}: Action: {action['action_type']}, Reward: {reward.item():.2f}")
138
-
139
- if done:
140
- break
141
-
142
- # --- Perform PPO Step ---
143
- # This is a simplified view. The actual step requires careful handling of tensors.
144
- # The `queries`, `responses`, `rewards` lists need to be formatted correctly.
145
- try:
146
- stats = gppo_trainer.step(queries, responses, rewards)
147
- gppo_trainer.log_stats(stats, {}, rewards)
148
- print(f" PPO Step done. Mean reward: {stats['ppo/returns/mean']:.2f}")
149
- except Exception as e:
150
- print(f"ERROR during trainer.step: {e}")
151
- print(" Skipping PPO step for this episode. This might happen if all trajectories are truncated.")
152
 
153
 
154
  if __name__ == "__main__":
155
- # Ensure the server is running before starting training.
156
- # We will run the server in the background from the CLI.
157
- main()
 
1
+ """OpenSleuth GRPO trainer.
2
+
3
+ Trains a small Qwen2.5 model with TRL's GRPOTrainer to do in-context program
4
+ synthesis — given the public signature of a hidden function plus a handful of
5
+ (input, output) probe examples, emit a Python function that reproduces it.
6
+
7
+ Reward comes from the live OpenSleuth env Space: the agent's code is executed
8
+ against the hidden reference under domain-aware fuzzing, and the verifier
9
+ returns an `execution_reward - complexity_penalty` score that we hand back to
10
+ GRPO as the per-completion reward (plus a tiny formatting shaping reward).
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import argparse
16
+ import logging
17
+ import os
18
+ import sys
19
+ import time
20
+
21
  import torch
22
+ from peft import LoraConfig
23
+ from transformers import AutoTokenizer, BitsAndBytesConfig
24
+ from trl import GRPOConfig, GRPOTrainer
25
+
26
+ from opensleuth_train import (
27
+ EnvClient,
28
+ SYSTEM_PROMPT,
29
+ build_synthesis_dataset,
30
+ )
31
+ from opensleuth_train.reward import format_reward, make_env_reward
32
+
33
+
34
+ logging.basicConfig(
35
+ level=logging.INFO,
36
+ format="%(asctime)s %(levelname)s %(name)s: %(message)s",
37
+ stream=sys.stdout,
38
+ )
39
+ log = logging.getLogger("opensleuth.train")
40
+
41
+
42
+ def parse_args() -> argparse.Namespace:
43
+ p = argparse.ArgumentParser()
44
+ p.add_argument("--env-url", default=os.environ.get("ENV_URL", "https://anugrah55-opensleuth-env-gemini-cli.hf.space"))
45
+ p.add_argument("--model-name", default=os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-0.5B-Instruct"))
46
+ p.add_argument("--output-dir", default=os.environ.get("OUTPUT_DIR", "/data/opensleuth-grpo"))
47
+ p.add_argument("--push-to-hub", default=os.environ.get("PUSH_TO_HUB", "anugrah55/opensleuth-qwen2.5-0.5b-grpo"))
48
+ p.add_argument("--n-per-function", type=int, default=int(os.environ.get("N_PER_FUNCTION", "16")))
49
+ p.add_argument("--n-probes", type=int, default=int(os.environ.get("N_PROBES", "6")))
50
+ p.add_argument("--num-generations", type=int, default=int(os.environ.get("NUM_GENERATIONS", "4")))
51
+ p.add_argument("--max-completion-length", type=int, default=int(os.environ.get("MAX_COMPLETION_LENGTH", "320")))
52
+ p.add_argument("--max-prompt-length", type=int, default=int(os.environ.get("MAX_PROMPT_LENGTH", "768")))
53
+ p.add_argument("--learning-rate", type=float, default=float(os.environ.get("LEARNING_RATE", "1e-5")))
54
+ p.add_argument("--num-train-epochs", type=float, default=float(os.environ.get("NUM_TRAIN_EPOCHS", "1")))
55
+ p.add_argument("--per-device-batch-size", type=int, default=int(os.environ.get("PER_DEVICE_BATCH_SIZE", "1")))
56
+ p.add_argument("--gradient-accumulation-steps", type=int, default=int(os.environ.get("GRAD_ACCUM", "8")))
57
+ p.add_argument("--no-4bit", action="store_true", default=os.environ.get("NO_4BIT", "0") == "1")
58
+ p.add_argument("--seed", type=int, default=int(os.environ.get("SEED", "42")))
59
+ return p.parse_args()
60
+
61
+
62
+ def wait_for_env(client: EnvClient, max_wait_s: float = 300.0) -> None:
63
+ log.info("waiting for env at %s ...", client.base_url)
64
+ start = time.time()
65
+ last_err = ""
66
+ while time.time() - start < max_wait_s:
67
+ try:
68
+ h = client.health()
69
+ log.info("env healthy: %s", h)
70
+ return
71
+ except Exception as e: # noqa: BLE001
72
+ last_err = str(e)
73
+ time.sleep(5)
74
+ raise RuntimeError(f"env never became healthy after {max_wait_s}s. Last error: {last_err}")
75
+
76
+
77
+ def main() -> int:
78
+ args = parse_args()
79
+ log.info("args: %s", vars(args))
80
+
81
+ client = EnvClient(base_url=args.env_url, timeout=60.0, retries=4)
82
+ wait_for_env(client)
83
+ fns = client.list_functions()
84
+ log.info("env exposes %d functions: %s", len(fns), [f["name"] for f in fns])
85
+
86
+ log.info("building synthesis dataset (n_per_function=%d, n_probes=%d)", args.n_per_function, args.n_probes)
87
+ dataset = build_synthesis_dataset(
88
+ client, n_per_function=args.n_per_function, n_probes=args.n_probes, seed=args.seed
89
  )
90
+ log.info("dataset size: %d rows", len(dataset))
91
+
92
+ # GRPO with chat-templated prompts: each row needs a "prompt" field, which
93
+ # we re-format as a chat message list so the trainer applies the chat
94
+ # template under the hood.
95
+ def to_chat(row):
96
+ return {
97
+ "prompt": [
98
+ {"role": "system", "content": SYSTEM_PROMPT},
99
+ {"role": "user", "content": row["prompt"]},
100
+ ],
101
+ "target_function_name": row["target_function_name"],
102
+ "row_seed": row["row_seed"],
103
+ }
104
+
105
+ dataset = dataset.map(to_chat, remove_columns=["prompt"])
106
+
107
+ # ---- Model + LoRA ----
108
+ log.info("loading model %s (4bit=%s)", args.model_name, not args.no_4bit)
109
+ bnb_config = None
110
+ if not args.no_4bit:
111
+ bnb_config = BitsAndBytesConfig(
112
+ load_in_4bit=True,
113
+ bnb_4bit_compute_dtype=torch.bfloat16,
114
+ bnb_4bit_use_double_quant=True,
115
+ bnb_4bit_quant_type="nf4",
116
+ )
117
+
118
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True)
119
+ if tokenizer.pad_token is None:
120
+ tokenizer.pad_token = tokenizer.eos_token
121
+
122
+ peft_config = LoraConfig(
123
+ r=16,
124
+ lora_alpha=32,
125
+ lora_dropout=0.05,
126
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
127
+ task_type="CAUSAL_LM",
128
+ bias="none",
129
  )
130
 
131
+ grpo_config = GRPOConfig(
132
+ output_dir=args.output_dir,
133
+ per_device_train_batch_size=args.per_device_batch_size,
134
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
135
+ learning_rate=args.learning_rate,
136
+ num_train_epochs=args.num_train_epochs,
137
+ max_prompt_length=args.max_prompt_length,
138
+ max_completion_length=args.max_completion_length,
139
+ num_generations=args.num_generations,
140
+ beta=0.04,
141
+ bf16=torch.cuda.is_bf16_supported() if torch.cuda.is_available() else False,
142
+ fp16=False,
143
+ logging_steps=1,
144
+ save_steps=50,
145
+ save_total_limit=2,
146
+ report_to=[],
147
+ seed=args.seed,
148
+ push_to_hub=bool(args.push_to_hub) and bool(os.environ.get("HF_TOKEN")),
149
+ hub_model_id=args.push_to_hub or None,
150
+ hub_strategy="end",
151
+ gradient_checkpointing=True,
152
  )
153
 
154
+ env_reward_fn = make_env_reward(client)
155
+ env_reward_fn.__name__ = "env_verifier_reward"
156
+ format_reward.__name__ = "format_reward"
157
+
158
+ log.info("instantiating GRPOTrainer")
159
+ # Newer TRL passes the model name and instantiates internally; this works
160
+ # across recent TRL versions because GRPOTrainer accepts a model id string.
161
+ trainer_kwargs = dict(
162
+ model=args.model_name,
163
+ reward_funcs=[env_reward_fn, format_reward],
164
+ args=grpo_config,
165
+ train_dataset=dataset,
166
+ peft_config=peft_config,
167
  )
168
+ if bnb_config is not None:
169
+ # Some TRL versions accept model_init_kwargs to pass through to from_pretrained.
170
+ trainer_kwargs.setdefault("model_init_kwargs", {})
171
+ trainer_kwargs["model_init_kwargs"].update(
172
+ {"quantization_config": bnb_config, "torch_dtype": torch.bfloat16}
173
+ )
174
+
175
+ try:
176
+ trainer = GRPOTrainer(**trainer_kwargs)
177
+ except TypeError as e:
178
+ # Older TRL (<0.16) doesn't accept model_init_kwargs at GRPOTrainer level;
179
+ # fall back to loading model first.
180
+ log.warning("GRPOTrainer rejected kwargs (%s); falling back to manual model load", e)
181
+ from transformers import AutoModelForCausalLM
182
+ model_kwargs = {"trust_remote_code": True, "torch_dtype": torch.bfloat16}
183
+ if bnb_config is not None:
184
+ model_kwargs["quantization_config"] = bnb_config
185
+ model = AutoModelForCausalLM.from_pretrained(args.model_name, **model_kwargs)
186
+ trainer = GRPOTrainer(
187
+ model=model,
188
+ reward_funcs=[env_reward_fn, format_reward],
189
+ args=grpo_config,
190
+ train_dataset=dataset,
191
+ peft_config=peft_config,
192
+ processing_class=tokenizer,
193
+ )
194
+
195
+ log.info("starting GRPO training")
196
+ trainer.train()
197
+ log.info("training complete; saving to %s", args.output_dir)
198
+ trainer.save_model(args.output_dir)
199
+ if grpo_config.push_to_hub:
200
+ log.info("pushing to hub: %s", args.push_to_hub)
201
+ trainer.push_to_hub()
202
+ return 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
 
205
  if __name__ == "__main__":
206
+ sys.exit(main())