Overhaul trainer: TRL GRPO with env-backed reward, Qwen2.5-0.5B 4bit+LoRA, slim PyTorch CUDA base, heartbeat HTTP for HF Spaces health probe
Browse files- Dockerfile +29 -9
- README.md +56 -5
- __pycache__/train.cpython-313.pyc +0 -0
- entrypoint.sh +42 -0
- opensleuth_train/__init__.py +14 -0
- opensleuth_train/client.py +63 -0
- opensleuth_train/dataset.py +112 -0
- opensleuth_train/prompt.py +60 -0
- opensleuth_train/reward.py +91 -0
- requirements.txt +20 -3
- train.py +196 -147
Dockerfile
CHANGED
|
@@ -1,13 +1,33 @@
|
|
| 1 |
-
#
|
| 2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
-
# Copy all the files from the repo to the container
|
| 5 |
-
COPY . /app
|
| 6 |
WORKDIR /app
|
| 7 |
|
| 8 |
-
#
|
| 9 |
-
RUN
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
|
| 13 |
-
CMD ["python3", "hf_train_runner.py"]
|
|
|
|
| 1 |
+
# Slim, well-tested CUDA + PyTorch base. Avoids HF transformers-pytorch-gpu's
|
| 2 |
+
# bloat and unsloth's CUDA-version sensitivity.
|
| 3 |
+
FROM pytorch/pytorch:2.5.1-cuda12.1-cudnn9-runtime
|
| 4 |
+
|
| 5 |
+
ENV PYTHONUNBUFFERED=1 \
|
| 6 |
+
PIP_NO_CACHE_DIR=1 \
|
| 7 |
+
PIP_DISABLE_PIP_VERSION_CHECK=1 \
|
| 8 |
+
HF_HOME=/data/.cache/huggingface \
|
| 9 |
+
TRANSFORMERS_CACHE=/data/.cache/huggingface/hub \
|
| 10 |
+
TOKENIZERS_PARALLELISM=false
|
| 11 |
|
|
|
|
|
|
|
| 12 |
WORKDIR /app
|
| 13 |
|
| 14 |
+
# System deps for bitsandbytes / build
|
| 15 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 16 |
+
git curl build-essential \
|
| 17 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 18 |
+
|
| 19 |
+
COPY requirements.txt /app/
|
| 20 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 21 |
+
|
| 22 |
+
# Project code
|
| 23 |
+
COPY opensleuth_train /app/opensleuth_train
|
| 24 |
+
COPY train.py /app/
|
| 25 |
+
COPY entrypoint.sh /app/
|
| 26 |
+
RUN chmod +x /app/entrypoint.sh \
|
| 27 |
+
&& mkdir -p /data/opensleuth-grpo /data/.cache/huggingface
|
| 28 |
+
|
| 29 |
+
# HF Spaces health probe expects the container to expose a port; keep it open
|
| 30 |
+
# so the orchestrator considers us alive while training runs.
|
| 31 |
+
EXPOSE 7860
|
| 32 |
|
| 33 |
+
CMD ["/app/entrypoint.sh"]
|
|
|
README.md
CHANGED
|
@@ -1,10 +1,61 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
|
|
|
| 7 |
pinned: false
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: OpenSleuth Trainer
|
| 3 |
+
emoji: 🛰️
|
| 4 |
+
colorFrom: red
|
| 5 |
+
colorTo: yellow
|
| 6 |
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
pinned: false
|
| 9 |
+
suggested_hardware: t4-small
|
| 10 |
---
|
| 11 |
|
| 12 |
+
# OpenSleuth — Trainer
|
| 13 |
+
|
| 14 |
+
GPU Space that fine-tunes a small Qwen2.5 model with TRL **GRPO** to do
|
| 15 |
+
in-context program synthesis against the live OpenSleuth env.
|
| 16 |
+
|
| 17 |
+
## Pipeline
|
| 18 |
+
|
| 19 |
+
1. Wait for the env Space to report healthy.
|
| 20 |
+
2. Build a dataset of synthesis prompts: each row pairs one black-box
|
| 21 |
+
function with N pre-sampled `(input, output)` probes drawn from the env.
|
| 22 |
+
3. Load `Qwen/Qwen2.5-0.5B-Instruct` in 4-bit + LoRA via `bitsandbytes` and
|
| 23 |
+
`peft`.
|
| 24 |
+
4. Train with `trl.GRPOTrainer`, generating `num_generations=4` candidate
|
| 25 |
+
completions per prompt and rewarding each against the env's verifier.
|
| 26 |
+
5. Persist the LoRA adapter to `/data/opensleuth-grpo` and (if `HF_TOKEN` is
|
| 27 |
+
set as a Space secret) push to `anugrah55/opensleuth-qwen2.5-0.5b-grpo`.
|
| 28 |
+
|
| 29 |
+
## Reward
|
| 30 |
+
|
| 31 |
+
* `env_verifier_reward = env.score_submission(...) / 100` — the headline
|
| 32 |
+
shaped reward, ranging roughly `[-0.5, +1.5]`.
|
| 33 |
+
* `format_reward` — small bonus for emitting a fenced ```python``` block
|
| 34 |
+
whose `def` matches the target function name; helps the model converge on
|
| 35 |
+
parseable output early.
|
| 36 |
+
|
| 37 |
+
## Hardware
|
| 38 |
+
|
| 39 |
+
`t4-small` is sufficient for 0.5B + LoRA + bnb-4bit. `a10g-small` will train
|
| 40 |
+
faster if available.
|
| 41 |
+
|
| 42 |
+
## Required Space secrets
|
| 43 |
+
|
| 44 |
+
* `HF_TOKEN` — write token if you want the LoRA adapter pushed to the Hub at
|
| 45 |
+
the end of training.
|
| 46 |
+
|
| 47 |
+
## Tuning knobs
|
| 48 |
+
|
| 49 |
+
All knobs are exposed as env vars (defaults shown):
|
| 50 |
+
|
| 51 |
+
| Env var | Default | Meaning |
|
| 52 |
+
|---------|---------|---------|
|
| 53 |
+
| `ENV_URL` | env Space URL | OpenSleuth env to target |
|
| 54 |
+
| `MODEL_NAME` | `Qwen/Qwen2.5-0.5B-Instruct` | Base policy |
|
| 55 |
+
| `N_PER_FUNCTION` | `16` | Prompts per black-box function |
|
| 56 |
+
| `N_PROBES` | `6` | Probes per prompt |
|
| 57 |
+
| `NUM_GENERATIONS` | `4` | GRPO group size |
|
| 58 |
+
| `LEARNING_RATE` | `1e-5` | |
|
| 59 |
+
| `NUM_TRAIN_EPOCHS` | `1` | |
|
| 60 |
+
| `PER_DEVICE_BATCH_SIZE` | `1` | |
|
| 61 |
+
| `GRAD_ACCUM` | `8` | |
|
__pycache__/train.cpython-313.pyc
ADDED
|
Binary file (11.7 kB). View file
|
|
|
entrypoint.sh
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# OpenSleuth training Space entrypoint.
|
| 3 |
+
#
|
| 4 |
+
# Starts a tiny background HTTP server on $PORT (default 7860) so the HF
|
| 5 |
+
# Spaces health probe is satisfied, then runs the actual training script in
|
| 6 |
+
# the foreground. All training logs go to stdout and are visible in the
|
| 7 |
+
# Space's "Container logs" tab.
|
| 8 |
+
set -euo pipefail
|
| 9 |
+
|
| 10 |
+
PORT="${PORT:-7860}"
|
| 11 |
+
|
| 12 |
+
log() { echo "[entrypoint $(date -u +%H:%M:%S)] $*"; }
|
| 13 |
+
|
| 14 |
+
# 1. Background heartbeat HTTP server. Just returns 200 OK on every request.
|
| 15 |
+
log "starting heartbeat server on :${PORT}"
|
| 16 |
+
python -c "
|
| 17 |
+
import http.server, socketserver, os, threading, time
|
| 18 |
+
class H(http.server.BaseHTTPRequestHandler):
|
| 19 |
+
def do_GET(self):
|
| 20 |
+
self.send_response(200)
|
| 21 |
+
self.send_header('Content-Type','text/plain')
|
| 22 |
+
self.end_headers()
|
| 23 |
+
self.wfile.write(b'opensleuth-trainer alive\n')
|
| 24 |
+
def log_message(self, *a, **kw): pass
|
| 25 |
+
port = int(os.environ.get('PORT','7860'))
|
| 26 |
+
srv = socketserver.TCPServer(('0.0.0.0', port), H)
|
| 27 |
+
threading.Thread(target=srv.serve_forever, daemon=True).start()
|
| 28 |
+
print(f'[heartbeat] listening on :{port}', flush=True)
|
| 29 |
+
while True: time.sleep(3600)
|
| 30 |
+
" &
|
| 31 |
+
HB_PID=$!
|
| 32 |
+
|
| 33 |
+
# Give the heartbeat a moment to bind before the orchestrator probes it.
|
| 34 |
+
sleep 2
|
| 35 |
+
|
| 36 |
+
# 2. Run training in the foreground. Crash here = container exits, which is
|
| 37 |
+
# what we want: HF will mark the Space failed and surface the error.
|
| 38 |
+
log "starting training (PID $$)"
|
| 39 |
+
log "GPU info:"
|
| 40 |
+
python -c "import torch; print('cuda available:', torch.cuda.is_available()); print('device:', torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'cpu')"
|
| 41 |
+
|
| 42 |
+
exec python /app/train.py
|
opensleuth_train/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""OpenSleuth training-side helpers (env client, dataset, reward fn)."""
|
| 2 |
+
|
| 3 |
+
from .client import EnvClient
|
| 4 |
+
from .dataset import build_synthesis_dataset, FUNCTIONS_FOR_TRAINING
|
| 5 |
+
from .prompt import SYSTEM_PROMPT, build_prompt, extract_code
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
"EnvClient",
|
| 9 |
+
"build_synthesis_dataset",
|
| 10 |
+
"FUNCTIONS_FOR_TRAINING",
|
| 11 |
+
"SYSTEM_PROMPT",
|
| 12 |
+
"build_prompt",
|
| 13 |
+
"extract_code",
|
| 14 |
+
]
|
opensleuth_train/client.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Thin HTTP client for the OpenSleuth env Space."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
import time
|
| 8 |
+
from typing import Any, Dict
|
| 9 |
+
|
| 10 |
+
import requests
|
| 11 |
+
|
| 12 |
+
log = logging.getLogger("opensleuth.client")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class EnvClient:
|
| 16 |
+
def __init__(self, base_url: str | None = None, timeout: float = 30.0, retries: int = 3):
|
| 17 |
+
self.base_url = (base_url or os.environ.get("ENV_URL", "http://127.0.0.1:7860")).rstrip("/")
|
| 18 |
+
self.timeout = timeout
|
| 19 |
+
self.retries = retries
|
| 20 |
+
|
| 21 |
+
def _post(self, path: str, payload: Dict[str, Any]) -> Dict[str, Any]:
|
| 22 |
+
last_exc: Exception | None = None
|
| 23 |
+
for attempt in range(self.retries):
|
| 24 |
+
try:
|
| 25 |
+
r = requests.post(f"{self.base_url}{path}", json=payload, timeout=self.timeout)
|
| 26 |
+
r.raise_for_status()
|
| 27 |
+
return r.json()
|
| 28 |
+
except (requests.RequestException, ValueError) as e: # noqa: PERF203
|
| 29 |
+
last_exc = e
|
| 30 |
+
wait = 0.5 * (2 ** attempt)
|
| 31 |
+
log.warning("env POST %s failed (%s); retrying in %.1fs", path, e, wait)
|
| 32 |
+
time.sleep(wait)
|
| 33 |
+
raise RuntimeError(f"env POST {path} failed after {self.retries} retries: {last_exc}")
|
| 34 |
+
|
| 35 |
+
def health(self) -> Dict[str, Any]:
|
| 36 |
+
r = requests.get(f"{self.base_url}/health", timeout=self.timeout)
|
| 37 |
+
r.raise_for_status()
|
| 38 |
+
return r.json()
|
| 39 |
+
|
| 40 |
+
def list_functions(self) -> list[Dict[str, str]]:
|
| 41 |
+
r = requests.get(f"{self.base_url}/functions", timeout=self.timeout)
|
| 42 |
+
r.raise_for_status()
|
| 43 |
+
return r.json()["functions"]
|
| 44 |
+
|
| 45 |
+
def reset(self, target_name: str, seed: int = 0, max_steps: int = 25) -> Dict[str, Any]:
|
| 46 |
+
return self._post("/reset", {"target_name": target_name, "seed": seed, "max_steps": max_steps})
|
| 47 |
+
|
| 48 |
+
def step(self, episode_id: str, action: Dict[str, Any]) -> Dict[str, Any]:
|
| 49 |
+
return self._post("/step", {"episode_id": episode_id, "action": action})
|
| 50 |
+
|
| 51 |
+
# --- High-level helpers used by the reward function --------------------
|
| 52 |
+
|
| 53 |
+
def submit(self, episode_id: str, code: str) -> Dict[str, Any]:
|
| 54 |
+
return self.step(episode_id, {"action_type": "submit", "code": code})
|
| 55 |
+
|
| 56 |
+
def probe(self, episode_id: str, input_repr: str) -> Dict[str, Any]:
|
| 57 |
+
return self.step(episode_id, {"action_type": "probe", "input_repr": input_repr})
|
| 58 |
+
|
| 59 |
+
def score_submission(self, target_name: str, code: str, seed: int = 0) -> float:
|
| 60 |
+
"""One-shot: open an episode, submit the code, return total reward."""
|
| 61 |
+
ep = self.reset(target_name=target_name, seed=seed, max_steps=2)
|
| 62 |
+
resp = self.submit(ep["episode_id"], code)
|
| 63 |
+
return float(resp["reward"])
|
opensleuth_train/dataset.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Build the training dataset of (function_name, signature, probes) → prompt.
|
| 2 |
+
|
| 3 |
+
We pre-sample probes server-side with deterministic seeds so the LLM trains
|
| 4 |
+
on a consistent set of in-context examples per task. The actual *reward* is
|
| 5 |
+
computed by re-submitting the model's code against the env with a fresh fuzz
|
| 6 |
+
seed, so the model can't memorise probe outputs.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import logging
|
| 12 |
+
import random
|
| 13 |
+
from typing import List
|
| 14 |
+
|
| 15 |
+
from datasets import Dataset
|
| 16 |
+
|
| 17 |
+
from .client import EnvClient
|
| 18 |
+
from .prompt import build_prompt
|
| 19 |
+
|
| 20 |
+
log = logging.getLogger("opensleuth.dataset")
|
| 21 |
+
|
| 22 |
+
FUNCTIONS_FOR_TRAINING: List[str] = [
|
| 23 |
+
"fibonacci",
|
| 24 |
+
"reverse_string",
|
| 25 |
+
"is_palindrome",
|
| 26 |
+
"digit_sum",
|
| 27 |
+
"count_vowels",
|
| 28 |
+
"gcd",
|
| 29 |
+
"sort_unique",
|
| 30 |
+
"caesar_cipher",
|
| 31 |
+
"is_prime",
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _sample_probes(client: EnvClient, target_name: str, seed: int, n_probes: int) -> tuple[str, list[tuple[str, str, bool]]]:
|
| 36 |
+
"""Open an episode and feed it `n_probes` random valid inputs sourced from
|
| 37 |
+
the env's own fuzz generator (we just hit /functions and synthesise inputs
|
| 38 |
+
locally to avoid coupling to a specific spec API)."""
|
| 39 |
+
rng = random.Random(seed)
|
| 40 |
+
ep = client.reset(target_name=target_name, seed=seed, max_steps=n_probes + 5)
|
| 41 |
+
sig = ep["target_function_signature"]
|
| 42 |
+
eid = ep["episode_id"]
|
| 43 |
+
|
| 44 |
+
inputs = _make_probe_inputs(target_name, rng, n_probes)
|
| 45 |
+
history: list[tuple[str, str, bool]] = []
|
| 46 |
+
for inp_repr in inputs:
|
| 47 |
+
resp = client.probe(eid, inp_repr)
|
| 48 |
+
last = resp["observation"]["probe_history"][-1]
|
| 49 |
+
history.append((last["input_repr"], last["output_repr"], bool(last["is_error"])))
|
| 50 |
+
return sig, history
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _make_probe_inputs(target_name: str, rng: random.Random, n: int) -> list[str]:
|
| 54 |
+
"""Generate `n` Python-literal repr strings appropriate for this function.
|
| 55 |
+
|
| 56 |
+
Kept in lock-step (loosely) with the env's fuzz generators so probes
|
| 57 |
+
almost always land on the function's valid domain, with a few intentional
|
| 58 |
+
out-of-domain inputs to expose error-handling.
|
| 59 |
+
"""
|
| 60 |
+
if target_name == "fibonacci":
|
| 61 |
+
pool = [1, 2, 5, 10, 20, 40, 89, -1, 0, 100]
|
| 62 |
+
elif target_name == "reverse_string":
|
| 63 |
+
pool = ['""', "'a'", "'hello'", "'racecar'", "'abc123'", "''", "'ab'"]
|
| 64 |
+
return [rng.choice(pool) for _ in range(n)]
|
| 65 |
+
elif target_name == "is_palindrome":
|
| 66 |
+
pool = ["'racecar'", "'hello'", "'A man a plan a canal Panama'", "''", "'ab'", "'aba'"]
|
| 67 |
+
return [rng.choice(pool) for _ in range(n)]
|
| 68 |
+
elif target_name == "digit_sum":
|
| 69 |
+
pool = [0, 1, 9, 10, 99, 100, 12345, -3]
|
| 70 |
+
elif target_name == "count_vowels":
|
| 71 |
+
pool = ["'hello'", "''", "'rhythm'", "'AEIOU'", "'xyz'", "'queueing'"]
|
| 72 |
+
return [rng.choice(pool) for _ in range(n)]
|
| 73 |
+
elif target_name == "gcd":
|
| 74 |
+
pool = ["(12, 8)", "(7, 13)", "(0, 5)", "[15, 25]", "(100, 75)", "[6, 9]"]
|
| 75 |
+
return [rng.choice(pool) for _ in range(n)]
|
| 76 |
+
elif target_name == "sort_unique":
|
| 77 |
+
pool = ["[3, 1, 2, 1]", "[]", "[5, 5, 5]", "[-1, 0, -1, 2]", "[10]"]
|
| 78 |
+
return [rng.choice(pool) for _ in range(n)]
|
| 79 |
+
elif target_name == "caesar_cipher":
|
| 80 |
+
pool = ["'hello'", "'abc'", "'xyz'", "''", "'Hello!'", "'a b c'"]
|
| 81 |
+
return [rng.choice(pool) for _ in range(n)]
|
| 82 |
+
elif target_name == "is_prime":
|
| 83 |
+
pool = [2, 3, 4, 7, 9, 11, 25, 29, 0, 1, -3]
|
| 84 |
+
else:
|
| 85 |
+
return ["1"] * n
|
| 86 |
+
return [repr(rng.choice(pool)) for _ in range(n)]
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def build_synthesis_dataset(
|
| 90 |
+
client: EnvClient,
|
| 91 |
+
*,
|
| 92 |
+
n_per_function: int = 24,
|
| 93 |
+
n_probes: int = 6,
|
| 94 |
+
seed: int = 0,
|
| 95 |
+
) -> Dataset:
|
| 96 |
+
"""Build a HuggingFace Dataset of {prompt, target_function_name} rows."""
|
| 97 |
+
rows = []
|
| 98 |
+
rng = random.Random(seed)
|
| 99 |
+
for fn_name in FUNCTIONS_FOR_TRAINING:
|
| 100 |
+
for k in range(n_per_function):
|
| 101 |
+
row_seed = rng.randrange(0, 2**31)
|
| 102 |
+
sig, probes = _sample_probes(client, fn_name, row_seed, n_probes)
|
| 103 |
+
prompt = build_prompt(fn_name, sig, probes)
|
| 104 |
+
rows.append(
|
| 105 |
+
{
|
| 106 |
+
"prompt": prompt,
|
| 107 |
+
"target_function_name": fn_name,
|
| 108 |
+
"row_seed": row_seed,
|
| 109 |
+
}
|
| 110 |
+
)
|
| 111 |
+
rng.shuffle(rows)
|
| 112 |
+
return Dataset.from_list(rows)
|
opensleuth_train/prompt.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Prompt construction + code extraction for the OpenSleuth agent."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import re
|
| 6 |
+
from typing import Iterable
|
| 7 |
+
|
| 8 |
+
SYSTEM_PROMPT = (
|
| 9 |
+
"You are an algorithmic detective. You are given the public signature of a hidden "
|
| 10 |
+
"Python function plus several (input, output) examples observed by probing it. "
|
| 11 |
+
"Your job is to write a Python function that *exactly* reproduces the hidden "
|
| 12 |
+
"function's behavior on all valid inputs. Match its return values AND its "
|
| 13 |
+
"exception types on invalid inputs. Keep your implementation as simple and clean "
|
| 14 |
+
"as possible (it is penalised for being needlessly branchy). Return ONLY the "
|
| 15 |
+
"function definition wrapped in a single ```python ... ``` code block."
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def build_prompt(target_name: str, signature: str, probes: Iterable[tuple[str, str, bool]]) -> str:
|
| 20 |
+
"""Build the user-side prompt.
|
| 21 |
+
|
| 22 |
+
`probes` is an iterable of `(input_repr, output_repr, is_error)` tuples,
|
| 23 |
+
typically pre-sampled by the dataset builder.
|
| 24 |
+
"""
|
| 25 |
+
lines = [
|
| 26 |
+
f"## Hidden function: {target_name}",
|
| 27 |
+
"",
|
| 28 |
+
f"### Public signature & docstring",
|
| 29 |
+
signature.strip() or "(no signature provided)",
|
| 30 |
+
"",
|
| 31 |
+
"### Observed probes",
|
| 32 |
+
]
|
| 33 |
+
probe_list = list(probes)
|
| 34 |
+
if not probe_list:
|
| 35 |
+
lines.append("(none)")
|
| 36 |
+
else:
|
| 37 |
+
for inp, out, is_err in probe_list:
|
| 38 |
+
tag = "raises" if is_err else "returns"
|
| 39 |
+
lines.append(f"- input={inp} -> {tag} {out}")
|
| 40 |
+
lines += [
|
| 41 |
+
"",
|
| 42 |
+
"### Task",
|
| 43 |
+
f"Write a Python function named `{target_name}` that reproduces the hidden "
|
| 44 |
+
"function's behaviour. Return ONLY the function definition in a single "
|
| 45 |
+
"```python ... ``` code block. Do not add explanations.",
|
| 46 |
+
]
|
| 47 |
+
return "\n".join(lines)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
_CODE_RE = re.compile(r"```(?:python)?\s*(.*?)```", re.DOTALL | re.IGNORECASE)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def extract_code(completion: str) -> str:
|
| 54 |
+
"""Pull the python source from a model completion. If no fenced block is
|
| 55 |
+
present we fall back to the whole completion (the verifier will then judge
|
| 56 |
+
it on its own)."""
|
| 57 |
+
m = _CODE_RE.search(completion)
|
| 58 |
+
if m:
|
| 59 |
+
return m.group(1).strip()
|
| 60 |
+
return completion.strip()
|
opensleuth_train/reward.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Reward functions for GRPO: env-backed verification + cheap shaping signals.
|
| 2 |
+
|
| 3 |
+
GRPO takes a list of `reward_funcs`. Each must accept `completions` and any
|
| 4 |
+
columns from the dataset as kwargs, and return one scalar per completion.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
import re
|
| 11 |
+
from typing import List
|
| 12 |
+
|
| 13 |
+
from .client import EnvClient
|
| 14 |
+
from .prompt import extract_code
|
| 15 |
+
|
| 16 |
+
log = logging.getLogger("opensleuth.reward")
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def make_env_reward(client: EnvClient, *, scale: float = 1.0 / 100.0) -> callable:
|
| 20 |
+
"""Verifier-backed reward. Calls the env's `submit` and returns the env's
|
| 21 |
+
reward divided by `scale` (default: divide by 100 so a perfect submission
|
| 22 |
+
is ~+1.5 and a bad one is around -0.5; this keeps GRPO advantages well
|
| 23 |
+
behaved without needing reward normalisation).
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def env_reward(completions, target_function_name=None, row_seed=None, **kwargs): # noqa: ANN001
|
| 27 |
+
rewards: List[float] = []
|
| 28 |
+
# GRPO calls the reward fn once per completion; both target_function_name
|
| 29 |
+
# and row_seed come in as lists of length len(completions).
|
| 30 |
+
for i, completion in enumerate(completions):
|
| 31 |
+
text = _extract_text(completion)
|
| 32 |
+
code = extract_code(text)
|
| 33 |
+
tname = _index(target_function_name, i, default="fibonacci")
|
| 34 |
+
seed = _index(row_seed, i, default=0)
|
| 35 |
+
try:
|
| 36 |
+
env_reward_value = client.score_submission(tname, code, seed=seed)
|
| 37 |
+
except Exception as e: # noqa: BLE001
|
| 38 |
+
log.warning("env scoring failed for %s: %s", tname, e)
|
| 39 |
+
env_reward_value = -50.0
|
| 40 |
+
rewards.append(env_reward_value * scale)
|
| 41 |
+
return rewards
|
| 42 |
+
|
| 43 |
+
return env_reward
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
_FUNC_RE = re.compile(r"^def\s+(\w+)\s*\(", re.MULTILINE)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def format_reward(completions, target_function_name=None, **kwargs): # noqa: ANN001
|
| 50 |
+
"""Cheap shaping reward: +0.2 if the completion contains a fenced python
|
| 51 |
+
block AND defines a function with the right name. Encourages the model to
|
| 52 |
+
converge on the expected output format quickly so the env reward becomes
|
| 53 |
+
informative early in training."""
|
| 54 |
+
rewards: List[float] = []
|
| 55 |
+
for i, completion in enumerate(completions):
|
| 56 |
+
text = _extract_text(completion)
|
| 57 |
+
score = 0.0
|
| 58 |
+
if "```python" in text or "```\n" in text:
|
| 59 |
+
score += 0.1
|
| 60 |
+
code = extract_code(text)
|
| 61 |
+
m = _FUNC_RE.search(code)
|
| 62 |
+
tname = _index(target_function_name, i, default=None)
|
| 63 |
+
if m and (tname is None or m.group(1) == tname):
|
| 64 |
+
score += 0.1
|
| 65 |
+
rewards.append(score)
|
| 66 |
+
return rewards
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def _extract_text(completion): # noqa: ANN001
|
| 70 |
+
"""GRPO can pass either a string or an OpenAI-style chat list of dicts.
|
| 71 |
+
Normalise to a single string."""
|
| 72 |
+
if isinstance(completion, str):
|
| 73 |
+
return completion
|
| 74 |
+
if isinstance(completion, list):
|
| 75 |
+
# [{role:..., content:...}, ...]
|
| 76 |
+
parts = []
|
| 77 |
+
for msg in completion:
|
| 78 |
+
if isinstance(msg, dict) and "content" in msg:
|
| 79 |
+
parts.append(str(msg["content"]))
|
| 80 |
+
else:
|
| 81 |
+
parts.append(str(msg))
|
| 82 |
+
return "\n".join(parts)
|
| 83 |
+
return str(completion)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _index(value, i: int, default):
|
| 87 |
+
if value is None:
|
| 88 |
+
return default
|
| 89 |
+
if isinstance(value, list):
|
| 90 |
+
return value[i] if i < len(value) else default
|
| 91 |
+
return value
|
requirements.txt
CHANGED
|
@@ -1,3 +1,20 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core ML stack. torch is provided by the base image.
|
| 2 |
+
transformers==4.46.3
|
| 3 |
+
trl==0.13.0
|
| 4 |
+
peft==0.13.2
|
| 5 |
+
accelerate==1.1.1
|
| 6 |
+
bitsandbytes==0.44.1
|
| 7 |
+
datasets==3.1.0
|
| 8 |
+
|
| 9 |
+
# Tokenizer + utility deps
|
| 10 |
+
sentencepiece==0.2.0
|
| 11 |
+
tiktoken==0.8.0
|
| 12 |
+
einops==0.8.0
|
| 13 |
+
safetensors==0.4.5
|
| 14 |
+
|
| 15 |
+
# HTTP + Hub
|
| 16 |
+
requests==2.32.3
|
| 17 |
+
huggingface_hub==0.26.2
|
| 18 |
+
|
| 19 |
+
# Misc
|
| 20 |
+
numpy==1.26.4
|
train.py
CHANGED
|
@@ -1,157 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
-
import
|
| 3 |
-
from transformers import AutoTokenizer
|
| 4 |
-
from
|
| 5 |
-
|
| 6 |
-
import
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
""
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
"""
|
| 37 |
-
|
| 38 |
-
"""
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
)
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
)
|
| 75 |
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
)
|
| 87 |
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
)
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
# Store experience
|
| 133 |
-
queries.append(query_tensor.squeeze())
|
| 134 |
-
responses.append(response_tensor.squeeze())
|
| 135 |
-
rewards.append(reward)
|
| 136 |
-
|
| 137 |
-
print(f"Step {step+1}: Action: {action['action_type']}, Reward: {reward.item():.2f}")
|
| 138 |
-
|
| 139 |
-
if done:
|
| 140 |
-
break
|
| 141 |
-
|
| 142 |
-
# --- Perform PPO Step ---
|
| 143 |
-
# This is a simplified view. The actual step requires careful handling of tensors.
|
| 144 |
-
# The `queries`, `responses`, `rewards` lists need to be formatted correctly.
|
| 145 |
-
try:
|
| 146 |
-
stats = gppo_trainer.step(queries, responses, rewards)
|
| 147 |
-
gppo_trainer.log_stats(stats, {}, rewards)
|
| 148 |
-
print(f" PPO Step done. Mean reward: {stats['ppo/returns/mean']:.2f}")
|
| 149 |
-
except Exception as e:
|
| 150 |
-
print(f"ERROR during trainer.step: {e}")
|
| 151 |
-
print(" Skipping PPO step for this episode. This might happen if all trajectories are truncated.")
|
| 152 |
|
| 153 |
|
| 154 |
if __name__ == "__main__":
|
| 155 |
-
|
| 156 |
-
# We will run the server in the background from the CLI.
|
| 157 |
-
main()
|
|
|
|
| 1 |
+
"""OpenSleuth GRPO trainer.
|
| 2 |
+
|
| 3 |
+
Trains a small Qwen2.5 model with TRL's GRPOTrainer to do in-context program
|
| 4 |
+
synthesis — given the public signature of a hidden function plus a handful of
|
| 5 |
+
(input, output) probe examples, emit a Python function that reproduces it.
|
| 6 |
+
|
| 7 |
+
Reward comes from the live OpenSleuth env Space: the agent's code is executed
|
| 8 |
+
against the hidden reference under domain-aware fuzzing, and the verifier
|
| 9 |
+
returns an `execution_reward - complexity_penalty` score that we hand back to
|
| 10 |
+
GRPO as the per-completion reward (plus a tiny formatting shaping reward).
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
import logging
|
| 17 |
+
import os
|
| 18 |
+
import sys
|
| 19 |
+
import time
|
| 20 |
+
|
| 21 |
import torch
|
| 22 |
+
from peft import LoraConfig
|
| 23 |
+
from transformers import AutoTokenizer, BitsAndBytesConfig
|
| 24 |
+
from trl import GRPOConfig, GRPOTrainer
|
| 25 |
+
|
| 26 |
+
from opensleuth_train import (
|
| 27 |
+
EnvClient,
|
| 28 |
+
SYSTEM_PROMPT,
|
| 29 |
+
build_synthesis_dataset,
|
| 30 |
+
)
|
| 31 |
+
from opensleuth_train.reward import format_reward, make_env_reward
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
logging.basicConfig(
|
| 35 |
+
level=logging.INFO,
|
| 36 |
+
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
| 37 |
+
stream=sys.stdout,
|
| 38 |
+
)
|
| 39 |
+
log = logging.getLogger("opensleuth.train")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def parse_args() -> argparse.Namespace:
|
| 43 |
+
p = argparse.ArgumentParser()
|
| 44 |
+
p.add_argument("--env-url", default=os.environ.get("ENV_URL", "https://anugrah55-opensleuth-env-gemini-cli.hf.space"))
|
| 45 |
+
p.add_argument("--model-name", default=os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-0.5B-Instruct"))
|
| 46 |
+
p.add_argument("--output-dir", default=os.environ.get("OUTPUT_DIR", "/data/opensleuth-grpo"))
|
| 47 |
+
p.add_argument("--push-to-hub", default=os.environ.get("PUSH_TO_HUB", "anugrah55/opensleuth-qwen2.5-0.5b-grpo"))
|
| 48 |
+
p.add_argument("--n-per-function", type=int, default=int(os.environ.get("N_PER_FUNCTION", "16")))
|
| 49 |
+
p.add_argument("--n-probes", type=int, default=int(os.environ.get("N_PROBES", "6")))
|
| 50 |
+
p.add_argument("--num-generations", type=int, default=int(os.environ.get("NUM_GENERATIONS", "4")))
|
| 51 |
+
p.add_argument("--max-completion-length", type=int, default=int(os.environ.get("MAX_COMPLETION_LENGTH", "320")))
|
| 52 |
+
p.add_argument("--max-prompt-length", type=int, default=int(os.environ.get("MAX_PROMPT_LENGTH", "768")))
|
| 53 |
+
p.add_argument("--learning-rate", type=float, default=float(os.environ.get("LEARNING_RATE", "1e-5")))
|
| 54 |
+
p.add_argument("--num-train-epochs", type=float, default=float(os.environ.get("NUM_TRAIN_EPOCHS", "1")))
|
| 55 |
+
p.add_argument("--per-device-batch-size", type=int, default=int(os.environ.get("PER_DEVICE_BATCH_SIZE", "1")))
|
| 56 |
+
p.add_argument("--gradient-accumulation-steps", type=int, default=int(os.environ.get("GRAD_ACCUM", "8")))
|
| 57 |
+
p.add_argument("--no-4bit", action="store_true", default=os.environ.get("NO_4BIT", "0") == "1")
|
| 58 |
+
p.add_argument("--seed", type=int, default=int(os.environ.get("SEED", "42")))
|
| 59 |
+
return p.parse_args()
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def wait_for_env(client: EnvClient, max_wait_s: float = 300.0) -> None:
|
| 63 |
+
log.info("waiting for env at %s ...", client.base_url)
|
| 64 |
+
start = time.time()
|
| 65 |
+
last_err = ""
|
| 66 |
+
while time.time() - start < max_wait_s:
|
| 67 |
+
try:
|
| 68 |
+
h = client.health()
|
| 69 |
+
log.info("env healthy: %s", h)
|
| 70 |
+
return
|
| 71 |
+
except Exception as e: # noqa: BLE001
|
| 72 |
+
last_err = str(e)
|
| 73 |
+
time.sleep(5)
|
| 74 |
+
raise RuntimeError(f"env never became healthy after {max_wait_s}s. Last error: {last_err}")
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def main() -> int:
|
| 78 |
+
args = parse_args()
|
| 79 |
+
log.info("args: %s", vars(args))
|
| 80 |
+
|
| 81 |
+
client = EnvClient(base_url=args.env_url, timeout=60.0, retries=4)
|
| 82 |
+
wait_for_env(client)
|
| 83 |
+
fns = client.list_functions()
|
| 84 |
+
log.info("env exposes %d functions: %s", len(fns), [f["name"] for f in fns])
|
| 85 |
+
|
| 86 |
+
log.info("building synthesis dataset (n_per_function=%d, n_probes=%d)", args.n_per_function, args.n_probes)
|
| 87 |
+
dataset = build_synthesis_dataset(
|
| 88 |
+
client, n_per_function=args.n_per_function, n_probes=args.n_probes, seed=args.seed
|
| 89 |
)
|
| 90 |
+
log.info("dataset size: %d rows", len(dataset))
|
| 91 |
+
|
| 92 |
+
# GRPO with chat-templated prompts: each row needs a "prompt" field, which
|
| 93 |
+
# we re-format as a chat message list so the trainer applies the chat
|
| 94 |
+
# template under the hood.
|
| 95 |
+
def to_chat(row):
|
| 96 |
+
return {
|
| 97 |
+
"prompt": [
|
| 98 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 99 |
+
{"role": "user", "content": row["prompt"]},
|
| 100 |
+
],
|
| 101 |
+
"target_function_name": row["target_function_name"],
|
| 102 |
+
"row_seed": row["row_seed"],
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
dataset = dataset.map(to_chat, remove_columns=["prompt"])
|
| 106 |
+
|
| 107 |
+
# ---- Model + LoRA ----
|
| 108 |
+
log.info("loading model %s (4bit=%s)", args.model_name, not args.no_4bit)
|
| 109 |
+
bnb_config = None
|
| 110 |
+
if not args.no_4bit:
|
| 111 |
+
bnb_config = BitsAndBytesConfig(
|
| 112 |
+
load_in_4bit=True,
|
| 113 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 114 |
+
bnb_4bit_use_double_quant=True,
|
| 115 |
+
bnb_4bit_quant_type="nf4",
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True)
|
| 119 |
+
if tokenizer.pad_token is None:
|
| 120 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 121 |
+
|
| 122 |
+
peft_config = LoraConfig(
|
| 123 |
+
r=16,
|
| 124 |
+
lora_alpha=32,
|
| 125 |
+
lora_dropout=0.05,
|
| 126 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
| 127 |
+
task_type="CAUSAL_LM",
|
| 128 |
+
bias="none",
|
| 129 |
)
|
| 130 |
|
| 131 |
+
grpo_config = GRPOConfig(
|
| 132 |
+
output_dir=args.output_dir,
|
| 133 |
+
per_device_train_batch_size=args.per_device_batch_size,
|
| 134 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 135 |
+
learning_rate=args.learning_rate,
|
| 136 |
+
num_train_epochs=args.num_train_epochs,
|
| 137 |
+
max_prompt_length=args.max_prompt_length,
|
| 138 |
+
max_completion_length=args.max_completion_length,
|
| 139 |
+
num_generations=args.num_generations,
|
| 140 |
+
beta=0.04,
|
| 141 |
+
bf16=torch.cuda.is_bf16_supported() if torch.cuda.is_available() else False,
|
| 142 |
+
fp16=False,
|
| 143 |
+
logging_steps=1,
|
| 144 |
+
save_steps=50,
|
| 145 |
+
save_total_limit=2,
|
| 146 |
+
report_to=[],
|
| 147 |
+
seed=args.seed,
|
| 148 |
+
push_to_hub=bool(args.push_to_hub) and bool(os.environ.get("HF_TOKEN")),
|
| 149 |
+
hub_model_id=args.push_to_hub or None,
|
| 150 |
+
hub_strategy="end",
|
| 151 |
+
gradient_checkpointing=True,
|
| 152 |
)
|
| 153 |
|
| 154 |
+
env_reward_fn = make_env_reward(client)
|
| 155 |
+
env_reward_fn.__name__ = "env_verifier_reward"
|
| 156 |
+
format_reward.__name__ = "format_reward"
|
| 157 |
+
|
| 158 |
+
log.info("instantiating GRPOTrainer")
|
| 159 |
+
# Newer TRL passes the model name and instantiates internally; this works
|
| 160 |
+
# across recent TRL versions because GRPOTrainer accepts a model id string.
|
| 161 |
+
trainer_kwargs = dict(
|
| 162 |
+
model=args.model_name,
|
| 163 |
+
reward_funcs=[env_reward_fn, format_reward],
|
| 164 |
+
args=grpo_config,
|
| 165 |
+
train_dataset=dataset,
|
| 166 |
+
peft_config=peft_config,
|
| 167 |
)
|
| 168 |
+
if bnb_config is not None:
|
| 169 |
+
# Some TRL versions accept model_init_kwargs to pass through to from_pretrained.
|
| 170 |
+
trainer_kwargs.setdefault("model_init_kwargs", {})
|
| 171 |
+
trainer_kwargs["model_init_kwargs"].update(
|
| 172 |
+
{"quantization_config": bnb_config, "torch_dtype": torch.bfloat16}
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
try:
|
| 176 |
+
trainer = GRPOTrainer(**trainer_kwargs)
|
| 177 |
+
except TypeError as e:
|
| 178 |
+
# Older TRL (<0.16) doesn't accept model_init_kwargs at GRPOTrainer level;
|
| 179 |
+
# fall back to loading model first.
|
| 180 |
+
log.warning("GRPOTrainer rejected kwargs (%s); falling back to manual model load", e)
|
| 181 |
+
from transformers import AutoModelForCausalLM
|
| 182 |
+
model_kwargs = {"trust_remote_code": True, "torch_dtype": torch.bfloat16}
|
| 183 |
+
if bnb_config is not None:
|
| 184 |
+
model_kwargs["quantization_config"] = bnb_config
|
| 185 |
+
model = AutoModelForCausalLM.from_pretrained(args.model_name, **model_kwargs)
|
| 186 |
+
trainer = GRPOTrainer(
|
| 187 |
+
model=model,
|
| 188 |
+
reward_funcs=[env_reward_fn, format_reward],
|
| 189 |
+
args=grpo_config,
|
| 190 |
+
train_dataset=dataset,
|
| 191 |
+
peft_config=peft_config,
|
| 192 |
+
processing_class=tokenizer,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
log.info("starting GRPO training")
|
| 196 |
+
trainer.train()
|
| 197 |
+
log.info("training complete; saving to %s", args.output_dir)
|
| 198 |
+
trainer.save_model(args.output_dir)
|
| 199 |
+
if grpo_config.push_to_hub:
|
| 200 |
+
log.info("pushing to hub: %s", args.push_to_hub)
|
| 201 |
+
trainer.push_to_hub()
|
| 202 |
+
return 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
|
| 204 |
|
| 205 |
if __name__ == "__main__":
|
| 206 |
+
sys.exit(main())
|
|
|
|
|
|