Spaces:
Runtime error
Runtime error
Initial commit: OpenSleuth Colab quickstart notebook + Gradio landing page
Browse files- README.md +42 -5
- app.py +74 -0
- requirements.txt +1 -0
- train_opensleuth_grpo.ipynb +821 -0
README.md
CHANGED
|
@@ -1,12 +1,49 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
colorTo: green
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: OpenSleuth Colab
|
| 3 |
+
emoji: 🕵️
|
| 4 |
+
colorFrom: indigo
|
| 5 |
colorTo: green
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 4.44.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
license: apache-2.0
|
| 11 |
---
|
| 12 |
|
| 13 |
+
# OpenSleuth — Colab quickstart Space
|
| 14 |
+
|
| 15 |
+
This Space is a thin landing page for the [`train_opensleuth_grpo.ipynb`](./train_opensleuth_grpo.ipynb) notebook — the **minimum reproducible Colab** for training an OpenSleuth agent end-to-end against the live env Space.
|
| 16 |
+
|
| 17 |
+
## What is OpenSleuth?
|
| 18 |
+
|
| 19 |
+
An **Algorithmic Detective** RL environment. An LLM agent reverse-engineers an unknown black-box Python function by **probing** it with inputs and then **submitting** a Python replica. The environment scores submissions by domain-aware fuzz-testing against the hidden reference, with a complexity penalty so the agent can't just memorise its probes inside a giant `if/else`.
|
| 20 |
+
|
| 21 |
+
## Try it
|
| 22 |
+
|
| 23 |
+
Click the badge to open the notebook in Google Colab:
|
| 24 |
+
|
| 25 |
+
[](https://colab.research.google.com/#fileId=https%3A//huggingface.co/spaces/anugrah55/opensleuth-colab/blob/main/train_opensleuth_grpo.ipynb)
|
| 26 |
+
|
| 27 |
+
Or download `train_opensleuth_grpo.ipynb` from the **Files** tab and upload it to Colab manually. Set the runtime to **GPU → T4** and hit **Runtime → Run all** — end-to-end training completes in roughly 15 – 25 minutes on a free-tier T4 with the default Qwen2.5-0.5B-Instruct config.
|
| 28 |
+
|
| 29 |
+
## What the notebook does
|
| 30 |
+
|
| 31 |
+
1. Pip-installs the pinned trainer stack (`transformers==4.51.3`, `trl==0.16.1`, `peft==0.14.0`, `accelerate==1.4.0`, `bitsandbytes==0.45.5`, `datasets==3.3.2`).
|
| 32 |
+
2. Hits the live env Space [`anugrah55/opensleuth-env-gemini-cli`](https://huggingface.co/spaces/anugrah55/opensleuth-env-gemini-cli) at `https://anugrah55-opensleuth-env-gemini-cli.hf.space` to discover all 15 tasks (9 builtins + 6 from the Hub task dataset).
|
| 33 |
+
3. Builds a synthesis dataset where each row is `(signature + observed probes) → expected python implementation`.
|
| 34 |
+
4. Loads `Qwen2.5-0.5B-Instruct` in 4-bit + LoRA so it fits on a T4.
|
| 35 |
+
5. Trains with HF TRL's `GRPOTrainer` using a two-part reward:
|
| 36 |
+
- **env-verifier reward**: real fuzz-tested correctness against the hidden reference, with a complexity penalty.
|
| 37 |
+
- **format reward**: tiny shaping signal for emitting a fenced ```python``` code block with the right function name.
|
| 38 |
+
6. Optionally pushes the trained LoRA adapter to your own Hub account.
|
| 39 |
+
7. Runs a 3-episode smoke eval and prints the agent's emitted code.
|
| 40 |
+
|
| 41 |
+
## Links
|
| 42 |
+
|
| 43 |
+
- **Env Space (REST API the notebook calls):** https://huggingface.co/spaces/anugrah55/opensleuth-env-gemini-cli
|
| 44 |
+
- **Training Space (full 3B retrain):** https://huggingface.co/spaces/anugrah55/opensleuth-training-gemini-cli
|
| 45 |
+
- **Open-ended task catalog (Hub dataset):** https://huggingface.co/datasets/anugrah55/opensleuth-tasks
|
| 46 |
+
|
| 47 |
+
## License
|
| 48 |
+
|
| 49 |
+
Apache-2.0.
|
app.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tiny Gradio landing page for the OpenSleuth Colab notebook Space.
|
| 2 |
+
|
| 3 |
+
The actual training happens in the notebook (`train_opensleuth_grpo.ipynb` in
|
| 4 |
+
this same repo, downloadable from the Files tab). This app just renders a
|
| 5 |
+
clickable Open-In-Colab card so visitors can launch it in one click.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import gradio as gr
|
| 11 |
+
|
| 12 |
+
NOTEBOOK_PATH = "train_opensleuth_grpo.ipynb"
|
| 13 |
+
SPACE_ID = "anugrah55/opensleuth-colab"
|
| 14 |
+
COLAB_URL = (
|
| 15 |
+
"https://colab.research.google.com/#fileId="
|
| 16 |
+
f"https%3A//huggingface.co/spaces/{SPACE_ID}/blob/main/{NOTEBOOK_PATH}"
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
LANDING_MD = f"""
|
| 20 |
+
# OpenSleuth — Colab quickstart
|
| 21 |
+
|
| 22 |
+
[]({COLAB_URL})
|
| 23 |
+
|
| 24 |
+
OpenSleuth is an *Algorithmic Detective* RL environment. An LLM agent reverse-engineers an unknown black-box Python function by probing it and then submitting a Python replica. The env fuzz-tests the submission against the hidden reference (with a complexity penalty) and returns a scalar reward.
|
| 25 |
+
|
| 26 |
+
This Space hosts the **minimum reproducible Colab notebook** for training an
|
| 27 |
+
agent against the live env Space using **HF TRL's `GRPOTrainer`** + **bnb-4bit**
|
| 28 |
+
+ **LoRA** on a free-tier Colab T4. End-to-end runtime: ~15 – 25 minutes.
|
| 29 |
+
|
| 30 |
+
### One-click training
|
| 31 |
+
|
| 32 |
+
1. Click the **Open in Colab** badge above (or grab `{NOTEBOOK_PATH}` from the **Files** tab and upload it to Colab manually).
|
| 33 |
+
2. In Colab: `Runtime → Change runtime type → GPU → T4`.
|
| 34 |
+
3. `Runtime → Run all`.
|
| 35 |
+
|
| 36 |
+
### Defaults
|
| 37 |
+
|
| 38 |
+
| Knob | Value |
|
| 39 |
+
|------|-------|
|
| 40 |
+
| Model | `Qwen/Qwen2.5-0.5B-Instruct` |
|
| 41 |
+
| Quant | bnb-4bit (nf4 + double-quant) |
|
| 42 |
+
| LoRA | r=16, alpha=32, q/k/v/o |
|
| 43 |
+
| Tasks | all 15 from `anugrah55/opensleuth-tasks` |
|
| 44 |
+
| GRPO `num_generations` | 4 |
|
| 45 |
+
| Epochs | 1 |
|
| 46 |
+
|
| 47 |
+
### Links
|
| 48 |
+
|
| 49 |
+
- **Env Space (REST API the notebook calls):** https://huggingface.co/spaces/anugrah55/opensleuth-env-gemini-cli
|
| 50 |
+
- **Training Space (full 3B retrain):** https://huggingface.co/spaces/anugrah55/opensleuth-training-gemini-cli
|
| 51 |
+
- **Open-ended task catalog:** https://huggingface.co/datasets/anugrah55/opensleuth-tasks
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _open_colab() -> str:
|
| 56 |
+
return f"Opening Colab: {COLAB_URL}"
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
with gr.Blocks(title="OpenSleuth — Colab quickstart") as demo:
|
| 60 |
+
gr.Markdown(LANDING_MD)
|
| 61 |
+
with gr.Row():
|
| 62 |
+
gr.Button(
|
| 63 |
+
value="Open in Google Colab",
|
| 64 |
+
link=COLAB_URL,
|
| 65 |
+
variant="primary",
|
| 66 |
+
)
|
| 67 |
+
gr.Button(
|
| 68 |
+
value="View notebook in Files tab",
|
| 69 |
+
link=f"https://huggingface.co/spaces/{SPACE_ID}/blob/main/{NOTEBOOK_PATH}",
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
if __name__ == "__main__":
|
| 74 |
+
demo.launch()
|
requirements.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
gradio==4.44.0
|
train_opensleuth_grpo.ipynb
ADDED
|
@@ -0,0 +1,821 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "7086c037",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"# OpenSleuth — GRPO training on a free-tier Colab T4\n",
|
| 9 |
+
"\n",
|
| 10 |
+
"[](https://colab.research.google.com/github/anugrah55/opensleuth/blob/main/colab/train_opensleuth_grpo.ipynb)\n",
|
| 11 |
+
"\n",
|
| 12 |
+
"**OpenSleuth** is an *Algorithmic Detective* RL environment. An LLM agent reverse-engineers an unknown black-box Python function by **probing** it with inputs and then **submitting** a Python replica. The environment scores submissions by domain-aware fuzz-testing against the hidden reference, with a complexity penalty so the agent can't just memorise its probes inside a giant `if/else`.\n",
|
| 13 |
+
"\n",
|
| 14 |
+
"This notebook trains a small open-weights model with HF TRL's [`GRPOTrainer`](https://huggingface.co/docs/trl/en/grpo_trainer) against the **live** OpenSleuth environment Space. It is sized to complete end-to-end on a **free-tier Colab T4** (16 GB GPU) in roughly **15 – 25 minutes**.\n",
|
| 15 |
+
"\n",
|
| 16 |
+
"### Links\n",
|
| 17 |
+
"\n",
|
| 18 |
+
"- **Live env Space (this notebook calls it directly):** https://huggingface.co/spaces/anugrah55/opensleuth-env-gemini-cli — REST API at `https://anugrah55-opensleuth-env-gemini-cli.hf.space`\n",
|
| 19 |
+
"- **Open-ended task catalog (Hub dataset, 15 tasks):** https://huggingface.co/datasets/anugrah55/opensleuth-tasks\n",
|
| 20 |
+
"- **Repo / Spaces:** training Space `anugrah55/opensleuth-training-gemini-cli`, env Space `anugrah55/opensleuth-env-gemini-cli`\n",
|
| 21 |
+
"- **Blog (if/when published):** https://huggingface.co/blog/anugrah55/opensleuth\n",
|
| 22 |
+
"\n",
|
| 23 |
+
"### What this notebook does\n",
|
| 24 |
+
"\n",
|
| 25 |
+
"1. Installs pinned versions of `transformers`, `trl`, `peft`, `bitsandbytes`, `accelerate`, `datasets`.\n",
|
| 26 |
+
"2. Hits the env's `/tasks` endpoint to discover all 15 tasks (9 builtins + 6 Hub-driven, both open-ended).\n",
|
| 27 |
+
"3. Builds a synthesis dataset where each row is `(signature + observed probes) → expected python implementation`.\n",
|
| 28 |
+
"4. Loads **Qwen2.5-0.5B-Instruct** in 4-bit + LoRA so it fits comfortably on a T4.\n",
|
| 29 |
+
"5. Trains with GRPO using a **two-part reward**: env-verifier score (real fuzz-tested correctness, capped by complexity) plus a tiny formatting shaping reward.\n",
|
| 30 |
+
"6. Optionally pushes the trained adapter to the Hub.\n",
|
| 31 |
+
"7. Runs a 3-episode smoke eval against the live env and prints the agent's emitted code.\n",
|
| 32 |
+
"\n",
|
| 33 |
+
"> The full 3B retrain runs separately on a Hugging Face Space; this notebook is the **minimum reproducible Colab** required by the hackathon spec."
|
| 34 |
+
]
|
| 35 |
+
},
|
| 36 |
+
{
|
| 37 |
+
"cell_type": "code",
|
| 38 |
+
"execution_count": null,
|
| 39 |
+
"id": "9307eb3f",
|
| 40 |
+
"metadata": {},
|
| 41 |
+
"outputs": [],
|
| 42 |
+
"source": [
|
| 43 |
+
"# Pinned to /training/requirements.txt so the env-side reward stays in lockstep.\n",
|
| 44 |
+
"# trl 0.16.x is required for the modern GRPOTrainer / GRPOConfig API.\n",
|
| 45 |
+
"!pip install --quiet \\\n",
|
| 46 |
+
" \"transformers==4.51.3\" \\\n",
|
| 47 |
+
" \"trl==0.16.1\" \\\n",
|
| 48 |
+
" \"peft==0.14.0\" \\\n",
|
| 49 |
+
" \"accelerate==1.4.0\" \\\n",
|
| 50 |
+
" \"bitsandbytes==0.45.5\" \\\n",
|
| 51 |
+
" \"datasets==3.3.2\" \\\n",
|
| 52 |
+
" \"huggingface_hub>=0.30.2,<1.0\" \\\n",
|
| 53 |
+
" \"requests>=2.32.3\"\n",
|
| 54 |
+
"print(\"deps installed\")"
|
| 55 |
+
]
|
| 56 |
+
},
|
| 57 |
+
{
|
| 58 |
+
"cell_type": "code",
|
| 59 |
+
"execution_count": null,
|
| 60 |
+
"id": "bb6ecbad",
|
| 61 |
+
"metadata": {},
|
| 62 |
+
"outputs": [],
|
| 63 |
+
"source": [
|
| 64 |
+
"# OPTIONAL: log in so you can push the trained adapter to your own HF account.\n",
|
| 65 |
+
"# Skip this cell entirely if you only want to train + smoke-eval locally in Colab.\n",
|
| 66 |
+
"from huggingface_hub import notebook_login\n",
|
| 67 |
+
"notebook_login()"
|
| 68 |
+
]
|
| 69 |
+
},
|
| 70 |
+
{
|
| 71 |
+
"cell_type": "code",
|
| 72 |
+
"execution_count": null,
|
| 73 |
+
"id": "6c81d26f",
|
| 74 |
+
"metadata": {},
|
| 75 |
+
"outputs": [],
|
| 76 |
+
"source": [
|
| 77 |
+
"import logging\n",
|
| 78 |
+
"import os\n",
|
| 79 |
+
"import random\n",
|
| 80 |
+
"import re\n",
|
| 81 |
+
"import sys\n",
|
| 82 |
+
"import time\n",
|
| 83 |
+
"from typing import Any, Dict, Iterable, List, Optional, Sequence\n",
|
| 84 |
+
"\n",
|
| 85 |
+
"import requests\n",
|
| 86 |
+
"import torch\n",
|
| 87 |
+
"from datasets import Dataset\n",
|
| 88 |
+
"\n",
|
| 89 |
+
"logging.basicConfig(\n",
|
| 90 |
+
" level=logging.INFO,\n",
|
| 91 |
+
" format=\"%(asctime)s %(levelname)s %(name)s: %(message)s\",\n",
|
| 92 |
+
" stream=sys.stdout,\n",
|
| 93 |
+
")\n",
|
| 94 |
+
"log = logging.getLogger(\"opensleuth.colab\")\n",
|
| 95 |
+
"\n",
|
| 96 |
+
"# ---------------------------------------------------------------------------\n",
|
| 97 |
+
"# Constants you might tweak\n",
|
| 98 |
+
"# ---------------------------------------------------------------------------\n",
|
| 99 |
+
"\n",
|
| 100 |
+
"# Live OpenSleuth env Space.\n",
|
| 101 |
+
"ENV_URL = \"https://anugrah55-opensleuth-env-gemini-cli.hf.space\"\n",
|
| 102 |
+
"\n",
|
| 103 |
+
"# Default to the 0.5B Qwen so this completes in ~15-25 min on a free-tier T4.\n",
|
| 104 |
+
"# Bump to \"Qwen/Qwen2.5-1.5B-Instruct\" or \"Qwen/Qwen2.5-3B-Instruct\" if you have\n",
|
| 105 |
+
"# Colab Pro / a beefier GPU. The reward + dataset code is model-agnostic.\n",
|
| 106 |
+
"MODEL_NAME = \"Qwen/Qwen2.5-0.5B-Instruct\"\n",
|
| 107 |
+
"\n",
|
| 108 |
+
"# Where to dump checkpoints + the final adapter.\n",
|
| 109 |
+
"OUTPUT_DIR = \"./opensleuth-grpo-colab\"\n",
|
| 110 |
+
"\n",
|
| 111 |
+
"# Set to your own Hub repo to push (e.g. \"your-username/opensleuth-grpo-colab\").\n",
|
| 112 |
+
"# Leave as None to skip pushing. Pushing requires a write token from the login\n",
|
| 113 |
+
"# cell above.\n",
|
| 114 |
+
"PUSH_TO_HUB_REPO: Optional[str] = None # e.g. \"anugrah55/opensleuth-grpo-colab\"\n",
|
| 115 |
+
"\n",
|
| 116 |
+
"# ---------------------------------------------------------------------------\n",
|
| 117 |
+
"# Hyperparameters (sized for free-tier T4: 16GB GPU, ~12hr session)\n",
|
| 118 |
+
"# ---------------------------------------------------------------------------\n",
|
| 119 |
+
"\n",
|
| 120 |
+
"# Per-task rollouts. Small so the dataset-build phase (which has to call the env\n",
|
| 121 |
+
"# /probe endpoint many times) finishes in a few minutes.\n",
|
| 122 |
+
"N_PER_FUNCTION = 8\n",
|
| 123 |
+
"N_PROBES = 6\n",
|
| 124 |
+
"\n",
|
| 125 |
+
"# GRPO knobs. num_generations=4 + per_device_batch_size=4 means each optimisation\n",
|
| 126 |
+
"# step uses one prompt and 4 sampled completions for the relative advantage,\n",
|
| 127 |
+
"# which is the minimum sensible GRPO batch.\n",
|
| 128 |
+
"NUM_GENERATIONS = 4\n",
|
| 129 |
+
"PER_DEVICE_BATCH_SIZE = 4\n",
|
| 130 |
+
"GRADIENT_ACCUMULATION_STEPS = 2\n",
|
| 131 |
+
"NUM_TRAIN_EPOCHS = 1.0\n",
|
| 132 |
+
"LEARNING_RATE = 1e-5\n",
|
| 133 |
+
"MAX_PROMPT_LENGTH = 1024\n",
|
| 134 |
+
"MAX_COMPLETION_LENGTH = 384\n",
|
| 135 |
+
"\n",
|
| 136 |
+
"SEED = 42\n",
|
| 137 |
+
"random.seed(SEED)\n",
|
| 138 |
+
"\n",
|
| 139 |
+
"print(\"env_url =\", ENV_URL)\n",
|
| 140 |
+
"print(\"model_name =\", MODEL_NAME)\n",
|
| 141 |
+
"print(\"output_dir =\", OUTPUT_DIR)\n",
|
| 142 |
+
"print(\"push_to_hub =\", PUSH_TO_HUB_REPO)\n",
|
| 143 |
+
"print(\"cuda_available =\", torch.cuda.is_available())\n",
|
| 144 |
+
"if torch.cuda.is_available():\n",
|
| 145 |
+
" print(\"gpu =\", torch.cuda.get_device_name(0))"
|
| 146 |
+
]
|
| 147 |
+
},
|
| 148 |
+
{
|
| 149 |
+
"cell_type": "code",
|
| 150 |
+
"execution_count": null,
|
| 151 |
+
"id": "fdd9c63b",
|
| 152 |
+
"metadata": {},
|
| 153 |
+
"outputs": [],
|
| 154 |
+
"source": [
|
| 155 |
+
"# Self-contained copy of /training/opensleuth_train/client.py so this notebook\n",
|
| 156 |
+
"# does not depend on pip-installing the trainer package.\n",
|
| 157 |
+
"\n",
|
| 158 |
+
"class EnvClient:\n",
|
| 159 |
+
" \"\"\"Thin HTTP client for the live OpenSleuth env Space.\"\"\"\n",
|
| 160 |
+
"\n",
|
| 161 |
+
" def __init__(self, base_url: str = ENV_URL, timeout: float = 60.0, retries: int = 4):\n",
|
| 162 |
+
" self.base_url = base_url.rstrip(\"/\")\n",
|
| 163 |
+
" self.timeout = timeout\n",
|
| 164 |
+
" self.retries = retries\n",
|
| 165 |
+
"\n",
|
| 166 |
+
" def _post(self, path: str, payload: Dict[str, Any]) -> Dict[str, Any]:\n",
|
| 167 |
+
" last_exc: Optional[Exception] = None\n",
|
| 168 |
+
" for attempt in range(self.retries):\n",
|
| 169 |
+
" try:\n",
|
| 170 |
+
" r = requests.post(\n",
|
| 171 |
+
" f\"{self.base_url}{path}\", json=payload, timeout=self.timeout\n",
|
| 172 |
+
" )\n",
|
| 173 |
+
" r.raise_for_status()\n",
|
| 174 |
+
" return r.json()\n",
|
| 175 |
+
" except (requests.RequestException, ValueError) as e:\n",
|
| 176 |
+
" last_exc = e\n",
|
| 177 |
+
" wait = 0.5 * (2 ** attempt)\n",
|
| 178 |
+
" log.warning(\"env POST %s failed (%s); retrying in %.1fs\", path, e, wait)\n",
|
| 179 |
+
" time.sleep(wait)\n",
|
| 180 |
+
" raise RuntimeError(\n",
|
| 181 |
+
" f\"env POST {path} failed after {self.retries} retries: {last_exc}\"\n",
|
| 182 |
+
" )\n",
|
| 183 |
+
"\n",
|
| 184 |
+
" def _get(self, path: str, params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:\n",
|
| 185 |
+
" last_exc: Optional[Exception] = None\n",
|
| 186 |
+
" for attempt in range(self.retries):\n",
|
| 187 |
+
" try:\n",
|
| 188 |
+
" r = requests.get(\n",
|
| 189 |
+
" f\"{self.base_url}{path}\", params=params, timeout=self.timeout\n",
|
| 190 |
+
" )\n",
|
| 191 |
+
" r.raise_for_status()\n",
|
| 192 |
+
" return r.json()\n",
|
| 193 |
+
" except (requests.RequestException, ValueError) as e:\n",
|
| 194 |
+
" last_exc = e\n",
|
| 195 |
+
" wait = 0.5 * (2 ** attempt)\n",
|
| 196 |
+
" log.warning(\"env GET %s failed (%s); retrying in %.1fs\", path, e, wait)\n",
|
| 197 |
+
" time.sleep(wait)\n",
|
| 198 |
+
" raise RuntimeError(\n",
|
| 199 |
+
" f\"env GET {path} failed after {self.retries} retries: {last_exc}\"\n",
|
| 200 |
+
" )\n",
|
| 201 |
+
"\n",
|
| 202 |
+
" def health(self) -> Dict[str, Any]:\n",
|
| 203 |
+
" r = requests.get(f\"{self.base_url}/health\", timeout=self.timeout)\n",
|
| 204 |
+
" r.raise_for_status()\n",
|
| 205 |
+
" return r.json()\n",
|
| 206 |
+
"\n",
|
| 207 |
+
" def list_functions(self) -> List[Dict[str, str]]:\n",
|
| 208 |
+
" \"\"\"Legacy /functions endpoint -- only the 9 builtin functions.\"\"\"\n",
|
| 209 |
+
" r = requests.get(f\"{self.base_url}/functions\", timeout=self.timeout)\n",
|
| 210 |
+
" r.raise_for_status()\n",
|
| 211 |
+
" return r.json()[\"functions\"]\n",
|
| 212 |
+
"\n",
|
| 213 |
+
" def list_tasks(\n",
|
| 214 |
+
" self,\n",
|
| 215 |
+
" source: str = \"all\",\n",
|
| 216 |
+
" difficulty: Optional[str] = None,\n",
|
| 217 |
+
" ) -> List[Dict[str, Any]]:\n",
|
| 218 |
+
" \"\"\"Live catalog: builtins + Hub-driven tasks.\"\"\"\n",
|
| 219 |
+
" params: Dict[str, Any] = {\"source\": source}\n",
|
| 220 |
+
" if difficulty:\n",
|
| 221 |
+
" params[\"difficulty\"] = difficulty\n",
|
| 222 |
+
" return self._get(\"/tasks\", params=params)[\"tasks\"]\n",
|
| 223 |
+
"\n",
|
| 224 |
+
" def sample_inputs(self, target_name: str, n: int = 8, seed: int = 0) -> List[str]:\n",
|
| 225 |
+
" \"\"\"Pull `n` ready-to-probe input_repr strings from the env's auto-fuzzer.\"\"\"\n",
|
| 226 |
+
" resp = self._get(\n",
|
| 227 |
+
" f\"/tasks/{target_name}/sample_inputs\",\n",
|
| 228 |
+
" params={\"n\": n, \"seed\": seed},\n",
|
| 229 |
+
" )\n",
|
| 230 |
+
" return list(resp[\"inputs\"])\n",
|
| 231 |
+
"\n",
|
| 232 |
+
" def reset(self, target_name: str, seed: int = 0, max_steps: int = 25) -> Dict[str, Any]:\n",
|
| 233 |
+
" return self._post(\n",
|
| 234 |
+
" \"/reset\",\n",
|
| 235 |
+
" {\"target_name\": target_name, \"seed\": seed, \"max_steps\": max_steps},\n",
|
| 236 |
+
" )\n",
|
| 237 |
+
"\n",
|
| 238 |
+
" def step(self, episode_id: str, action: Dict[str, Any]) -> Dict[str, Any]:\n",
|
| 239 |
+
" return self._post(\n",
|
| 240 |
+
" \"/step\", {\"episode_id\": episode_id, \"action\": action}\n",
|
| 241 |
+
" )\n",
|
| 242 |
+
"\n",
|
| 243 |
+
" def submit(self, episode_id: str, code: str) -> Dict[str, Any]:\n",
|
| 244 |
+
" return self.step(episode_id, {\"action_type\": \"submit\", \"code\": code})\n",
|
| 245 |
+
"\n",
|
| 246 |
+
" def probe(self, episode_id: str, input_repr: str) -> Dict[str, Any]:\n",
|
| 247 |
+
" return self.step(episode_id, {\"action_type\": \"probe\", \"input_repr\": input_repr})\n",
|
| 248 |
+
"\n",
|
| 249 |
+
" def score_submission(self, target_name: str, code: str, seed: int = 0) -> float:\n",
|
| 250 |
+
" \"\"\"One-shot: open an episode, submit the code, return total reward.\"\"\"\n",
|
| 251 |
+
" ep = self.reset(target_name=target_name, seed=seed, max_steps=2)\n",
|
| 252 |
+
" resp = self.submit(ep[\"episode_id\"], code)\n",
|
| 253 |
+
" return float(resp[\"reward\"])\n",
|
| 254 |
+
"\n",
|
| 255 |
+
"\n",
|
| 256 |
+
"client = EnvClient(base_url=ENV_URL)\n",
|
| 257 |
+
"print(\"EnvClient ready ->\", client.base_url)"
|
| 258 |
+
]
|
| 259 |
+
},
|
| 260 |
+
{
|
| 261 |
+
"cell_type": "code",
|
| 262 |
+
"execution_count": null,
|
| 263 |
+
"id": "c2e1c7e5",
|
| 264 |
+
"metadata": {},
|
| 265 |
+
"outputs": [],
|
| 266 |
+
"source": [
|
| 267 |
+
"# Self-contained, minimal copies of:\n",
|
| 268 |
+
"# /training/opensleuth_train/prompt.py\n",
|
| 269 |
+
"# /training/opensleuth_train/reward.py\n",
|
| 270 |
+
"# /training/opensleuth_train/dataset.py\n",
|
| 271 |
+
"# (kept in lockstep with the original modules; the env is the source of truth\n",
|
| 272 |
+
"# for which tasks exist and what probe inputs to use.)\n",
|
| 273 |
+
"\n",
|
| 274 |
+
"# --- prompt --------------------------------------------------------------\n",
|
| 275 |
+
"\n",
|
| 276 |
+
"SYSTEM_PROMPT = (\n",
|
| 277 |
+
" \"You are an algorithmic detective. You are given the public signature of a \"\n",
|
| 278 |
+
" \"hidden Python function plus several (input, output) examples observed by \"\n",
|
| 279 |
+
" \"probing it. Your job is to write a Python function that *exactly* \"\n",
|
| 280 |
+
" \"reproduces the hidden function's behavior on all valid inputs. Match its \"\n",
|
| 281 |
+
" \"return values AND its exception types on invalid inputs. Keep your \"\n",
|
| 282 |
+
" \"implementation as simple and clean as possible (it is penalised for \"\n",
|
| 283 |
+
" \"being needlessly branchy). Return ONLY the function definition wrapped \"\n",
|
| 284 |
+
" \"in a single ```python ... ``` code block.\"\n",
|
| 285 |
+
")\n",
|
| 286 |
+
"\n",
|
| 287 |
+
"\n",
|
| 288 |
+
"def build_prompt(target_name: str, signature: str, probes: Iterable[tuple]) -> str:\n",
|
| 289 |
+
" lines = [\n",
|
| 290 |
+
" f\"## Hidden function: {target_name}\",\n",
|
| 291 |
+
" \"\",\n",
|
| 292 |
+
" \"### Public signature & docstring\",\n",
|
| 293 |
+
" signature.strip() or \"(no signature provided)\",\n",
|
| 294 |
+
" \"\",\n",
|
| 295 |
+
" \"### Observed probes\",\n",
|
| 296 |
+
" ]\n",
|
| 297 |
+
" probe_list = list(probes)\n",
|
| 298 |
+
" if not probe_list:\n",
|
| 299 |
+
" lines.append(\"(none)\")\n",
|
| 300 |
+
" else:\n",
|
| 301 |
+
" for inp, out, is_err in probe_list:\n",
|
| 302 |
+
" tag = \"raises\" if is_err else \"returns\"\n",
|
| 303 |
+
" lines.append(f\"- input={inp} -> {tag} {out}\")\n",
|
| 304 |
+
" lines += [\n",
|
| 305 |
+
" \"\",\n",
|
| 306 |
+
" \"### Task\",\n",
|
| 307 |
+
" f\"Write a Python function named `{target_name}` that reproduces the hidden \"\n",
|
| 308 |
+
" \"function's behaviour. Return ONLY the function definition in a single \"\n",
|
| 309 |
+
" \"```python ... ``` code block. Do not add explanations.\",\n",
|
| 310 |
+
" ]\n",
|
| 311 |
+
" return \"\\n\".join(lines)\n",
|
| 312 |
+
"\n",
|
| 313 |
+
"\n",
|
| 314 |
+
"_CODE_RE = re.compile(r\"```(?:python)?\\s*(.*?)```\", re.DOTALL | re.IGNORECASE)\n",
|
| 315 |
+
"\n",
|
| 316 |
+
"\n",
|
| 317 |
+
"def extract_code(completion: str) -> str:\n",
|
| 318 |
+
" m = _CODE_RE.search(completion)\n",
|
| 319 |
+
" if m:\n",
|
| 320 |
+
" return m.group(1).strip()\n",
|
| 321 |
+
" return completion.strip()\n",
|
| 322 |
+
"\n",
|
| 323 |
+
"\n",
|
| 324 |
+
"# --- reward --------------------------------------------------------------\n",
|
| 325 |
+
"\n",
|
| 326 |
+
"_FUNC_RE = re.compile(r\"^def\\s+(\\w+)\\s*\\(\", re.MULTILINE)\n",
|
| 327 |
+
"\n",
|
| 328 |
+
"\n",
|
| 329 |
+
"def _extract_text(completion):\n",
|
| 330 |
+
" if isinstance(completion, str):\n",
|
| 331 |
+
" return completion\n",
|
| 332 |
+
" if isinstance(completion, list):\n",
|
| 333 |
+
" parts = []\n",
|
| 334 |
+
" for msg in completion:\n",
|
| 335 |
+
" if isinstance(msg, dict) and \"content\" in msg:\n",
|
| 336 |
+
" parts.append(str(msg[\"content\"]))\n",
|
| 337 |
+
" else:\n",
|
| 338 |
+
" parts.append(str(msg))\n",
|
| 339 |
+
" return \"\\n\".join(parts)\n",
|
| 340 |
+
" return str(completion)\n",
|
| 341 |
+
"\n",
|
| 342 |
+
"\n",
|
| 343 |
+
"def _index(value, i: int, default):\n",
|
| 344 |
+
" if value is None:\n",
|
| 345 |
+
" return default\n",
|
| 346 |
+
" if isinstance(value, list):\n",
|
| 347 |
+
" return value[i] if i < len(value) else default\n",
|
| 348 |
+
" return value\n",
|
| 349 |
+
"\n",
|
| 350 |
+
"\n",
|
| 351 |
+
"def make_env_reward(client: EnvClient, *, scale: float = 1.0 / 100.0):\n",
|
| 352 |
+
" \"\"\"Verifier-backed reward. Calls /step submit on the env and returns the\n",
|
| 353 |
+
" env's reward divided by `scale` (so a perfect submission ~= +1.5 and a bad\n",
|
| 354 |
+
" one ~= -0.5; keeps GRPO advantages well-behaved without normalisation).\"\"\"\n",
|
| 355 |
+
"\n",
|
| 356 |
+
" def env_reward(completions, target_function_name=None, row_seed=None, **kwargs):\n",
|
| 357 |
+
" rewards: List[float] = []\n",
|
| 358 |
+
" for i, completion in enumerate(completions):\n",
|
| 359 |
+
" text = _extract_text(completion)\n",
|
| 360 |
+
" code = extract_code(text)\n",
|
| 361 |
+
" tname = _index(target_function_name, i, default=\"fibonacci\")\n",
|
| 362 |
+
" seed = _index(row_seed, i, default=0)\n",
|
| 363 |
+
" try:\n",
|
| 364 |
+
" env_reward_value = client.score_submission(tname, code, seed=seed)\n",
|
| 365 |
+
" except Exception as e:\n",
|
| 366 |
+
" log.warning(\"env scoring failed for %s: %s\", tname, e)\n",
|
| 367 |
+
" env_reward_value = -50.0\n",
|
| 368 |
+
" rewards.append(env_reward_value * scale)\n",
|
| 369 |
+
" return rewards\n",
|
| 370 |
+
"\n",
|
| 371 |
+
" env_reward.__name__ = \"env_verifier_reward\"\n",
|
| 372 |
+
" return env_reward\n",
|
| 373 |
+
"\n",
|
| 374 |
+
"\n",
|
| 375 |
+
"def format_reward(completions, target_function_name=None, **kwargs):\n",
|
| 376 |
+
" \"\"\"Cheap shaping reward: +0.1 if the output has a fenced python block,\n",
|
| 377 |
+
" +0.1 more if it defines the right function name. Encourages the model to\n",
|
| 378 |
+
" converge on the expected output format quickly so the env-verifier reward\n",
|
| 379 |
+
" becomes informative early in training.\"\"\"\n",
|
| 380 |
+
" rewards: List[float] = []\n",
|
| 381 |
+
" for i, completion in enumerate(completions):\n",
|
| 382 |
+
" text = _extract_text(completion)\n",
|
| 383 |
+
" score = 0.0\n",
|
| 384 |
+
" if \"```python\" in text or \"```\\n\" in text:\n",
|
| 385 |
+
" score += 0.1\n",
|
| 386 |
+
" code = extract_code(text)\n",
|
| 387 |
+
" m = _FUNC_RE.search(code)\n",
|
| 388 |
+
" tname = _index(target_function_name, i, default=None)\n",
|
| 389 |
+
" if m and (tname is None or m.group(1) == tname):\n",
|
| 390 |
+
" score += 0.1\n",
|
| 391 |
+
" rewards.append(score)\n",
|
| 392 |
+
" return rewards\n",
|
| 393 |
+
"\n",
|
| 394 |
+
"\n",
|
| 395 |
+
"format_reward.__name__ = \"format_reward\"\n",
|
| 396 |
+
"\n",
|
| 397 |
+
"\n",
|
| 398 |
+
"# --- dataset -------------------------------------------------------------\n",
|
| 399 |
+
"\n",
|
| 400 |
+
"\n",
|
| 401 |
+
"def discover_functions(\n",
|
| 402 |
+
" client: EnvClient,\n",
|
| 403 |
+
" *,\n",
|
| 404 |
+
" source: str = \"all\",\n",
|
| 405 |
+
" include: Optional[Sequence[str]] = None,\n",
|
| 406 |
+
" difficulty: Optional[str] = None,\n",
|
| 407 |
+
") -> List[dict]:\n",
|
| 408 |
+
" \"\"\"Live task catalog from the env. `include` filters by name;\n",
|
| 409 |
+
" `difficulty` filters by easy/medium/hard.\"\"\"\n",
|
| 410 |
+
" tasks = client.list_tasks(source=source)\n",
|
| 411 |
+
" if difficulty and difficulty.lower() != \"all\":\n",
|
| 412 |
+
" tasks = [t for t in tasks if (t.get(\"difficulty\") or \"\").lower() == difficulty.lower()]\n",
|
| 413 |
+
" if include:\n",
|
| 414 |
+
" wanted = {n.strip() for n in include if n and n.strip()}\n",
|
| 415 |
+
" if wanted:\n",
|
| 416 |
+
" tasks = [t for t in tasks if t[\"name\"] in wanted]\n",
|
| 417 |
+
" if not tasks:\n",
|
| 418 |
+
" raise RuntimeError(\n",
|
| 419 |
+
" f\"discover_functions filtered to 0 tasks \"\n",
|
| 420 |
+
" f\"(source={source!r}, include={include!r}, difficulty={difficulty!r}).\"\n",
|
| 421 |
+
" )\n",
|
| 422 |
+
" return tasks\n",
|
| 423 |
+
"\n",
|
| 424 |
+
"\n",
|
| 425 |
+
"def _make_probe_inputs(\n",
|
| 426 |
+
" target_name: str,\n",
|
| 427 |
+
" rng: random.Random,\n",
|
| 428 |
+
" n: int,\n",
|
| 429 |
+
" *,\n",
|
| 430 |
+
" client: EnvClient,\n",
|
| 431 |
+
" seed: int,\n",
|
| 432 |
+
") -> List[str]:\n",
|
| 433 |
+
" \"\"\"Preferred path: ask the env for `n` ready-to-probe inputs via its\n",
|
| 434 |
+
" auto-fuzzer. Fallback (if the endpoint hiccups): submit literal \"1\" probes\n",
|
| 435 |
+
" so we at least populate `n` rows.\"\"\"\n",
|
| 436 |
+
" try:\n",
|
| 437 |
+
" return client.sample_inputs(target_name=target_name, n=n, seed=seed)\n",
|
| 438 |
+
" except Exception as e:\n",
|
| 439 |
+
" log.warning(\n",
|
| 440 |
+
" \"env sample_inputs(%s, n=%d, seed=%s) failed: %s; falling back to literals\",\n",
|
| 441 |
+
" target_name, n, seed, e,\n",
|
| 442 |
+
" )\n",
|
| 443 |
+
" return [\"1\"] * n\n",
|
| 444 |
+
"\n",
|
| 445 |
+
"\n",
|
| 446 |
+
"def _sample_probes(\n",
|
| 447 |
+
" client: EnvClient,\n",
|
| 448 |
+
" target_name: str,\n",
|
| 449 |
+
" seed: int,\n",
|
| 450 |
+
" n_probes: int,\n",
|
| 451 |
+
") -> tuple:\n",
|
| 452 |
+
" \"\"\"Open one episode and feed it `n_probes` random valid inputs sourced\n",
|
| 453 |
+
" from the env's own auto-fuzzer. Returns `(signature, history)`.\"\"\"\n",
|
| 454 |
+
" rng = random.Random(seed)\n",
|
| 455 |
+
" ep = client.reset(target_name=target_name, seed=seed, max_steps=n_probes + 5)\n",
|
| 456 |
+
" sig = ep[\"target_function_signature\"]\n",
|
| 457 |
+
" eid = ep[\"episode_id\"]\n",
|
| 458 |
+
"\n",
|
| 459 |
+
" inputs = _make_probe_inputs(\n",
|
| 460 |
+
" target_name, rng, n_probes, client=client, seed=seed,\n",
|
| 461 |
+
" )\n",
|
| 462 |
+
" history: List[tuple] = []\n",
|
| 463 |
+
" for inp_repr in inputs:\n",
|
| 464 |
+
" try:\n",
|
| 465 |
+
" resp = client.probe(eid, inp_repr)\n",
|
| 466 |
+
" except Exception as e:\n",
|
| 467 |
+
" log.warning(\"probe failed for %s with %r: %s\", target_name, inp_repr, e)\n",
|
| 468 |
+
" continue\n",
|
| 469 |
+
" last = resp[\"observation\"][\"probe_history\"][-1]\n",
|
| 470 |
+
" history.append(\n",
|
| 471 |
+
" (last[\"input_repr\"], last[\"output_repr\"], bool(last[\"is_error\"]))\n",
|
| 472 |
+
" )\n",
|
| 473 |
+
" return sig, history\n",
|
| 474 |
+
"\n",
|
| 475 |
+
"\n",
|
| 476 |
+
"def build_synthesis_dataset(\n",
|
| 477 |
+
" client: EnvClient,\n",
|
| 478 |
+
" *,\n",
|
| 479 |
+
" n_per_function: int,\n",
|
| 480 |
+
" n_probes: int = 6,\n",
|
| 481 |
+
" seed: int = 0,\n",
|
| 482 |
+
" include: Optional[Sequence[str]] = None,\n",
|
| 483 |
+
" difficulty: Optional[str] = None,\n",
|
| 484 |
+
" tasks: Optional[Iterable[dict]] = None,\n",
|
| 485 |
+
") -> Dataset:\n",
|
| 486 |
+
" \"\"\"Build a HuggingFace Dataset of {prompt, target_function_name} rows.\n",
|
| 487 |
+
"\n",
|
| 488 |
+
" Uniform-N variant of /training/opensleuth_train/dataset.py: every task\n",
|
| 489 |
+
" gets the same `n_per_function` rollouts. (The full trainer uses a\n",
|
| 490 |
+
" difficulty-weighted schedule; we keep the Colab variant simple so the\n",
|
| 491 |
+
" dataset-build phase fits in the free-tier session.)\"\"\"\n",
|
| 492 |
+
" if tasks is None:\n",
|
| 493 |
+
" tasks = discover_functions(\n",
|
| 494 |
+
" client, include=include, difficulty=difficulty,\n",
|
| 495 |
+
" )\n",
|
| 496 |
+
" tasks = list(tasks)\n",
|
| 497 |
+
" rows = []\n",
|
| 498 |
+
" rng = random.Random(seed)\n",
|
| 499 |
+
" log.info(\n",
|
| 500 |
+
" \"building dataset over %d task(s); n_per_function=%d n_probes=%d\",\n",
|
| 501 |
+
" len(tasks), n_per_function, n_probes,\n",
|
| 502 |
+
" )\n",
|
| 503 |
+
" for task in tasks:\n",
|
| 504 |
+
" fn_name = task[\"name\"]\n",
|
| 505 |
+
" diff = (task.get(\"difficulty\") or \"\").lower() or \"?\"\n",
|
| 506 |
+
" log.info(\n",
|
| 507 |
+
" \" %-22s difficulty=%-8s rollouts=%d source=%s\",\n",
|
| 508 |
+
" fn_name, diff, n_per_function, task.get(\"source\", \"?\"),\n",
|
| 509 |
+
" )\n",
|
| 510 |
+
" for _ in range(n_per_function):\n",
|
| 511 |
+
" row_seed = rng.randrange(0, 2 ** 31)\n",
|
| 512 |
+
" try:\n",
|
| 513 |
+
" sig, probes = _sample_probes(client, fn_name, row_seed, n_probes)\n",
|
| 514 |
+
" except Exception as e:\n",
|
| 515 |
+
" log.warning(\n",
|
| 516 |
+
" \"rollout build failed for %s seed=%d: %s; skipping row\",\n",
|
| 517 |
+
" fn_name, row_seed, e,\n",
|
| 518 |
+
" )\n",
|
| 519 |
+
" continue\n",
|
| 520 |
+
" prompt = build_prompt(fn_name, sig, probes)\n",
|
| 521 |
+
" rows.append(\n",
|
| 522 |
+
" {\n",
|
| 523 |
+
" \"prompt\": prompt,\n",
|
| 524 |
+
" \"target_function_name\": fn_name,\n",
|
| 525 |
+
" \"row_seed\": row_seed,\n",
|
| 526 |
+
" \"difficulty\": diff,\n",
|
| 527 |
+
" }\n",
|
| 528 |
+
" )\n",
|
| 529 |
+
" rng.shuffle(rows)\n",
|
| 530 |
+
" log.info(\"built dataset: %d rows total\", len(rows))\n",
|
| 531 |
+
" return Dataset.from_list(rows)\n",
|
| 532 |
+
"\n",
|
| 533 |
+
"\n",
|
| 534 |
+
"print(\"prompt + reward + dataset helpers loaded\")"
|
| 535 |
+
]
|
| 536 |
+
},
|
| 537 |
+
{
|
| 538 |
+
"cell_type": "code",
|
| 539 |
+
"execution_count": null,
|
| 540 |
+
"id": "88230844",
|
| 541 |
+
"metadata": {},
|
| 542 |
+
"outputs": [],
|
| 543 |
+
"source": [
|
| 544 |
+
"# Sanity-check the env, list every task it exposes, and build the dataset.\n",
|
| 545 |
+
"print(\"--- env health ---\")\n",
|
| 546 |
+
"print(client.health())\n",
|
| 547 |
+
"\n",
|
| 548 |
+
"print(\"\\n--- legacy /functions (9 builtins) ---\")\n",
|
| 549 |
+
"for f in client.list_functions():\n",
|
| 550 |
+
" print(\" -\", f.get(\"name\"), \"::\", f.get(\"signature\", \"\")[:60])\n",
|
| 551 |
+
"\n",
|
| 552 |
+
"print(\"\\n--- /tasks (full open-ended catalog) ---\")\n",
|
| 553 |
+
"all_tasks = client.list_tasks()\n",
|
| 554 |
+
"print(f\"total tasks: {len(all_tasks)}\")\n",
|
| 555 |
+
"for t in all_tasks:\n",
|
| 556 |
+
" print(f\" - {t['name']:<22} difficulty={t.get('difficulty', '?'):<6} source={t.get('source', '?')}\")\n",
|
| 557 |
+
"\n",
|
| 558 |
+
"print(\"\\n--- building synthesis dataset ---\")\n",
|
| 559 |
+
"dataset_raw = build_synthesis_dataset(\n",
|
| 560 |
+
" client,\n",
|
| 561 |
+
" n_per_function=N_PER_FUNCTION,\n",
|
| 562 |
+
" n_probes=N_PROBES,\n",
|
| 563 |
+
" seed=SEED,\n",
|
| 564 |
+
")\n",
|
| 565 |
+
"print(f\"\\ndataset rows: {len(dataset_raw)}\")\n",
|
| 566 |
+
"print(\"\\nsample row 0:\")\n",
|
| 567 |
+
"print(\" target_function_name =\", dataset_raw[0][\"target_function_name\"])\n",
|
| 568 |
+
"print(\" difficulty =\", dataset_raw[0][\"difficulty\"])\n",
|
| 569 |
+
"print(\" row_seed =\", dataset_raw[0][\"row_seed\"])\n",
|
| 570 |
+
"print(\" prompt:\")\n",
|
| 571 |
+
"print(\" \" + dataset_raw[0][\"prompt\"].replace(\"\\n\", \"\\n \"))"
|
| 572 |
+
]
|
| 573 |
+
},
|
| 574 |
+
{
|
| 575 |
+
"cell_type": "code",
|
| 576 |
+
"execution_count": null,
|
| 577 |
+
"id": "14ca2743",
|
| 578 |
+
"metadata": {},
|
| 579 |
+
"outputs": [],
|
| 580 |
+
"source": [
|
| 581 |
+
"from peft import LoraConfig\n",
|
| 582 |
+
"from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig\n",
|
| 583 |
+
"\n",
|
| 584 |
+
"bnb_config = BitsAndBytesConfig(\n",
|
| 585 |
+
" load_in_4bit=True,\n",
|
| 586 |
+
" bnb_4bit_compute_dtype=torch.bfloat16,\n",
|
| 587 |
+
" bnb_4bit_use_double_quant=True,\n",
|
| 588 |
+
" bnb_4bit_quant_type=\"nf4\",\n",
|
| 589 |
+
")\n",
|
| 590 |
+
"\n",
|
| 591 |
+
"print(f\"loading tokenizer for {MODEL_NAME} ...\")\n",
|
| 592 |
+
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)\n",
|
| 593 |
+
"if tokenizer.pad_token is None:\n",
|
| 594 |
+
" tokenizer.pad_token = tokenizer.eos_token\n",
|
| 595 |
+
"\n",
|
| 596 |
+
"print(f\"loading model {MODEL_NAME} in 4-bit ...\")\n",
|
| 597 |
+
"model = AutoModelForCausalLM.from_pretrained(\n",
|
| 598 |
+
" MODEL_NAME,\n",
|
| 599 |
+
" quantization_config=bnb_config,\n",
|
| 600 |
+
" torch_dtype=torch.bfloat16,\n",
|
| 601 |
+
" trust_remote_code=True,\n",
|
| 602 |
+
" device_map=\"auto\",\n",
|
| 603 |
+
")\n",
|
| 604 |
+
"\n",
|
| 605 |
+
"peft_config = LoraConfig(\n",
|
| 606 |
+
" r=16,\n",
|
| 607 |
+
" lora_alpha=32,\n",
|
| 608 |
+
" lora_dropout=0.05,\n",
|
| 609 |
+
" target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\"],\n",
|
| 610 |
+
" task_type=\"CAUSAL_LM\",\n",
|
| 611 |
+
" bias=\"none\",\n",
|
| 612 |
+
")\n",
|
| 613 |
+
"print(\"model + LoRA config ready\")\n",
|
| 614 |
+
"print(\"model device map:\", {k: str(v) for k, v in (model.hf_device_map or {}).items()} if hasattr(model, \"hf_device_map\") and model.hf_device_map else \"single-device\")"
|
| 615 |
+
]
|
| 616 |
+
},
|
| 617 |
+
{
|
| 618 |
+
"cell_type": "code",
|
| 619 |
+
"execution_count": null,
|
| 620 |
+
"id": "202de2fb",
|
| 621 |
+
"metadata": {},
|
| 622 |
+
"outputs": [],
|
| 623 |
+
"source": [
|
| 624 |
+
"from trl import GRPOConfig, GRPOTrainer\n",
|
| 625 |
+
"\n",
|
| 626 |
+
"# Wrap each row as a chat-template prompt list. GRPOTrainer applies the chat\n",
|
| 627 |
+
"# template under the hood when \"prompt\" is a list of messages.\n",
|
| 628 |
+
"def to_chat(row):\n",
|
| 629 |
+
" return {\n",
|
| 630 |
+
" \"prompt\": [\n",
|
| 631 |
+
" {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
|
| 632 |
+
" {\"role\": \"user\", \"content\": row[\"prompt\"]},\n",
|
| 633 |
+
" ],\n",
|
| 634 |
+
" \"target_function_name\": row[\"target_function_name\"],\n",
|
| 635 |
+
" \"row_seed\": row[\"row_seed\"],\n",
|
| 636 |
+
" }\n",
|
| 637 |
+
"\n",
|
| 638 |
+
"drop_cols = [c for c in (\"prompt\", \"difficulty\") if c in dataset_raw.column_names]\n",
|
| 639 |
+
"dataset = dataset_raw.map(to_chat, remove_columns=drop_cols)\n",
|
| 640 |
+
"print(\"dataset columns after chat-format:\", dataset.column_names)\n",
|
| 641 |
+
"print(\"rows:\", len(dataset))\n",
|
| 642 |
+
"\n",
|
| 643 |
+
"# GRPO requires per_device_train_batch_size to be a multiple of num_generations\n",
|
| 644 |
+
"# (one prompt is repeated num_generations times in the same forward pass).\n",
|
| 645 |
+
"assert PER_DEVICE_BATCH_SIZE % NUM_GENERATIONS == 0, (\n",
|
| 646 |
+
" f\"PER_DEVICE_BATCH_SIZE ({PER_DEVICE_BATCH_SIZE}) must be a multiple of \"\n",
|
| 647 |
+
" f\"NUM_GENERATIONS ({NUM_GENERATIONS}).\"\n",
|
| 648 |
+
")\n",
|
| 649 |
+
"\n",
|
| 650 |
+
"grpo_config = GRPOConfig(\n",
|
| 651 |
+
" output_dir=OUTPUT_DIR,\n",
|
| 652 |
+
" per_device_train_batch_size=PER_DEVICE_BATCH_SIZE,\n",
|
| 653 |
+
" gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,\n",
|
| 654 |
+
" learning_rate=LEARNING_RATE,\n",
|
| 655 |
+
" num_train_epochs=NUM_TRAIN_EPOCHS,\n",
|
| 656 |
+
" max_prompt_length=MAX_PROMPT_LENGTH,\n",
|
| 657 |
+
" max_completion_length=MAX_COMPLETION_LENGTH,\n",
|
| 658 |
+
" num_generations=NUM_GENERATIONS,\n",
|
| 659 |
+
" beta=0.04,\n",
|
| 660 |
+
" bf16=torch.cuda.is_bf16_supported() if torch.cuda.is_available() else False,\n",
|
| 661 |
+
" fp16=False,\n",
|
| 662 |
+
" logging_steps=1,\n",
|
| 663 |
+
" save_steps=50,\n",
|
| 664 |
+
" save_total_limit=2,\n",
|
| 665 |
+
" report_to=[],\n",
|
| 666 |
+
" seed=SEED,\n",
|
| 667 |
+
" push_to_hub=bool(PUSH_TO_HUB_REPO),\n",
|
| 668 |
+
" hub_model_id=PUSH_TO_HUB_REPO,\n",
|
| 669 |
+
" hub_strategy=\"end\",\n",
|
| 670 |
+
" gradient_checkpointing=True,\n",
|
| 671 |
+
")\n",
|
| 672 |
+
"\n",
|
| 673 |
+
"env_reward_fn = make_env_reward(client)\n",
|
| 674 |
+
"\n",
|
| 675 |
+
"trainer = GRPOTrainer(\n",
|
| 676 |
+
" model=model,\n",
|
| 677 |
+
" reward_funcs=[env_reward_fn, format_reward],\n",
|
| 678 |
+
" args=grpo_config,\n",
|
| 679 |
+
" train_dataset=dataset,\n",
|
| 680 |
+
" peft_config=peft_config,\n",
|
| 681 |
+
" processing_class=tokenizer,\n",
|
| 682 |
+
")\n",
|
| 683 |
+
"print(\"GRPOTrainer ready. Steps per epoch (approx):\",\n",
|
| 684 |
+
" max(1, len(dataset) // (PER_DEVICE_BATCH_SIZE // NUM_GENERATIONS) // GRADIENT_ACCUMULATION_STEPS))"
|
| 685 |
+
]
|
| 686 |
+
},
|
| 687 |
+
{
|
| 688 |
+
"cell_type": "code",
|
| 689 |
+
"execution_count": null,
|
| 690 |
+
"id": "03875ee7",
|
| 691 |
+
"metadata": {},
|
| 692 |
+
"outputs": [],
|
| 693 |
+
"source": [
|
| 694 |
+
"# Kick off training. On a free-tier T4 with the defaults above this should\n",
|
| 695 |
+
"# take roughly 15-25 minutes for one epoch over the 15-task catalog.\n",
|
| 696 |
+
"# You'll see GRPO logging every step: reward/env_verifier_reward, reward/format_reward,\n",
|
| 697 |
+
"# rewards/std, kl, loss, etc.\n",
|
| 698 |
+
"trainer.train()\n",
|
| 699 |
+
"print(\"training complete.\")"
|
| 700 |
+
]
|
| 701 |
+
},
|
| 702 |
+
{
|
| 703 |
+
"cell_type": "code",
|
| 704 |
+
"execution_count": null,
|
| 705 |
+
"id": "7bd608a9",
|
| 706 |
+
"metadata": {},
|
| 707 |
+
"outputs": [],
|
| 708 |
+
"source": [
|
| 709 |
+
"trainer.save_model(OUTPUT_DIR)\n",
|
| 710 |
+
"print(f\"adapter saved to {OUTPUT_DIR}\")\n",
|
| 711 |
+
"\n",
|
| 712 |
+
"# Optional push. To enable, set PUSH_TO_HUB_REPO above to e.g.\n",
|
| 713 |
+
"# \"your-username/opensleuth-grpo-colab\"\n",
|
| 714 |
+
"# and re-run this cell after a successful notebook_login() above.\n",
|
| 715 |
+
"if PUSH_TO_HUB_REPO:\n",
|
| 716 |
+
" print(f\"pushing to hub: {PUSH_TO_HUB_REPO}\")\n",
|
| 717 |
+
" trainer.push_to_hub()\n",
|
| 718 |
+
" print(\"push complete.\")\n",
|
| 719 |
+
"else:\n",
|
| 720 |
+
" print(\"PUSH_TO_HUB_REPO is None -- skipping hub push. \"\n",
|
| 721 |
+
" \"Set PUSH_TO_HUB_REPO at the top of the notebook to push.\")"
|
| 722 |
+
]
|
| 723 |
+
},
|
| 724 |
+
{
|
| 725 |
+
"cell_type": "code",
|
| 726 |
+
"execution_count": null,
|
| 727 |
+
"id": "a5ab224e",
|
| 728 |
+
"metadata": {},
|
| 729 |
+
"outputs": [],
|
| 730 |
+
"source": [
|
| 731 |
+
"# Smoke-eval: run 3 episodes against the live env using the just-trained\n",
|
| 732 |
+
"# adapter. Each episode probes a fresh function, generates a candidate\n",
|
| 733 |
+
"# implementation, submits it, and prints the env's reward + the emitted code.\n",
|
| 734 |
+
"\n",
|
| 735 |
+
"EVAL_TASKS = [\"fibonacci\", \"is_palindrome\", \"digit_sum\"]\n",
|
| 736 |
+
"EVAL_PROBES = 6\n",
|
| 737 |
+
"EVAL_MAX_NEW_TOKENS = 384\n",
|
| 738 |
+
"\n",
|
| 739 |
+
"\n",
|
| 740 |
+
"def _gen(prompt_text: str) -> str:\n",
|
| 741 |
+
" msgs = [\n",
|
| 742 |
+
" {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
|
| 743 |
+
" {\"role\": \"user\", \"content\": prompt_text},\n",
|
| 744 |
+
" ]\n",
|
| 745 |
+
" inputs = tokenizer.apply_chat_template(\n",
|
| 746 |
+
" msgs, return_tensors=\"pt\", add_generation_prompt=True\n",
|
| 747 |
+
" ).to(model.device)\n",
|
| 748 |
+
" with torch.no_grad():\n",
|
| 749 |
+
" out = model.generate(\n",
|
| 750 |
+
" inputs,\n",
|
| 751 |
+
" max_new_tokens=EVAL_MAX_NEW_TOKENS,\n",
|
| 752 |
+
" do_sample=False,\n",
|
| 753 |
+
" pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,\n",
|
| 754 |
+
" )\n",
|
| 755 |
+
" completion_ids = out[0, inputs.shape[1]:]\n",
|
| 756 |
+
" return tokenizer.decode(completion_ids, skip_special_tokens=True)\n",
|
| 757 |
+
"\n",
|
| 758 |
+
"\n",
|
| 759 |
+
"for task_name in EVAL_TASKS:\n",
|
| 760 |
+
" print(\"\\n\" + \"=\" * 70)\n",
|
| 761 |
+
" print(f\"=== task: {task_name} ===\")\n",
|
| 762 |
+
" sig, probes = _sample_probes(client, task_name, seed=SEED + hash(task_name) % 1000, n_probes=EVAL_PROBES)\n",
|
| 763 |
+
" user_prompt = build_prompt(task_name, sig, probes)\n",
|
| 764 |
+
" completion = _gen(user_prompt)\n",
|
| 765 |
+
" code = extract_code(completion)\n",
|
| 766 |
+
" try:\n",
|
| 767 |
+
" reward = client.score_submission(task_name, code, seed=SEED)\n",
|
| 768 |
+
" except Exception as e:\n",
|
| 769 |
+
" reward = float(\"nan\")\n",
|
| 770 |
+
" print(f\"score_submission failed: {e}\")\n",
|
| 771 |
+
" print(f\"env reward: {reward:.3f}\")\n",
|
| 772 |
+
" print(\"--- emitted code ---\")\n",
|
| 773 |
+
" print(code)"
|
| 774 |
+
]
|
| 775 |
+
},
|
| 776 |
+
{
|
| 777 |
+
"cell_type": "markdown",
|
| 778 |
+
"id": "728aaee9",
|
| 779 |
+
"metadata": {},
|
| 780 |
+
"source": [
|
| 781 |
+
"## Next steps\n",
|
| 782 |
+
"\n",
|
| 783 |
+
"You just trained a tiny LoRA adapter on top of `Qwen2.5-0.5B-Instruct` against the live OpenSleuth env. Some things to try next:\n",
|
| 784 |
+
"\n",
|
| 785 |
+
"- **Push to the Hub.** Set `PUSH_TO_HUB_REPO = \"your-username/opensleuth-grpo-colab\"` in the constants cell, re-run the login + save/push cells. The adapter is tiny (LoRA on q/k/v/o), so it pushes in seconds.\n",
|
| 786 |
+
"- **Train longer.** Bump `N_PER_FUNCTION` to `16-24` and `NUM_TRAIN_EPOCHS` to `2-3`. On a T4 this still fits inside one Colab session.\n",
|
| 787 |
+
"- **Step up to 3B.** Set `MODEL_NAME = \"Qwen/Qwen2.5-3B-Instruct\"` and drop `PER_DEVICE_BATCH_SIZE` back to `2` (with `NUM_GENERATIONS=2`). You'll need Colab Pro / an A100, or just use the dedicated training Space (`anugrah55/opensleuth-training-gemini-cli`) which is configured to retrain the 3B model end-to-end.\n",
|
| 788 |
+
"- **Curriculum.** Pass `difficulty=\"easy\"` to `build_synthesis_dataset(...)` for an easier warm-up, then re-run with `difficulty=\"hard\"` once the format reward saturates.\n",
|
| 789 |
+
"- **Add tasks.** Push a row to the [`anugrah55/opensleuth-tasks`](https://huggingface.co/datasets/anugrah55/opensleuth-tasks) Hub dataset; the env hot-reloads on its next boot, no redeploy needed, and this notebook's `discover_functions(client)` will pick them up automatically.\n",
|
| 790 |
+
"- **Eval externally.** The repo's `eval/run_eval.py` runs the same fuzz-tested verification headlessly; point it at your pushed adapter and the live env Space to get an apples-to-apples score against the baseline.\n",
|
| 791 |
+
"\n",
|
| 792 |
+
"### Links again\n",
|
| 793 |
+
"\n",
|
| 794 |
+
"- Env Space: https://huggingface.co/spaces/anugrah55/opensleuth-env-gemini-cli\n",
|
| 795 |
+
"- Training Space (full 3B retrain): https://huggingface.co/spaces/anugrah55/opensleuth-training-gemini-cli\n",
|
| 796 |
+
"- Task dataset (open-ended): https://huggingface.co/datasets/anugrah55/opensleuth-tasks\n",
|
| 797 |
+
"- Trained adapter (after you push): `https://huggingface.co/<your-username>/opensleuth-grpo-colab`"
|
| 798 |
+
]
|
| 799 |
+
}
|
| 800 |
+
],
|
| 801 |
+
"metadata": {
|
| 802 |
+
"accelerator": "GPU",
|
| 803 |
+
"colab": {
|
| 804 |
+
"gpuType": "T4",
|
| 805 |
+
"name": "train_opensleuth_grpo.ipynb",
|
| 806 |
+
"provenance": [],
|
| 807 |
+
"toc_visible": true
|
| 808 |
+
},
|
| 809 |
+
"kernelspec": {
|
| 810 |
+
"display_name": "Python 3",
|
| 811 |
+
"language": "python",
|
| 812 |
+
"name": "python3"
|
| 813 |
+
},
|
| 814 |
+
"language_info": {
|
| 815 |
+
"name": "python",
|
| 816 |
+
"version": "3.10"
|
| 817 |
+
}
|
| 818 |
+
},
|
| 819 |
+
"nbformat": 4,
|
| 820 |
+
"nbformat_minor": 5
|
| 821 |
+
}
|