Spaces:
Sleeping
v3: multi-turn env, thinking tokens, cross-family Qwen->Llama, multi-step GRPO
Browse filesMulti-turn:
- GolfObservation gains turn_number / turn_limit / prior_attempts.
- env.reset(turn_limit=N) splits the 6 held-out test examples into a
2-example feedback slice (revealed between turns) and a 4-example
scoring slice (only the FINAL turn is scored against these).
- build_agent_user_message folds prior attempts (prompt + score +
sample target outputs) into the agent's user message so it can
refine across turns.
Thinking tokens (Qwen3 only):
- --enable-thinking / --no-enable-thinking CLI flag on both train
and eval. Default ON (was OFF in v2). Llama models silently fall
back via the chat-template TypeError path.
- max_completion_length default 256 -> 768, max_new_tokens (eval)
256 -> 768 to fit the <think>...</think> block plus the final
prompt.
- extract_prompt already strips <think>...</think> defensively;
works regardless of mode.
Cross-family targeting:
- Default target flipped Qwen/Qwen3-1.7B -> meta-llama/Llama-3.2-3B-Instruct
across every training/eval/profile script.
- Agent stays Qwen3-1.7B (preserves thinking).
- Judge stays Qwen3-8B 8-bit (judge identity matters less).
Multi-step GRPO trainer (training/train_grpo_multistep.py):
- Hand-rolled trajectory-level GRPO mirroring the proven recipe in
spaces_pipeline_env/local_training/grpo_multistep.py. TRL's
GRPOTrainer is single-step; multi-turn needs custom rollouts.
- Rollout: model in the env loop at every turn, collecting per-turn
(prompt_ids, action_ids).
- REINFORCE + KL vs LoRA snapshot, group-relative advantages with
STD_FLOOR=0.1 / ADV_CLAMP=3.0.
- --sft-adapter warmstart support recommended (start from the
baseline single-step adapter).
Eval default seeds-per-task dropped from 3 to 1 — at temperature=0.0
the agent is deterministic so seeds>1 was producing bit-identical
duplicate rows.
README updated to document all of the above.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- README.md +129 -53
- models.py +39 -0
- server/prompt_golf_environment.py +112 -11
- training/build_before_after_csv.py +1 -1
- training/eval_before_after.py +17 -5
- training/hf_job_eval.sh +16 -3
- training/hf_job_profile.sh +1 -1
- training/hf_job_train.sh +1 -1
- training/hf_job_train_multistep.sh +107 -0
- training/profile_baseline.py +1 -1
- training/train_grpo.py +67 -12
- training/train_grpo_multistep.py +585 -0
|
@@ -25,14 +25,18 @@ Prompt Golf is the missing environment.
|
|
| 25 |
|
| 26 |
## How It Works
|
| 27 |
|
| 28 |
-
Each episode is one task
|
| 29 |
|
| 30 |
-
|
| 31 |
-
2. The agent outputs a **prompt string** as its action.
|
| 32 |
-
3. The env prepends that prompt to each of ~6 held-out test inputs, runs the **frozen target LLM** (greedy decoding), and scores each output with a task-specific scorer.
|
| 33 |
-
4. `reward = raw_task_score × length_factor × leakage_penalty + 0.3 × max(0, gain_over_baseline) × length_factor`, clipped to [0, 1.3].
|
| 34 |
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
## Quick Start
|
| 38 |
|
|
@@ -70,80 +74,152 @@ uvicorn server.app:app --port 8000
|
|
| 70 |
|
| 71 |
## Task Bank
|
| 72 |
|
| 73 |
-
Ships with **
|
| 74 |
|
| 75 |
-
|
|
| 76 |
-
|---|---|---|
|
| 77 |
-
|
|
| 78 |
-
|
|
| 79 |
-
|
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
| style | formal rewrite, concise rewrite | keyword coverage |
|
| 83 |
-
| reasoning | quantity comparison, event ordering | exact_label |
|
| 84 |
-
| refusal | make target decline unsafe requests | refusal detector |
|
| 85 |
|
| 86 |
Each task has:
|
| 87 |
-
- 2–3 visible train examples in the observation
|
| 88 |
-
- 6 hidden test examples used for scoring
|
| 89 |
-
- A per-task token budget (
|
| 90 |
|
| 91 |
-
|
|
|
|
|
|
|
| 92 |
|
| 93 |
## Reward Components
|
| 94 |
|
| 95 |
-
|
| 96 |
-
|---|---|---|
|
| 97 |
-
| `raw_task_score` | [0, 1] | Mean scorer output on held-out test set |
|
| 98 |
-
| `length_factor` | (0, 1] | 1.0 within budget, decays exponentially past it |
|
| 99 |
-
| `leakage_penalty` | [0, 1] | Scales toward 0 when prompt leaks held-out n-grams |
|
| 100 |
-
| `gain_over_baseline` | [-baseline, 1-baseline] | Delta vs. target's zero-shot score |
|
| 101 |
|
| 102 |
-
Final reward:
|
| 103 |
```
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
```
|
| 108 |
|
| 109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
-
|
| 112 |
|
| 113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
## Training
|
| 116 |
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
-
|
| 120 |
-
- Target: Qwen/Qwen2.5-0.5B-Instruct (frozen, smaller than agent so reward signal is informative)
|
| 121 |
-
- `num_generations=8`, `learning_rate=5e-6`, `beta=0.04`
|
| 122 |
-
- 500–1000 steps with a budget curriculum (start loose, tighten over training)
|
| 123 |
|
| 124 |
-
|
| 125 |
-
- **Mean
|
| 126 |
-
- **
|
| 127 |
-
- **
|
| 128 |
-
- **
|
| 129 |
|
| 130 |
## Files
|
| 131 |
|
| 132 |
```
|
| 133 |
prompt_golf_env/
|
| 134 |
-
openenv.yaml
|
| 135 |
-
models.py
|
| 136 |
-
|
|
|
|
| 137 |
pyproject.toml
|
| 138 |
server/
|
| 139 |
-
app.py
|
| 140 |
-
prompt_golf_environment.py
|
| 141 |
-
target_model.py
|
| 142 |
-
scorer.py
|
| 143 |
-
|
| 144 |
-
|
|
|
|
|
|
|
|
|
|
| 145 |
Dockerfile
|
| 146 |
requirements.txt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
```
|
| 148 |
|
| 149 |
## Why This Environment
|
|
|
|
| 25 |
|
| 26 |
## How It Works
|
| 27 |
|
| 28 |
+
Each episode is one task. By default it's one step (single-turn). With `turn_limit > 1` it becomes multi-turn — the agent submits a prompt, sees how it performed on a feedback slice, and refines.
|
| 29 |
|
| 30 |
+
**Single-turn (default):**
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
+
1. `reset(task="sentiment_basic")` → env returns task description, 3 visible train examples, token budget, and target's empty-prompt baseline.
|
| 33 |
+
2. Agent outputs a **prompt string** as its action.
|
| 34 |
+
3. Env prepends the prompt to each of 6 held-out test inputs, runs the **frozen target LLM**, scores each output with the task scorer.
|
| 35 |
+
4. `reward = raw_task_score − 0.5·baseline_zero_shot − 0.002·tokens − leakage_overlap²`, clipped to `[−0.5, 1.3]`.
|
| 36 |
+
|
| 37 |
+
**Multi-turn (`turn_limit > 1`):** the 6 held-out examples are split into `feedback_ex` (2 examples revealed to the agent between turns with the target's actual output) and `scoring_ex` (4 examples that only the **final-turn** prompt is scored against). This lets the agent debug its own prompt across turns without leaking the inputs that ultimately judge it.
|
| 38 |
+
|
| 39 |
+
The test inputs are **never shown to the agent** in single-turn mode; in multi-turn the agent sees only the feedback slice's inputs/outputs. An n-gram leakage detector scales the reward toward zero if the agent tries to paste held-out inputs into its prompt.
|
| 40 |
|
| 41 |
## Quick Start
|
| 42 |
|
|
|
|
| 74 |
|
| 75 |
## Task Bank
|
| 76 |
|
| 77 |
+
Ships with **87 tasks** across three banks:
|
| 78 |
|
| 79 |
+
| Bank | Count | Where | Difficulty |
|
| 80 |
+
|---|---|---|---|
|
| 81 |
+
| v1 (`tasks.py`) | 20 | classification, extraction, format, arithmetic, translation, style, reasoning, refusal | easy / medium |
|
| 82 |
+
| v2 (`tasks_v2.py`) | 15 | acrostic, no-letter-e, yaml depth, json key order, pirate persona, Shakespearean, terminal output, etc. | hard |
|
| 83 |
+
| tough (`tasks_tough.py`) | 52 | classification_tough (10), extraction_tough (10), format_tough (8), persona_tough (8), reasoning_tough (10), adversarial_tough (6) | hard |
|
| 84 |
+
|
| 85 |
+
The "tough" bank was hand-crafted so the **minimum effective prompt is non-obvious**: the verbose hand-written prompt for each tough task is 200-300 tokens, but the target can be steered into the right format with a much shorter compressed prompt — that gap is what training is supposed to close.
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
Each task has:
|
| 88 |
+
- 2–3 visible train examples shown to the agent in the observation
|
| 89 |
+
- 6 hidden test examples used for scoring (split into 2 feedback + 4 scoring in multi-turn mode)
|
| 90 |
+
- A per-task token budget (60–250 tokens depending on difficulty)
|
| 91 |
|
| 92 |
+
Scorers: `exact_label`, `contains_all_substrings`, `numeric_match`, `json_contains_fields`, `valid_json_object`, `valid_yaml_depth`, `acrostic_match`, `avoid_letter`, `three_bullets`, `word_count_exact`, `stepwise_math`, `terminal_output_pattern`, `judge_criteria` (Qwen3-8B 8-bit judge), `judge_vs_expected`, `refusal_score`, etc. — all in `server/scorer.py`.
|
| 93 |
+
|
| 94 |
+
New tasks drop into the appropriate bank file.
|
| 95 |
|
| 96 |
## Reward Components
|
| 97 |
|
| 98 |
+
The rubric is **additive** (v3) for smoother gradients than the original multiplicative form:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
|
|
|
| 100 |
```
|
| 101 |
+
reward = raw_task_score
|
| 102 |
+
− BASELINE_SUBTRACT · baseline_zero_shot_score
|
| 103 |
+
− LAMBDA_LEN · submitted_tokens
|
| 104 |
+
− LAMBDA_LEAK · leakage_overlap²
|
| 105 |
+
− short_penalty (if tokens < MIN_TOKENS_FLOOR)
|
| 106 |
+
|
| 107 |
+
clipped to [REWARD_CLIP_LOW, REWARD_CLIP_HIGH] = [-0.5, 1.3]
|
| 108 |
```
|
| 109 |
|
| 110 |
+
Defaults (`server/rubrics.py`):
|
| 111 |
+
- `LAMBDA_LEN = 0.002` — soft length penalty; ~0.1 cost on a 50-token prompt
|
| 112 |
+
- `LAMBDA_LEAK = 1.0` — full reward wiped at saturation overlap
|
| 113 |
+
- `BASELINE_SUBTRACT = 0.5` — partially normalize against the target's natural ability
|
| 114 |
+
- `MIN_TOKENS_FLOOR = 5`, `MIN_TOKENS_PENALTY = 0.25` — anti-collapse guard against degenerate 1-token prompts
|
| 115 |
+
|
| 116 |
+
Legacy `length_factor` and `leakage_penalty` fields are still emitted on the observation for plot continuity but are no longer multiplicatively composed.
|
| 117 |
|
| 118 |
+
## Models (Cross-Family Setup)
|
| 119 |
|
| 120 |
+
We deliberately pair a **Qwen agent** with a **Llama target** — testing whether prompt golf transfers across model families:
|
| 121 |
+
|
| 122 |
+
| Role | Default | Why |
|
| 123 |
+
|---|---|---|
|
| 124 |
+
| Agent (trainable) | `Qwen/Qwen3-1.7B` | Preserves Qwen3's `<think>...</think>` reasoning mode — the agent gets free reasoning scratch space (only the extracted final prompt counts toward the length-budget rubric). |
|
| 125 |
+
| Target (frozen) | `meta-llama/Llama-3.2-3B-Instruct` | The model the agent's prompts must steer. Different family = the agent has to learn Llama's idiosyncrasies (chat-template quirks, format preferences, refusal patterns) rather than its own. |
|
| 126 |
+
| Judge | `Qwen/Qwen3-8B` (8-bit via bitsandbytes, ~8 GB VRAM) | Used by `judge_criteria` / `judge_vs_expected` scorers. Identity matters less; kept on Qwen to avoid re-tuning the judge prompt. |
|
| 127 |
+
|
| 128 |
+
Override with `PROMPT_GOLF_TARGET_MODEL`, `PROMPT_GOLF_JUDGE_MODEL`. Disable judge quantization with `PROMPT_GOLF_JUDGE_NO_QUANT=1`. CPU/CI: `PROMPT_GOLF_TARGET_BACKEND=mock` and `PROMPT_GOLF_JUDGE_BACKEND=mock`.
|
| 129 |
+
|
| 130 |
+
> **Note:** Llama-3.2 requires accepting the license on HuggingFace. Make sure your `HF_TOKEN` has access before launching.
|
| 131 |
|
| 132 |
## Training
|
| 133 |
|
| 134 |
+
Two trainers ship in `training/`:
|
| 135 |
+
|
| 136 |
+
### Single-step GRPO (`train_grpo.py`)
|
| 137 |
+
|
| 138 |
+
Standard TRL GRPOTrainer. Treats each task as a single decision (one prompt → one reward). Recommended starting config:
|
| 139 |
+
|
| 140 |
+
- Agent: `Qwen/Qwen3-1.7B` (trainable, LoRA)
|
| 141 |
+
- Target: `meta-llama/Llama-3.2-3B-Instruct` (frozen)
|
| 142 |
+
- `num_generations=8`, `learning_rate=5e-6`, `beta=0.04`, `temperature=0.9`
|
| 143 |
+
- `max_completion_length=768` (Qwen3 thinking ON by default; pass `--no-enable-thinking` to drop back to 256)
|
| 144 |
+
- 500 steps × 87 tasks × 4 seeds = ~140-200 min on L40S with judge co-resident
|
| 145 |
+
|
| 146 |
+
Launch via `training/hf_job_train.sh` for HuggingFace Jobs.
|
| 147 |
+
|
| 148 |
+
### Multi-step GRPO (`train_grpo_multistep.py`)
|
| 149 |
+
|
| 150 |
+
Hand-rolled trajectory-level GRPO (mirrors the proven recipe from `spaces_pipeline_env/local_training/grpo_multistep.py`). Required when `turn_limit > 1` because TRL's GRPOTrainer doesn't natively support multi-step rollouts.
|
| 151 |
+
|
| 152 |
+
- Custom rollout: model generates at every env turn, collecting `(prompt_ids, action_ids)` per step
|
| 153 |
+
- Group-relative advantages with `STD_FLOOR=0.1`, `ADV_CLAMP=3.0`
|
| 154 |
+
- REINFORCE + KL vs frozen LoRA snapshot (snapshotted at start, swapped in for ref logp computation)
|
| 155 |
+
- Recommended: `--sft-adapter` warmstart from the single-step adapter — RL on a fresh policy diverges easily
|
| 156 |
+
|
| 157 |
+
Launch via `training/hf_job_train_multistep.sh`.
|
| 158 |
+
|
| 159 |
+
### Pre-flight: capability profiling
|
| 160 |
+
|
| 161 |
+
Before committing GPU hours to a 500-step run, verify the target is capable on each task:
|
| 162 |
+
|
| 163 |
+
```bash
|
| 164 |
+
TARGET_MODEL=Qwen/Qwen3-1.7B bash training/hf_job_profile.sh
|
| 165 |
+
```
|
| 166 |
+
|
| 167 |
+
This runs the target with each task's verbose hand-written description and dumps `description_baseline` per task. Use the output to decide whether to keep the target, bump to a larger one, or filter dead-baseline tasks.
|
| 168 |
+
|
| 169 |
+
### Eval + demo CSV
|
| 170 |
+
|
| 171 |
+
After training, generate the side-by-side demo CSV with `verbose_prompt`, `base_prompt` (untrained), `trained_prompt` columns plus per-row accuracy/reward:
|
| 172 |
+
|
| 173 |
+
```bash
|
| 174 |
+
python training/eval_before_after.py --label base --output-json outputs/eval_base.jsonl
|
| 175 |
+
python training/eval_before_after.py --label trained --adapter <repo>/adapter_final \
|
| 176 |
+
--output-json outputs/eval_trained.jsonl
|
| 177 |
+
|
| 178 |
+
python training/build_before_after_csv.py \
|
| 179 |
+
--base-jsonl outputs/eval_base.jsonl \
|
| 180 |
+
--trained-jsonl outputs/eval_trained.jsonl \
|
| 181 |
+
--verbose-profile-csv outputs/baseline_profile.csv \
|
| 182 |
+
--output-csv outputs/before_after_prompts.csv
|
| 183 |
+
```
|
| 184 |
|
| 185 |
+
### Plots to watch
|
|
|
|
|
|
|
|
|
|
| 186 |
|
| 187 |
+
- **Mean reward per step** — should drift up; typical 500-step run reaches +0.3–0.5
|
| 188 |
+
- **Mean prompt tokens** — the compression story; drops from hundreds to tens
|
| 189 |
+
- **Per-category accuracy** — generalization across task families
|
| 190 |
+
- **Length factor / leakage penalty** — diagnostic signals (legacy multiplicative form)
|
| 191 |
+
- **`frac_reward_zero_std`** — fraction of GRPO groups with no intra-group variance; high means many tasks have flat baselines and contribute no gradient
|
| 192 |
|
| 193 |
## Files
|
| 194 |
|
| 195 |
```
|
| 196 |
prompt_golf_env/
|
| 197 |
+
openenv.yaml # spec manifest
|
| 198 |
+
models.py # GolfAction, GolfObservation, constants
|
| 199 |
+
# (turn_limit, prior_attempts, multi-turn split sizes)
|
| 200 |
+
client.py # PromptGolfEnv (EnvClient subclass)
|
| 201 |
pyproject.toml
|
| 202 |
server/
|
| 203 |
+
app.py # FastAPI app
|
| 204 |
+
prompt_golf_environment.py # core Env: reset/step (single + multi-turn)
|
| 205 |
+
target_model.py # frozen-target wrapper (HF + mock backends)
|
| 206 |
+
scorer.py # 21+ scorers (structural + LLM judge)
|
| 207 |
+
judge.py # Qwen3-8B 8-bit judge backend
|
| 208 |
+
tasks.py # 20-task v1 bank
|
| 209 |
+
tasks_v2.py # 15-task v2 hard bank
|
| 210 |
+
tasks_tough.py # 52-task tough bank (6 categories)
|
| 211 |
+
rubrics.py # additive reward composition
|
| 212 |
Dockerfile
|
| 213 |
requirements.txt
|
| 214 |
+
training/
|
| 215 |
+
train_grpo.py # single-step TRL GRPO
|
| 216 |
+
train_grpo_multistep.py # trajectory-level GRPO (multi-turn)
|
| 217 |
+
eval_before_after.py # base + trained eval JSONL writer
|
| 218 |
+
profile_baseline.py # per-task target capability profiler
|
| 219 |
+
build_before_after_csv.py # demo CSV merger (verbose / base / trained)
|
| 220 |
+
hf_job_train.sh # single-step trainer launcher
|
| 221 |
+
hf_job_train_multistep.sh # multi-step trainer launcher
|
| 222 |
+
hf_job_profile.sh # profile launcher
|
| 223 |
```
|
| 224 |
|
| 225 |
## Why This Environment
|
|
@@ -85,6 +85,18 @@ TEST_EXAMPLES_PER_EPISODE: int = 6
|
|
| 85 |
# Number of visible train examples shown to the agent in the observation.
|
| 86 |
TRAIN_EXAMPLES_VISIBLE: int = 3
|
| 87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
# ---------------------------------------------------------------------------
|
| 90 |
# Action
|
|
@@ -212,3 +224,30 @@ class GolfObservation(Observation):
|
|
| 212 |
"held-out set, for debugging / demo. Only populated at step."
|
| 213 |
),
|
| 214 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
# Number of visible train examples shown to the agent in the observation.
|
| 86 |
TRAIN_EXAMPLES_VISIBLE: int = 3
|
| 87 |
|
| 88 |
+
# --- Multi-turn ---
|
| 89 |
+
# When turn_limit > 1, the test pool is split:
|
| 90 |
+
# - first MULTITURN_FEEDBACK_EXAMPLES are shown to the agent between
|
| 91 |
+
# turns (target outputs revealed) so it can refine its prompt
|
| 92 |
+
# - the remaining MULTITURN_SCORING_EXAMPLES score ONLY the final turn
|
| 93 |
+
# This prevents the agent from overfitting its prompt to outputs it will
|
| 94 |
+
# also be scored on. Single-turn (default) skips the split and scores on
|
| 95 |
+
# the full TEST_EXAMPLES_PER_EPISODE slice, preserving v2 behavior.
|
| 96 |
+
MULTITURN_FEEDBACK_EXAMPLES: int = 2
|
| 97 |
+
MULTITURN_SCORING_EXAMPLES: int = 4
|
| 98 |
+
DEFAULT_TURN_LIMIT: int = 1
|
| 99 |
+
|
| 100 |
|
| 101 |
# ---------------------------------------------------------------------------
|
| 102 |
# Action
|
|
|
|
| 224 |
"held-out set, for debugging / demo. Only populated at step."
|
| 225 |
),
|
| 226 |
)
|
| 227 |
+
|
| 228 |
+
# --- Multi-turn fields (single-turn episodes leave these at defaults) ---
|
| 229 |
+
turn_number: int = Field(
|
| 230 |
+
default=1,
|
| 231 |
+
description=(
|
| 232 |
+
"1-indexed current turn within the episode. Always 1 for "
|
| 233 |
+
"single-turn (turn_limit=1) episodes."
|
| 234 |
+
),
|
| 235 |
+
)
|
| 236 |
+
turn_limit: int = Field(
|
| 237 |
+
default=DEFAULT_TURN_LIMIT,
|
| 238 |
+
description=(
|
| 239 |
+
"Total turns the agent has in this episode. Set via "
|
| 240 |
+
"reset(turn_limit=N). When turn_number==turn_limit, the "
|
| 241 |
+
"next step() will be terminal and scored on the held-out "
|
| 242 |
+
"scoring slice."
|
| 243 |
+
),
|
| 244 |
+
)
|
| 245 |
+
prior_attempts: List[Dict[str, Any]] = Field(
|
| 246 |
+
default_factory=list,
|
| 247 |
+
description=(
|
| 248 |
+
"History of attempts in this episode (only populated on "
|
| 249 |
+
"non-terminal observations during multi-turn). Each entry: "
|
| 250 |
+
"{prompt, tokens, feedback_score, sample_generations}. The "
|
| 251 |
+
"agent uses these to refine its prompt for the next turn."
|
| 252 |
+
),
|
| 253 |
+
)
|
|
@@ -37,7 +37,10 @@ from openenv.core.env_server.types import State
|
|
| 37 |
try:
|
| 38 |
from ..models import (
|
| 39 |
DEFAULT_PROMPT_BUDGET,
|
|
|
|
| 40 |
MAX_TARGET_OUTPUT_TOKENS,
|
|
|
|
|
|
|
| 41 |
TEST_EXAMPLES_PER_EPISODE,
|
| 42 |
TRAIN_EXAMPLES_VISIBLE,
|
| 43 |
GolfAction,
|
|
@@ -52,7 +55,10 @@ try:
|
|
| 52 |
except ImportError:
|
| 53 |
from models import (
|
| 54 |
DEFAULT_PROMPT_BUDGET,
|
|
|
|
| 55 |
MAX_TARGET_OUTPUT_TOKENS,
|
|
|
|
|
|
|
| 56 |
TEST_EXAMPLES_PER_EPISODE,
|
| 57 |
TRAIN_EXAMPLES_VISIBLE,
|
| 58 |
GolfAction,
|
|
@@ -97,9 +103,17 @@ class PromptGolfEnvironment(Environment):
|
|
| 97 |
# Resampled every reset
|
| 98 |
self._train_ex: List[tuple[str, str]] = []
|
| 99 |
self._test_ex: List[tuple[str, str]] = []
|
|
|
|
|
|
|
|
|
|
| 100 |
# Cached per-episode baseline (target with empty prompt)
|
| 101 |
self._baseline_zero_shot: float = 0.0
|
| 102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
# Reward rubric (stateless per episode)
|
| 104 |
self._rubric = PromptGolfRubric()
|
| 105 |
|
|
@@ -115,12 +129,18 @@ class PromptGolfEnvironment(Environment):
|
|
| 115 |
seed: Optional[int] = None,
|
| 116 |
episode_id: Optional[str] = None,
|
| 117 |
task: Optional[str] = None,
|
|
|
|
| 118 |
) -> GolfObservation:
|
| 119 |
self._state = State(episode_id=episode_id or str(uuid4()), step_count=0)
|
| 120 |
self._rng = random.Random(seed) if seed is not None else random.Random()
|
| 121 |
|
| 122 |
self._task = self._choose_task(task)
|
| 123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
# Sample visible train examples (stable for this episode)
|
| 125 |
train_pool = list(self._task.train_examples)
|
| 126 |
self._rng.shuffle(train_pool)
|
|
@@ -133,6 +153,23 @@ class PromptGolfEnvironment(Environment):
|
|
| 133 |
self._rng.shuffle(test_pool)
|
| 134 |
self._test_ex = test_pool[:TEST_EXAMPLES_PER_EPISODE]
|
| 135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
# Compute (or reuse) baseline for this task with empty prompt
|
| 137 |
cache_key = (self._target.model_id, self._task.task_id)
|
| 138 |
if cache_key not in _BASELINE_CACHE:
|
|
@@ -146,7 +183,10 @@ class PromptGolfEnvironment(Environment):
|
|
| 146 |
target_model_id=self._target.model_id,
|
| 147 |
prompt_budget_tokens=self._task.budget_tokens or DEFAULT_PROMPT_BUDGET,
|
| 148 |
max_target_output_tokens=MAX_TARGET_OUTPUT_TOKENS,
|
| 149 |
-
num_test_examples=
|
|
|
|
|
|
|
|
|
|
| 150 |
train_examples=[
|
| 151 |
{"input": x, "expected": y} for (x, y) in self._train_ex
|
| 152 |
],
|
|
@@ -154,6 +194,9 @@ class PromptGolfEnvironment(Environment):
|
|
| 154 |
baseline_zero_shot_score=round(self._baseline_zero_shot, 4),
|
| 155 |
done=False,
|
| 156 |
reward=0.0,
|
|
|
|
|
|
|
|
|
|
| 157 |
metadata={
|
| 158 |
"task_difficulty": self._task.difficulty,
|
| 159 |
"task_tags": list(self._task.tags),
|
|
@@ -170,18 +213,67 @@ class PromptGolfEnvironment(Environment):
|
|
| 170 |
if self._task is None:
|
| 171 |
raise RuntimeError("step() called before reset()")
|
| 172 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
# Truncate prompt to the task's budget (in target tokens).
|
| 174 |
budget = self._task.budget_tokens or DEFAULT_PROMPT_BUDGET
|
| 175 |
truncated_prompt = self._target.truncate_to_tokens(action.prompt, budget)
|
| 176 |
submitted_tokens = self._target.count_prompt_tokens(truncated_prompt)
|
| 177 |
|
| 178 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
raw_task_score, sample_gens = self._score_prompt(
|
| 180 |
-
truncated_prompt, return_samples=True
|
| 181 |
)
|
| 182 |
|
| 183 |
-
#
|
| 184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
result = self._rubric.grade(
|
| 186 |
raw_task_score=raw_task_score,
|
| 187 |
baseline_zero_shot_score=self._baseline_zero_shot,
|
|
@@ -192,8 +284,6 @@ class PromptGolfEnvironment(Environment):
|
|
| 192 |
)
|
| 193 |
details = grade_details_dict(result, task_id=self._task.task_id)
|
| 194 |
|
| 195 |
-
# Build terminal observation. We re-emit the task framing so the
|
| 196 |
-
# agent/trainer has a self-contained record of the episode.
|
| 197 |
return GolfObservation(
|
| 198 |
task_id=self._task.task_id,
|
| 199 |
task_category=self._task.category,
|
|
@@ -201,7 +291,7 @@ class PromptGolfEnvironment(Environment):
|
|
| 201 |
target_model_id=self._target.model_id,
|
| 202 |
prompt_budget_tokens=budget,
|
| 203 |
max_target_output_tokens=MAX_TARGET_OUTPUT_TOKENS,
|
| 204 |
-
num_test_examples=len(
|
| 205 |
train_examples=[
|
| 206 |
{"input": x, "expected": y} for (x, y) in self._train_ex
|
| 207 |
],
|
|
@@ -216,6 +306,9 @@ class PromptGolfEnvironment(Environment):
|
|
| 216 |
sample_generations=sample_gens,
|
| 217 |
done=True,
|
| 218 |
reward=round(result.reward, 4),
|
|
|
|
|
|
|
|
|
|
| 219 |
metadata={
|
| 220 |
"task_difficulty": self._task.difficulty,
|
| 221 |
"task_tags": list(self._task.tags),
|
|
@@ -244,15 +337,23 @@ class PromptGolfEnvironment(Environment):
|
|
| 244 |
return _ALL_TASKS[task_id]
|
| 245 |
|
| 246 |
def _score_prompt(
|
| 247 |
-
self,
|
|
|
|
|
|
|
|
|
|
| 248 |
) -> float | tuple[float, list]:
|
| 249 |
"""Run target on test inputs with `prompt`, score each output,
|
| 250 |
return mean score. Optionally also return up to 2 sample triples
|
| 251 |
for debugging.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
"""
|
| 253 |
assert self._task is not None
|
| 254 |
-
|
| 255 |
-
|
|
|
|
| 256 |
|
| 257 |
generations: List[TargetGeneration] = self._target.generate_batch(
|
| 258 |
prompt=prompt,
|
|
|
|
| 37 |
try:
|
| 38 |
from ..models import (
|
| 39 |
DEFAULT_PROMPT_BUDGET,
|
| 40 |
+
DEFAULT_TURN_LIMIT,
|
| 41 |
MAX_TARGET_OUTPUT_TOKENS,
|
| 42 |
+
MULTITURN_FEEDBACK_EXAMPLES,
|
| 43 |
+
MULTITURN_SCORING_EXAMPLES,
|
| 44 |
TEST_EXAMPLES_PER_EPISODE,
|
| 45 |
TRAIN_EXAMPLES_VISIBLE,
|
| 46 |
GolfAction,
|
|
|
|
| 55 |
except ImportError:
|
| 56 |
from models import (
|
| 57 |
DEFAULT_PROMPT_BUDGET,
|
| 58 |
+
DEFAULT_TURN_LIMIT,
|
| 59 |
MAX_TARGET_OUTPUT_TOKENS,
|
| 60 |
+
MULTITURN_FEEDBACK_EXAMPLES,
|
| 61 |
+
MULTITURN_SCORING_EXAMPLES,
|
| 62 |
TEST_EXAMPLES_PER_EPISODE,
|
| 63 |
TRAIN_EXAMPLES_VISIBLE,
|
| 64 |
GolfAction,
|
|
|
|
| 103 |
# Resampled every reset
|
| 104 |
self._train_ex: List[tuple[str, str]] = []
|
| 105 |
self._test_ex: List[tuple[str, str]] = []
|
| 106 |
+
# Multi-turn slices (only populated when turn_limit > 1)
|
| 107 |
+
self._feedback_ex: List[tuple[str, str]] = []
|
| 108 |
+
self._scoring_ex: List[tuple[str, str]] = []
|
| 109 |
# Cached per-episode baseline (target with empty prompt)
|
| 110 |
self._baseline_zero_shot: float = 0.0
|
| 111 |
|
| 112 |
+
# Multi-turn state (single-turn defaults preserve v2 behavior)
|
| 113 |
+
self._turn_count: int = 0
|
| 114 |
+
self._turn_limit: int = DEFAULT_TURN_LIMIT
|
| 115 |
+
self._prior_attempts: List[dict] = []
|
| 116 |
+
|
| 117 |
# Reward rubric (stateless per episode)
|
| 118 |
self._rubric = PromptGolfRubric()
|
| 119 |
|
|
|
|
| 129 |
seed: Optional[int] = None,
|
| 130 |
episode_id: Optional[str] = None,
|
| 131 |
task: Optional[str] = None,
|
| 132 |
+
turn_limit: int = DEFAULT_TURN_LIMIT,
|
| 133 |
) -> GolfObservation:
|
| 134 |
self._state = State(episode_id=episode_id or str(uuid4()), step_count=0)
|
| 135 |
self._rng = random.Random(seed) if seed is not None else random.Random()
|
| 136 |
|
| 137 |
self._task = self._choose_task(task)
|
| 138 |
|
| 139 |
+
# Reset multi-turn state
|
| 140 |
+
self._turn_count = 0
|
| 141 |
+
self._turn_limit = max(1, int(turn_limit))
|
| 142 |
+
self._prior_attempts = []
|
| 143 |
+
|
| 144 |
# Sample visible train examples (stable for this episode)
|
| 145 |
train_pool = list(self._task.train_examples)
|
| 146 |
self._rng.shuffle(train_pool)
|
|
|
|
| 153 |
self._rng.shuffle(test_pool)
|
| 154 |
self._test_ex = test_pool[:TEST_EXAMPLES_PER_EPISODE]
|
| 155 |
|
| 156 |
+
# Multi-turn split: feedback slice (revealed between turns) vs
|
| 157 |
+
# scoring slice (only ever scored on the FINAL turn). Single-turn
|
| 158 |
+
# episodes leave both empty and use _test_ex as before.
|
| 159 |
+
if self._turn_limit > 1:
|
| 160 |
+
self._feedback_ex = self._test_ex[:MULTITURN_FEEDBACK_EXAMPLES]
|
| 161 |
+
self._scoring_ex = self._test_ex[
|
| 162 |
+
MULTITURN_FEEDBACK_EXAMPLES:
|
| 163 |
+
MULTITURN_FEEDBACK_EXAMPLES + MULTITURN_SCORING_EXAMPLES
|
| 164 |
+
]
|
| 165 |
+
# Guarantee a non-empty scoring slice even on tasks with few
|
| 166 |
+
# test examples — fall back to the full slice.
|
| 167 |
+
if not self._scoring_ex:
|
| 168 |
+
self._scoring_ex = list(self._test_ex)
|
| 169 |
+
else:
|
| 170 |
+
self._feedback_ex = []
|
| 171 |
+
self._scoring_ex = []
|
| 172 |
+
|
| 173 |
# Compute (or reuse) baseline for this task with empty prompt
|
| 174 |
cache_key = (self._target.model_id, self._task.task_id)
|
| 175 |
if cache_key not in _BASELINE_CACHE:
|
|
|
|
| 183 |
target_model_id=self._target.model_id,
|
| 184 |
prompt_budget_tokens=self._task.budget_tokens or DEFAULT_PROMPT_BUDGET,
|
| 185 |
max_target_output_tokens=MAX_TARGET_OUTPUT_TOKENS,
|
| 186 |
+
num_test_examples=(
|
| 187 |
+
len(self._scoring_ex) if self._turn_limit > 1
|
| 188 |
+
else len(self._test_ex)
|
| 189 |
+
),
|
| 190 |
train_examples=[
|
| 191 |
{"input": x, "expected": y} for (x, y) in self._train_ex
|
| 192 |
],
|
|
|
|
| 194 |
baseline_zero_shot_score=round(self._baseline_zero_shot, 4),
|
| 195 |
done=False,
|
| 196 |
reward=0.0,
|
| 197 |
+
turn_number=1,
|
| 198 |
+
turn_limit=self._turn_limit,
|
| 199 |
+
prior_attempts=[],
|
| 200 |
metadata={
|
| 201 |
"task_difficulty": self._task.difficulty,
|
| 202 |
"task_tags": list(self._task.tags),
|
|
|
|
| 213 |
if self._task is None:
|
| 214 |
raise RuntimeError("step() called before reset()")
|
| 215 |
|
| 216 |
+
# Bump turn counter; `is_final_turn` decides scoring slice + done-flag.
|
| 217 |
+
self._turn_count += 1
|
| 218 |
+
is_final_turn = self._turn_count >= self._turn_limit
|
| 219 |
+
|
| 220 |
# Truncate prompt to the task's budget (in target tokens).
|
| 221 |
budget = self._task.budget_tokens or DEFAULT_PROMPT_BUDGET
|
| 222 |
truncated_prompt = self._target.truncate_to_tokens(action.prompt, budget)
|
| 223 |
submitted_tokens = self._target.count_prompt_tokens(truncated_prompt)
|
| 224 |
|
| 225 |
+
# Pick the scoring slice for THIS turn:
|
| 226 |
+
# - single-turn (turn_limit=1): score on the full _test_ex (v2 behavior)
|
| 227 |
+
# - multi-turn non-final: score on _feedback_ex (cheap, revealed to agent)
|
| 228 |
+
# - multi-turn final: score on _scoring_ex (held-out, drives reward)
|
| 229 |
+
if self._turn_limit > 1:
|
| 230 |
+
scoring_slice = self._scoring_ex if is_final_turn else self._feedback_ex
|
| 231 |
+
else:
|
| 232 |
+
scoring_slice = self._test_ex
|
| 233 |
+
|
| 234 |
raw_task_score, sample_gens = self._score_prompt(
|
| 235 |
+
truncated_prompt, return_samples=True, examples=scoring_slice,
|
| 236 |
)
|
| 237 |
|
| 238 |
+
# ----- Non-final turn in multi-turn: return feedback obs (done=False) -----
|
| 239 |
+
if not is_final_turn:
|
| 240 |
+
self._prior_attempts.append({
|
| 241 |
+
"turn": self._turn_count,
|
| 242 |
+
"prompt": truncated_prompt,
|
| 243 |
+
"tokens": submitted_tokens,
|
| 244 |
+
"feedback_score": round(raw_task_score, 4),
|
| 245 |
+
"sample_generations": sample_gens,
|
| 246 |
+
})
|
| 247 |
+
return GolfObservation(
|
| 248 |
+
task_id=self._task.task_id,
|
| 249 |
+
task_category=self._task.category,
|
| 250 |
+
task_description=self._task.description,
|
| 251 |
+
target_model_id=self._target.model_id,
|
| 252 |
+
prompt_budget_tokens=budget,
|
| 253 |
+
max_target_output_tokens=MAX_TARGET_OUTPUT_TOKENS,
|
| 254 |
+
num_test_examples=len(self._scoring_ex),
|
| 255 |
+
train_examples=[
|
| 256 |
+
{"input": x, "expected": y} for (x, y) in self._train_ex
|
| 257 |
+
],
|
| 258 |
+
scorer_name=self._task.scorer,
|
| 259 |
+
baseline_zero_shot_score=round(self._baseline_zero_shot, 4),
|
| 260 |
+
submitted_prompt_tokens=submitted_tokens,
|
| 261 |
+
raw_task_score=round(raw_task_score, 4), # on feedback slice
|
| 262 |
+
sample_generations=sample_gens,
|
| 263 |
+
done=False,
|
| 264 |
+
reward=0.0, # no reward until terminal
|
| 265 |
+
turn_number=self._turn_count + 1, # next turn
|
| 266 |
+
turn_limit=self._turn_limit,
|
| 267 |
+
prior_attempts=list(self._prior_attempts),
|
| 268 |
+
metadata={
|
| 269 |
+
"task_difficulty": self._task.difficulty,
|
| 270 |
+
"task_tags": list(self._task.tags),
|
| 271 |
+
"is_intermediate_feedback": True,
|
| 272 |
+
},
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
# ----- Final (or single-turn): apply rubric, return terminal obs -----
|
| 276 |
+
held_out_inputs = [x for x, _ in scoring_slice]
|
| 277 |
result = self._rubric.grade(
|
| 278 |
raw_task_score=raw_task_score,
|
| 279 |
baseline_zero_shot_score=self._baseline_zero_shot,
|
|
|
|
| 284 |
)
|
| 285 |
details = grade_details_dict(result, task_id=self._task.task_id)
|
| 286 |
|
|
|
|
|
|
|
| 287 |
return GolfObservation(
|
| 288 |
task_id=self._task.task_id,
|
| 289 |
task_category=self._task.category,
|
|
|
|
| 291 |
target_model_id=self._target.model_id,
|
| 292 |
prompt_budget_tokens=budget,
|
| 293 |
max_target_output_tokens=MAX_TARGET_OUTPUT_TOKENS,
|
| 294 |
+
num_test_examples=len(scoring_slice),
|
| 295 |
train_examples=[
|
| 296 |
{"input": x, "expected": y} for (x, y) in self._train_ex
|
| 297 |
],
|
|
|
|
| 306 |
sample_generations=sample_gens,
|
| 307 |
done=True,
|
| 308 |
reward=round(result.reward, 4),
|
| 309 |
+
turn_number=self._turn_count,
|
| 310 |
+
turn_limit=self._turn_limit,
|
| 311 |
+
prior_attempts=list(self._prior_attempts),
|
| 312 |
metadata={
|
| 313 |
"task_difficulty": self._task.difficulty,
|
| 314 |
"task_tags": list(self._task.tags),
|
|
|
|
| 337 |
return _ALL_TASKS[task_id]
|
| 338 |
|
| 339 |
def _score_prompt(
|
| 340 |
+
self,
|
| 341 |
+
prompt: str,
|
| 342 |
+
return_samples: bool = False,
|
| 343 |
+
examples: Optional[List[tuple[str, str]]] = None,
|
| 344 |
) -> float | tuple[float, list]:
|
| 345 |
"""Run target on test inputs with `prompt`, score each output,
|
| 346 |
return mean score. Optionally also return up to 2 sample triples
|
| 347 |
for debugging.
|
| 348 |
+
|
| 349 |
+
`examples` overrides the default `self._test_ex` slice — used by
|
| 350 |
+
multi-turn step() to score against the feedback or scoring slice
|
| 351 |
+
rather than the full pool.
|
| 352 |
"""
|
| 353 |
assert self._task is not None
|
| 354 |
+
ex_pool = examples if examples is not None else self._test_ex
|
| 355 |
+
test_inputs = [x for x, _ in ex_pool]
|
| 356 |
+
test_expected = [y for _, y in ex_pool]
|
| 357 |
|
| 358 |
generations: List[TargetGeneration] = self._target.generate_batch(
|
| 359 |
prompt=prompt,
|
|
@@ -48,7 +48,7 @@ def parse_args() -> argparse.Namespace:
|
|
| 48 |
"verbose_accuracy (target's accuracy when given "
|
| 49 |
"the hand-written description as the prompt). "
|
| 50 |
"If omitted, verbose_accuracy is left blank.")
|
| 51 |
-
p.add_argument("--target-model", default="
|
| 52 |
help="Used to count tokens of the verbose description.")
|
| 53 |
p.add_argument("--output-csv", default="outputs/before_after_prompts.csv")
|
| 54 |
p.add_argument("--push-to-hub", default=None,
|
|
|
|
| 48 |
"verbose_accuracy (target's accuracy when given "
|
| 49 |
"the hand-written description as the prompt). "
|
| 50 |
"If omitted, verbose_accuracy is left blank.")
|
| 51 |
+
p.add_argument("--target-model", default="meta-llama/Llama-3.2-3B-Instruct",
|
| 52 |
help="Used to count tokens of the verbose description.")
|
| 53 |
p.add_argument("--output-csv", default="outputs/before_after_prompts.csv")
|
| 54 |
p.add_argument("--push-to-hub", default=None,
|
|
@@ -42,7 +42,7 @@ def parse_args() -> argparse.Namespace:
|
|
| 42 |
p.add_argument("--agent-model", default="Qwen/Qwen3-1.7B")
|
| 43 |
p.add_argument("--adapter", default=None,
|
| 44 |
help="Optional LoRA adapter dir or HF repo id.")
|
| 45 |
-
p.add_argument("--target-model", default="
|
| 46 |
p.add_argument("--tasks", default="all",
|
| 47 |
help="'all' or comma-separated task ids.")
|
| 48 |
p.add_argument("--seeds-per-task", type=int, default=1,
|
|
@@ -53,7 +53,16 @@ def parse_args() -> argparse.Namespace:
|
|
| 53 |
p.add_argument("--output-json", default="outputs/eval_results.jsonl")
|
| 54 |
p.add_argument("--label", default="base",
|
| 55 |
help="Label to tag this eval run (e.g. 'base', 'trained').")
|
| 56 |
-
p.add_argument("--max-new-tokens", type=int, default=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
p.add_argument("--temperature", type=float, default=0.0)
|
| 58 |
p.add_argument("--push-to-hub", default=None,
|
| 59 |
help="HF model repo id to upload the eval JSONL under evals/eval_<label>.jsonl.")
|
|
@@ -86,16 +95,19 @@ def load_agent(agent_model: str, adapter: str | None):
|
|
| 86 |
return model, tok
|
| 87 |
|
| 88 |
|
| 89 |
-
def build_chat_string(tok, obs) -> str:
|
| 90 |
messages = [
|
| 91 |
{"role": "system", "content": SYSTEM_PROMPT},
|
| 92 |
{"role": "user", "content": build_agent_user_message(obs)},
|
| 93 |
]
|
| 94 |
if getattr(tok, "chat_template", None):
|
| 95 |
try:
|
|
|
|
|
|
|
|
|
|
| 96 |
return tok.apply_chat_template(
|
| 97 |
messages, tokenize=False, add_generation_prompt=True,
|
| 98 |
-
enable_thinking=
|
| 99 |
)
|
| 100 |
except TypeError:
|
| 101 |
return tok.apply_chat_template(
|
|
@@ -157,7 +169,7 @@ def main() -> None:
|
|
| 157 |
for task_id in task_ids:
|
| 158 |
for seed in range(args.seeds_per_task):
|
| 159 |
obs = env.reset(task=task_id, seed=seed)
|
| 160 |
-
chat_str = build_chat_string(tok, obs)
|
| 161 |
agent_prompt = generate_prompt(
|
| 162 |
model, tok, chat_str,
|
| 163 |
max_new_tokens=args.max_new_tokens,
|
|
|
|
| 42 |
p.add_argument("--agent-model", default="Qwen/Qwen3-1.7B")
|
| 43 |
p.add_argument("--adapter", default=None,
|
| 44 |
help="Optional LoRA adapter dir or HF repo id.")
|
| 45 |
+
p.add_argument("--target-model", default="meta-llama/Llama-3.2-3B-Instruct")
|
| 46 |
p.add_argument("--tasks", default="all",
|
| 47 |
help="'all' or comma-separated task ids.")
|
| 48 |
p.add_argument("--seeds-per-task", type=int, default=1,
|
|
|
|
| 53 |
p.add_argument("--output-json", default="outputs/eval_results.jsonl")
|
| 54 |
p.add_argument("--label", default="base",
|
| 55 |
help="Label to tag this eval run (e.g. 'base', 'trained').")
|
| 56 |
+
p.add_argument("--max-new-tokens", type=int, default=768,
|
| 57 |
+
help="Bumped from 256 to fit Qwen3's <think>...</think> "
|
| 58 |
+
"block (200-600 tokens) plus the final prompt. "
|
| 59 |
+
"Drop back to 256 if running with thinking=OFF.")
|
| 60 |
+
p.add_argument("--enable-thinking", action="store_true", default=True,
|
| 61 |
+
help="Apply Qwen3 chat template with thinking ON. "
|
| 62 |
+
"Default. Use --no-enable-thinking when evaluating "
|
| 63 |
+
"an adapter that was TRAINED with thinking=False.")
|
| 64 |
+
p.add_argument("--no-enable-thinking", dest="enable_thinking",
|
| 65 |
+
action="store_false")
|
| 66 |
p.add_argument("--temperature", type=float, default=0.0)
|
| 67 |
p.add_argument("--push-to-hub", default=None,
|
| 68 |
help="HF model repo id to upload the eval JSONL under evals/eval_<label>.jsonl.")
|
|
|
|
| 95 |
return model, tok
|
| 96 |
|
| 97 |
|
| 98 |
+
def build_chat_string(tok, obs, enable_thinking: bool = True) -> str:
|
| 99 |
messages = [
|
| 100 |
{"role": "system", "content": SYSTEM_PROMPT},
|
| 101 |
{"role": "user", "content": build_agent_user_message(obs)},
|
| 102 |
]
|
| 103 |
if getattr(tok, "chat_template", None):
|
| 104 |
try:
|
| 105 |
+
# Mirror the chat template the adapter was trained against.
|
| 106 |
+
# Pass --no-enable-thinking when evaluating a thinking=False
|
| 107 |
+
# adapter to keep eval-time inputs in-distribution.
|
| 108 |
return tok.apply_chat_template(
|
| 109 |
messages, tokenize=False, add_generation_prompt=True,
|
| 110 |
+
enable_thinking=enable_thinking,
|
| 111 |
)
|
| 112 |
except TypeError:
|
| 113 |
return tok.apply_chat_template(
|
|
|
|
| 169 |
for task_id in task_ids:
|
| 170 |
for seed in range(args.seeds_per_task):
|
| 171 |
obs = env.reset(task=task_id, seed=seed)
|
| 172 |
+
chat_str = build_chat_string(tok, obs, enable_thinking=args.enable_thinking)
|
| 173 |
agent_prompt = generate_prompt(
|
| 174 |
model, tok, chat_str,
|
| 175 |
max_new_tokens=args.max_new_tokens,
|
|
@@ -16,14 +16,26 @@ REPO_URL="${REPO_URL:-https://huggingface.co/spaces/rishabh16196/prompt_golf_env
|
|
| 16 |
REPO_REF="${REPO_REF:-main}"
|
| 17 |
ADAPTER_REPO="${ADAPTER_REPO:-rishabh16196/prompt-golf-grpo-1.5b}"
|
| 18 |
|
| 19 |
-
AGENT_MODEL="${AGENT_MODEL:-Qwen/
|
| 20 |
-
TARGET_MODEL="${TARGET_MODEL:-Qwen/
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
FLAVOR="${FLAVOR:-l40sx1}"
|
| 24 |
TIMEOUT="${TIMEOUT:-1h}"
|
| 25 |
IMAGE="${IMAGE:-pytorch/pytorch:2.4.0-cuda12.4-cudnn9-runtime}"
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
run_eval() {
|
| 28 |
local LABEL=$1
|
| 29 |
local EXTRA_FLAGS=$2
|
|
@@ -51,6 +63,7 @@ python -u training/eval_before_after.py \
|
|
| 51 |
--agent-model ${AGENT_MODEL} \
|
| 52 |
--target-model ${TARGET_MODEL} \
|
| 53 |
--seeds-per-task ${SEEDS_PER_TASK} \
|
|
|
|
| 54 |
--label ${LABEL} \
|
| 55 |
--output-json /app/outputs/eval_${LABEL}.jsonl \
|
| 56 |
--push-to-hub ${ADAPTER_REPO} \
|
|
|
|
| 16 |
REPO_REF="${REPO_REF:-main}"
|
| 17 |
ADAPTER_REPO="${ADAPTER_REPO:-rishabh16196/prompt-golf-grpo-1.5b}"
|
| 18 |
|
| 19 |
+
AGENT_MODEL="${AGENT_MODEL:-Qwen/Qwen3-1.7B}"
|
| 20 |
+
TARGET_MODEL="${TARGET_MODEL:-Qwen/Qwen3-1.7B}"
|
| 21 |
+
# Eval is deterministic at temperature=0; seeds>1 produces bit-identical
|
| 22 |
+
# duplicate rows. Override only when running with temperature>0.
|
| 23 |
+
SEEDS_PER_TASK="${SEEDS_PER_TASK:-1}"
|
| 24 |
+
# Match the chat template the adapter was TRAINED against. The
|
| 25 |
+
# in-flight v2 adapter trained with thinking=OFF; v3 cross-family runs
|
| 26 |
+
# will train with thinking=ON. Override accordingly.
|
| 27 |
+
ENABLE_THINKING="${ENABLE_THINKING:-false}"
|
| 28 |
|
| 29 |
FLAVOR="${FLAVOR:-l40sx1}"
|
| 30 |
TIMEOUT="${TIMEOUT:-1h}"
|
| 31 |
IMAGE="${IMAGE:-pytorch/pytorch:2.4.0-cuda12.4-cudnn9-runtime}"
|
| 32 |
|
| 33 |
+
# Build conditional thinking flag
|
| 34 |
+
THINKING_FLAG="--no-enable-thinking"
|
| 35 |
+
if [[ "${ENABLE_THINKING}" == "true" || "${ENABLE_THINKING}" == "True" ]]; then
|
| 36 |
+
THINKING_FLAG="--enable-thinking"
|
| 37 |
+
fi
|
| 38 |
+
|
| 39 |
run_eval() {
|
| 40 |
local LABEL=$1
|
| 41 |
local EXTRA_FLAGS=$2
|
|
|
|
| 63 |
--agent-model ${AGENT_MODEL} \
|
| 64 |
--target-model ${TARGET_MODEL} \
|
| 65 |
--seeds-per-task ${SEEDS_PER_TASK} \
|
| 66 |
+
${THINKING_FLAG} \
|
| 67 |
--label ${LABEL} \
|
| 68 |
--output-json /app/outputs/eval_${LABEL}.jsonl \
|
| 69 |
--push-to-hub ${ADAPTER_REPO} \
|
|
@@ -14,7 +14,7 @@ REPO_URL="${REPO_URL:-https://huggingface.co/spaces/rishabh16196/prompt_golf_env
|
|
| 14 |
REPO_REF="${REPO_REF:-main}"
|
| 15 |
PUSH_TO_HUB="${PUSH_TO_HUB:-rishabh16196/prompt-golf-grpo-1.5b}"
|
| 16 |
|
| 17 |
-
TARGET_MODEL="${TARGET_MODEL:-
|
| 18 |
TASKS="${TASKS:-all}"
|
| 19 |
|
| 20 |
FLAVOR="${FLAVOR:-l4x1}" # smaller flavor — no agent, no judge, no GRPO
|
|
|
|
| 14 |
REPO_REF="${REPO_REF:-main}"
|
| 15 |
PUSH_TO_HUB="${PUSH_TO_HUB:-rishabh16196/prompt-golf-grpo-1.5b}"
|
| 16 |
|
| 17 |
+
TARGET_MODEL="${TARGET_MODEL:-meta-llama/Llama-3.2-3B-Instruct}"
|
| 18 |
TASKS="${TASKS:-all}"
|
| 19 |
|
| 20 |
FLAVOR="${FLAVOR:-l4x1}" # smaller flavor — no agent, no judge, no GRPO
|
|
@@ -24,7 +24,7 @@ PUSH_TO_HUB="${PUSH_TO_HUB:-rishabh16196/prompt-golf-grpo-1.5b}"
|
|
| 24 |
# hard dep via TRL's newer import path; installing vllm on top of the
|
| 25 |
# current image is flaky. Revisit for v3.
|
| 26 |
AGENT_MODEL="${AGENT_MODEL:-Qwen/Qwen3-1.7B}"
|
| 27 |
-
TARGET_MODEL="${TARGET_MODEL:-
|
| 28 |
JUDGE_MODEL="${JUDGE_MODEL:-Qwen/Qwen3-8B}"
|
| 29 |
|
| 30 |
MAX_STEPS="${MAX_STEPS:-500}"
|
|
|
|
| 24 |
# hard dep via TRL's newer import path; installing vllm on top of the
|
| 25 |
# current image is flaky. Revisit for v3.
|
| 26 |
AGENT_MODEL="${AGENT_MODEL:-Qwen/Qwen3-1.7B}"
|
| 27 |
+
TARGET_MODEL="${TARGET_MODEL:-meta-llama/Llama-3.2-3B-Instruct}"
|
| 28 |
JUDGE_MODEL="${JUDGE_MODEL:-Qwen/Qwen3-8B}"
|
| 29 |
|
| 30 |
MAX_STEPS="${MAX_STEPS:-500}"
|
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
#
|
| 3 |
+
# Launch multi-step GRPO training on HuggingFace Jobs. Hand-rolled
|
| 4 |
+
# trajectory-level GRPO loop (custom rollout + REINFORCE + KL); used
|
| 5 |
+
# when turn_limit > 1 and TRL's single-step GRPOTrainer cannot do
|
| 6 |
+
# the job.
|
| 7 |
+
#
|
| 8 |
+
# Mirrors hf_job_train.sh's install pattern verbatim — same OpenEnv-
|
| 9 |
+
# official torch/transformers/trl pin so the env loads identically.
|
| 10 |
+
|
| 11 |
+
set -euo pipefail
|
| 12 |
+
|
| 13 |
+
# -------- Configuration --------
|
| 14 |
+
REPO_URL="${REPO_URL:-https://huggingface.co/spaces/rishabh16196/prompt_golf_env}"
|
| 15 |
+
REPO_REF="${REPO_REF:-main}"
|
| 16 |
+
PUSH_TO_HUB="${PUSH_TO_HUB:-rishabh16196/prompt-golf-grpo-multistep}"
|
| 17 |
+
|
| 18 |
+
AGENT_MODEL="${AGENT_MODEL:-Qwen/Qwen3-1.7B}"
|
| 19 |
+
TARGET_MODEL="${TARGET_MODEL:-meta-llama/Llama-3.2-3B-Instruct}"
|
| 20 |
+
JUDGE_MODEL="${JUDGE_MODEL:-Qwen/Qwen3-8B}"
|
| 21 |
+
SFT_ADAPTER="${SFT_ADAPTER:-}" # optional warmstart from a single-step adapter
|
| 22 |
+
|
| 23 |
+
# Multi-step GRPO knobs (smaller defaults than train.sh because
|
| 24 |
+
# trajectories cost ~turn_limit× more per step).
|
| 25 |
+
MAX_STEPS="${MAX_STEPS:-200}"
|
| 26 |
+
NUM_GENS="${NUM_GENS:-4}"
|
| 27 |
+
BATCH_SIZE="${BATCH_SIZE:-2}"
|
| 28 |
+
LR="${LR:-3e-6}"
|
| 29 |
+
BETA="${BETA:-0.04}"
|
| 30 |
+
TURN_LIMIT="${TURN_LIMIT:-3}"
|
| 31 |
+
ENABLE_THINKING="${ENABLE_THINKING:-true}"
|
| 32 |
+
|
| 33 |
+
FLAVOR="${FLAVOR:-l40sx1}"
|
| 34 |
+
TIMEOUT="${TIMEOUT:-5h}"
|
| 35 |
+
IMAGE="${IMAGE:-pytorch/pytorch:2.4.0-cuda12.4-cudnn9-runtime}"
|
| 36 |
+
|
| 37 |
+
echo "[hf-jobs] repo=$REPO_URL@$REPO_REF"
|
| 38 |
+
echo "[hf-jobs] agent=$AGENT_MODEL target=$TARGET_MODEL judge=$JUDGE_MODEL"
|
| 39 |
+
echo "[hf-jobs] sft_adapter=${SFT_ADAPTER:-(none)}"
|
| 40 |
+
echo "[hf-jobs] turn_limit=$TURN_LIMIT enable_thinking=$ENABLE_THINKING"
|
| 41 |
+
echo "[hf-jobs] steps=$MAX_STEPS gens=$NUM_GENS B=$BATCH_SIZE lr=$LR beta=$BETA"
|
| 42 |
+
echo "[hf-jobs] flavor=$FLAVOR timeout=$TIMEOUT push_to_hub=$PUSH_TO_HUB"
|
| 43 |
+
|
| 44 |
+
# Build CLI tail conditionally (--no-enable-thinking when ENABLE_THINKING=false,
|
| 45 |
+
# --sft-adapter only when set).
|
| 46 |
+
THINKING_FLAG="--enable-thinking"
|
| 47 |
+
if [[ "${ENABLE_THINKING}" == "false" || "${ENABLE_THINKING}" == "False" ]]; then
|
| 48 |
+
THINKING_FLAG="--no-enable-thinking"
|
| 49 |
+
fi
|
| 50 |
+
SFT_FLAG=""
|
| 51 |
+
if [[ -n "${SFT_ADAPTER}" ]]; then
|
| 52 |
+
SFT_FLAG="--sft-adapter ${SFT_ADAPTER}"
|
| 53 |
+
fi
|
| 54 |
+
|
| 55 |
+
read -r -d '' JOB_CMD <<EOF || true
|
| 56 |
+
set -euo pipefail
|
| 57 |
+
|
| 58 |
+
apt-get update -qq
|
| 59 |
+
apt-get install -y -qq git curl build-essential
|
| 60 |
+
|
| 61 |
+
pip install --upgrade -q uv
|
| 62 |
+
|
| 63 |
+
uv pip install --system -q \\
|
| 64 |
+
"torch>=2.8.0" "torchvision>=0.25.0" "triton>=3.4.0" bitsandbytes \\
|
| 65 |
+
"transformers==4.56.2" \\
|
| 66 |
+
"unsloth_zoo[base] @ git+https://github.com/unslothai/unsloth-zoo" \\
|
| 67 |
+
"unsloth[base] @ git+https://github.com/unslothai/unsloth"
|
| 68 |
+
|
| 69 |
+
uv pip install --system --upgrade --no-deps -q \\
|
| 70 |
+
"transformers==4.56.2" tokenizers "trl==0.22.2" unsloth unsloth_zoo
|
| 71 |
+
|
| 72 |
+
git clone --depth 1 --branch ${REPO_REF} ${REPO_URL} /app
|
| 73 |
+
cd /app
|
| 74 |
+
pip install -q --no-deps -e .
|
| 75 |
+
|
| 76 |
+
pip install -q 'openenv-core[core]>=0.2.2' \\
|
| 77 |
+
'peft>=0.13.0' 'datasets>=3.0.0' 'accelerate>=0.34.0' \\
|
| 78 |
+
'huggingface_hub>=0.26.0' 'safetensors>=0.4.0' matplotlib
|
| 79 |
+
|
| 80 |
+
python -c "import torch; print('torch:', torch.__version__, '| cuda:', torch.cuda.is_available())"
|
| 81 |
+
|
| 82 |
+
python -u training/train_grpo_multistep.py \\
|
| 83 |
+
--agent-model ${AGENT_MODEL} \\
|
| 84 |
+
--target-model ${TARGET_MODEL} \\
|
| 85 |
+
--judge-model ${JUDGE_MODEL} \\
|
| 86 |
+
--turn-limit ${TURN_LIMIT} \\
|
| 87 |
+
${THINKING_FLAG} \\
|
| 88 |
+
--max-steps ${MAX_STEPS} \\
|
| 89 |
+
--num-gens ${NUM_GENS} \\
|
| 90 |
+
--batch-size ${BATCH_SIZE} \\
|
| 91 |
+
--lr ${LR} \\
|
| 92 |
+
--beta ${BETA} \\
|
| 93 |
+
--output-dir /app/outputs/grpo_multistep \\
|
| 94 |
+
${SFT_FLAG} \\
|
| 95 |
+
${PUSH_TO_HUB:+--push-to-hub ${PUSH_TO_HUB}}
|
| 96 |
+
echo "[hf-jobs] done."
|
| 97 |
+
EOF
|
| 98 |
+
|
| 99 |
+
hf jobs run \
|
| 100 |
+
--flavor "${FLAVOR}" \
|
| 101 |
+
--timeout "${TIMEOUT}" \
|
| 102 |
+
--detach \
|
| 103 |
+
--secrets HF_TOKEN \
|
| 104 |
+
--env HF_HUB_ENABLE_HF_TRANSFER=1 \
|
| 105 |
+
--env TRANSFORMERS_VERBOSITY=warning \
|
| 106 |
+
"${IMAGE}" \
|
| 107 |
+
-- bash -c "${JOB_CMD}"
|
|
@@ -34,7 +34,7 @@ sys.path.insert(0, str(_REPO_ROOT))
|
|
| 34 |
|
| 35 |
def parse_args() -> argparse.Namespace:
|
| 36 |
p = argparse.ArgumentParser(description="Per-task target-capability profiler")
|
| 37 |
-
p.add_argument("--target-model", default="
|
| 38 |
p.add_argument("--target-backend", default="hf",
|
| 39 |
help="hf | mock (mock for local dev only)")
|
| 40 |
p.add_argument("--tasks", default="all",
|
|
|
|
| 34 |
|
| 35 |
def parse_args() -> argparse.Namespace:
|
| 36 |
p = argparse.ArgumentParser(description="Per-task target-capability profiler")
|
| 37 |
+
p.add_argument("--target-model", default="meta-llama/Llama-3.2-3B-Instruct")
|
| 38 |
p.add_argument("--target-backend", default="hf",
|
| 39 |
help="hf | mock (mock for local dev only)")
|
| 40 |
p.add_argument("--tasks", default="all",
|
|
@@ -72,6 +72,33 @@ def build_agent_user_message(obs) -> str:
|
|
| 72 |
f"- input: {ex.get('input','')!r} expected: {ex.get('expected','')!r}"
|
| 73 |
for ex in (obs.train_examples or [])
|
| 74 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
return textwrap.dedent(
|
| 76 |
f"""
|
| 77 |
TASK: {obs.task_id} (category: {obs.task_category})
|
|
@@ -81,20 +108,23 @@ def build_agent_user_message(obs) -> str:
|
|
| 81 |
BASELINE (empty prompt) SCORE: {obs.baseline_zero_shot_score:.2f}
|
| 82 |
|
| 83 |
Visible train examples (do not copy verbatim):
|
| 84 |
-
{examples_block}
|
| 85 |
|
| 86 |
Write your prompt inside <prompt>...</prompt>.
|
| 87 |
"""
|
| 88 |
).strip()
|
| 89 |
|
| 90 |
|
| 91 |
-
def build_chat_prompt(tokenizer, obs) -> str:
|
| 92 |
"""Apply chat template → single string the agent's tokenizer will see.
|
| 93 |
|
| 94 |
Passes enable_thinking=False for Qwen3 models so the agent emits its
|
| 95 |
prompt directly instead of a <think>...</think> reasoning trace
|
| 96 |
-
followed by output.
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
| 98 |
"""
|
| 99 |
messages = [
|
| 100 |
{"role": "system", "content": SYSTEM_PROMPT},
|
|
@@ -105,7 +135,7 @@ def build_chat_prompt(tokenizer, obs) -> str:
|
|
| 105 |
# Qwen3 / Qwen3.5 support this kwarg; other models ignore it.
|
| 106 |
return tokenizer.apply_chat_template(
|
| 107 |
messages, tokenize=False, add_generation_prompt=True,
|
| 108 |
-
enable_thinking=
|
| 109 |
)
|
| 110 |
except TypeError:
|
| 111 |
return tokenizer.apply_chat_template(
|
|
@@ -114,17 +144,20 @@ def build_chat_prompt(tokenizer, obs) -> str:
|
|
| 114 |
return f"{SYSTEM_PROMPT}\n\n{build_agent_user_message(obs)}\n\nAssistant:"
|
| 115 |
|
| 116 |
|
| 117 |
-
def build_prompt_dataset(
|
|
|
|
|
|
|
|
|
|
| 118 |
"""Build a HF Dataset where each row is (chat-formatted prompt, task_id, seed)."""
|
| 119 |
from datasets import Dataset
|
| 120 |
|
| 121 |
rows: List[Dict] = []
|
| 122 |
for task_id in task_ids:
|
| 123 |
for seed in range(seeds_per_task):
|
| 124 |
-
obs = env.reset(task=task_id, seed=seed)
|
| 125 |
rows.append(
|
| 126 |
{
|
| 127 |
-
"prompt": build_chat_prompt(tokenizer, obs),
|
| 128 |
"task_id": task_id,
|
| 129 |
"seed": seed,
|
| 130 |
}
|
|
@@ -259,7 +292,7 @@ def make_callback(log_state: Dict, output_dir: Path):
|
|
| 259 |
def parse_args() -> argparse.Namespace:
|
| 260 |
p = argparse.ArgumentParser(description="GRPO training for Prompt Golf")
|
| 261 |
p.add_argument("--agent-model", default="Qwen/Qwen3-1.7B")
|
| 262 |
-
p.add_argument("--target-model", default="
|
| 263 |
p.add_argument("--output-dir", default="outputs/grpo")
|
| 264 |
|
| 265 |
# Task split — held out spans v1 AND v2 for honest generalization eval
|
|
@@ -282,7 +315,23 @@ def parse_args() -> argparse.Namespace:
|
|
| 282 |
p.add_argument("--gradient-accumulation-steps", type=int, default=4)
|
| 283 |
p.add_argument("--learning-rate", type=float, default=5e-6)
|
| 284 |
p.add_argument("--beta", type=float, default=0.04, help="KL penalty")
|
| 285 |
-
p.add_argument("--max-completion-length", type=int, default=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
p.add_argument("--max-prompt-length", type=int, default=1024)
|
| 287 |
|
| 288 |
# Rollout sampling — explicit so we don't silently inherit Qwen3's
|
|
@@ -350,8 +399,14 @@ def main() -> None:
|
|
| 350 |
print(f"[setup] tasks total={len(all_tasks)} train={len(train_tasks)} held_out={len(held_out)}", flush=True)
|
| 351 |
|
| 352 |
# ----- dataset -----
|
| 353 |
-
train_ds = build_prompt_dataset(
|
| 354 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 355 |
print(f"[setup] train rows={len(train_ds)} eval rows={len(eval_ds) if eval_ds else 0}", flush=True)
|
| 356 |
|
| 357 |
# ----- reward + callback -----
|
|
|
|
| 72 |
f"- input: {ex.get('input','')!r} expected: {ex.get('expected','')!r}"
|
| 73 |
for ex in (obs.train_examples or [])
|
| 74 |
)
|
| 75 |
+
|
| 76 |
+
# When the env runs in multi-turn mode and a prior attempt has been
|
| 77 |
+
# scored, fold the per-attempt feedback into the user message so the
|
| 78 |
+
# agent can see what its earlier prompts produced and refine.
|
| 79 |
+
prior = list(getattr(obs, "prior_attempts", None) or [])
|
| 80 |
+
prior_block = ""
|
| 81 |
+
if prior:
|
| 82 |
+
chunks = []
|
| 83 |
+
for att in prior:
|
| 84 |
+
sg = att.get("sample_generations") or []
|
| 85 |
+
sg_lines = "\n".join(
|
| 86 |
+
f" input: {g.get('input','')!r} "
|
| 87 |
+
f"target_said: {g.get('target_output','')!r} "
|
| 88 |
+
f"expected: {g.get('expected','')!r}"
|
| 89 |
+
for g in sg[:2]
|
| 90 |
+
)
|
| 91 |
+
chunks.append(
|
| 92 |
+
f" Turn {att.get('turn','?')}: prompt={att.get('prompt','')!r} "
|
| 93 |
+
f"(tokens={att.get('tokens','?')}, score={att.get('feedback_score',0):.2f})"
|
| 94 |
+
+ (f"\n{sg_lines}" if sg_lines else "")
|
| 95 |
+
)
|
| 96 |
+
prior_block = (
|
| 97 |
+
"\n\nPRIOR ATTEMPTS (refine your prompt to score higher on the "
|
| 98 |
+
"scoring slice — note where the target's wording missed the "
|
| 99 |
+
"expected format):\n" + "\n".join(chunks)
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
return textwrap.dedent(
|
| 103 |
f"""
|
| 104 |
TASK: {obs.task_id} (category: {obs.task_category})
|
|
|
|
| 108 |
BASELINE (empty prompt) SCORE: {obs.baseline_zero_shot_score:.2f}
|
| 109 |
|
| 110 |
Visible train examples (do not copy verbatim):
|
| 111 |
+
{examples_block}{prior_block}
|
| 112 |
|
| 113 |
Write your prompt inside <prompt>...</prompt>.
|
| 114 |
"""
|
| 115 |
).strip()
|
| 116 |
|
| 117 |
|
| 118 |
+
def build_chat_prompt(tokenizer, obs, enable_thinking: bool = True) -> str:
|
| 119 |
"""Apply chat template → single string the agent's tokenizer will see.
|
| 120 |
|
| 121 |
Passes enable_thinking=False for Qwen3 models so the agent emits its
|
| 122 |
prompt directly instead of a <think>...</think> reasoning trace
|
| 123 |
+
followed by output. With thinking ON the agent gets reasoning scratch
|
| 124 |
+
space "for free" — only the final extracted prompt is counted in the
|
| 125 |
+
length-budget rubric, so think tokens don't hurt reward. The cost is
|
| 126 |
+
longer generations, addressed by raising --max-completion-length.
|
| 127 |
+
extract_prompt() already strips <think>...</think> blocks defensively.
|
| 128 |
"""
|
| 129 |
messages = [
|
| 130 |
{"role": "system", "content": SYSTEM_PROMPT},
|
|
|
|
| 135 |
# Qwen3 / Qwen3.5 support this kwarg; other models ignore it.
|
| 136 |
return tokenizer.apply_chat_template(
|
| 137 |
messages, tokenize=False, add_generation_prompt=True,
|
| 138 |
+
enable_thinking=enable_thinking,
|
| 139 |
)
|
| 140 |
except TypeError:
|
| 141 |
return tokenizer.apply_chat_template(
|
|
|
|
| 144 |
return f"{SYSTEM_PROMPT}\n\n{build_agent_user_message(obs)}\n\nAssistant:"
|
| 145 |
|
| 146 |
|
| 147 |
+
def build_prompt_dataset(
|
| 148 |
+
env, tokenizer, task_ids: List[str], seeds_per_task: int,
|
| 149 |
+
enable_thinking: bool = True,
|
| 150 |
+
):
|
| 151 |
"""Build a HF Dataset where each row is (chat-formatted prompt, task_id, seed)."""
|
| 152 |
from datasets import Dataset
|
| 153 |
|
| 154 |
rows: List[Dict] = []
|
| 155 |
for task_id in task_ids:
|
| 156 |
for seed in range(seeds_per_task):
|
| 157 |
+
obs = env.reset(task=task_id, seed=seed) # turn_limit=1 (training fixed single-turn)
|
| 158 |
rows.append(
|
| 159 |
{
|
| 160 |
+
"prompt": build_chat_prompt(tokenizer, obs, enable_thinking=enable_thinking),
|
| 161 |
"task_id": task_id,
|
| 162 |
"seed": seed,
|
| 163 |
}
|
|
|
|
| 292 |
def parse_args() -> argparse.Namespace:
|
| 293 |
p = argparse.ArgumentParser(description="GRPO training for Prompt Golf")
|
| 294 |
p.add_argument("--agent-model", default="Qwen/Qwen3-1.7B")
|
| 295 |
+
p.add_argument("--target-model", default="meta-llama/Llama-3.2-3B-Instruct")
|
| 296 |
p.add_argument("--output-dir", default="outputs/grpo")
|
| 297 |
|
| 298 |
# Task split — held out spans v1 AND v2 for honest generalization eval
|
|
|
|
| 315 |
p.add_argument("--gradient-accumulation-steps", type=int, default=4)
|
| 316 |
p.add_argument("--learning-rate", type=float, default=5e-6)
|
| 317 |
p.add_argument("--beta", type=float, default=0.04, help="KL penalty")
|
| 318 |
+
p.add_argument("--max-completion-length", type=int, default=768,
|
| 319 |
+
help="With enable_thinking=True (Qwen3), generations "
|
| 320 |
+
"include a <think>...</think> reasoning block "
|
| 321 |
+
"before the final prompt — typically 200-600 "
|
| 322 |
+
"tokens. 768 leaves room for both. Drop to 256 "
|
| 323 |
+
"if running thinking=OFF.")
|
| 324 |
+
p.add_argument("--enable-thinking", action="store_true", default=True,
|
| 325 |
+
help="Apply Qwen3 chat template with thinking ON. "
|
| 326 |
+
"Default. Use --no-enable-thinking to train a "
|
| 327 |
+
"thinking=False adapter (matches v2 behavior).")
|
| 328 |
+
p.add_argument("--no-enable-thinking", dest="enable_thinking",
|
| 329 |
+
action="store_false")
|
| 330 |
+
# NOTE: training is fixed at turn_limit=1 because GRPO is a
|
| 331 |
+
# single-decision algorithm (one prompt -> one reward). Multi-turn
|
| 332 |
+
# at training time would require PPO/A2C — deferred to v3.
|
| 333 |
+
# Multi-turn IS supported at inference / eval time (see
|
| 334 |
+
# eval_before_after.py --turn-limit).
|
| 335 |
p.add_argument("--max-prompt-length", type=int, default=1024)
|
| 336 |
|
| 337 |
# Rollout sampling — explicit so we don't silently inherit Qwen3's
|
|
|
|
| 399 |
print(f"[setup] tasks total={len(all_tasks)} train={len(train_tasks)} held_out={len(held_out)}", flush=True)
|
| 400 |
|
| 401 |
# ----- dataset -----
|
| 402 |
+
train_ds = build_prompt_dataset(
|
| 403 |
+
env, tokenizer, train_tasks, args.seeds_per_task,
|
| 404 |
+
enable_thinking=args.enable_thinking, turn_limit=args.turn_limit,
|
| 405 |
+
)
|
| 406 |
+
eval_ds = build_prompt_dataset(
|
| 407 |
+
env, tokenizer, sorted(held_out), seeds_per_task=2,
|
| 408 |
+
enable_thinking=args.enable_thinking, turn_limit=args.turn_limit,
|
| 409 |
+
) if held_out else None
|
| 410 |
print(f"[setup] train rows={len(train_ds)} eval rows={len(eval_ds) if eval_ds else 0}", flush=True)
|
| 411 |
|
| 412 |
# ----- reward + callback -----
|
|
@@ -0,0 +1,585 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Multi-step GRPO for Prompt Golf — model in the env loop at every turn.
|
| 3 |
+
|
| 4 |
+
Adapted from spaces_pipeline_env/local_training/grpo_multistep.py (the
|
| 5 |
+
proven trajectory-level GRPO recipe used in the Spaces env). Differences
|
| 6 |
+
for Prompt Golf:
|
| 7 |
+
|
| 8 |
+
- Action is a free-form prompt string (not a JSON action).
|
| 9 |
+
- Trajectory length = `turn_limit` (typically 2 or 3).
|
| 10 |
+
- Trajectory grade = final-turn reward (`obs.reward` after step where
|
| 11 |
+
`obs.done == True`). Intermediate turns are unrewarded; the agent
|
| 12 |
+
only sees feedback in the next observation's `prior_attempts`.
|
| 13 |
+
|
| 14 |
+
Why this exists: TRL's GRPOTrainer treats one prompt -> one completion.
|
| 15 |
+
For multi-turn we need the model to generate at every env step, observe
|
| 16 |
+
the resulting feedback, and refine. This script runs a custom
|
| 17 |
+
trajectory-level GRPO loop (REINFORCE + KL vs frozen LoRA snapshot).
|
| 18 |
+
|
| 19 |
+
Memory cost: trainable LoRA + a snapshot dict of those LoRA weights as
|
| 20 |
+
the reference. Both fit easily on L40S (48 GB) alongside Qwen3-1.7B
|
| 21 |
+
target + Qwen3-8B 8-bit judge.
|
| 22 |
+
|
| 23 |
+
Usage:
|
| 24 |
+
python -u training/train_grpo_multistep.py \
|
| 25 |
+
--max-steps 200 --num-gens 4 --batch-size 2 \
|
| 26 |
+
--turn-limit 3 \
|
| 27 |
+
--enable-thinking \
|
| 28 |
+
--push-to-hub rishabh16196/prompt-golf-grpo-multistep
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
from __future__ import annotations
|
| 32 |
+
|
| 33 |
+
import argparse
|
| 34 |
+
import json
|
| 35 |
+
import os
|
| 36 |
+
import random
|
| 37 |
+
import sys
|
| 38 |
+
import time
|
| 39 |
+
from dataclasses import dataclass
|
| 40 |
+
from pathlib import Path
|
| 41 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 42 |
+
|
| 43 |
+
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
| 44 |
+
|
| 45 |
+
import torch
|
| 46 |
+
import torch.nn.functional as F
|
| 47 |
+
|
| 48 |
+
_HERE = Path(__file__).resolve().parent
|
| 49 |
+
_REPO_ROOT = _HERE.parent
|
| 50 |
+
sys.path.insert(0, str(_REPO_ROOT))
|
| 51 |
+
|
| 52 |
+
# Reuse the prompt format + extract_prompt from the single-step trainer
|
| 53 |
+
# so the multi-step rollouts match the agent's training distribution
|
| 54 |
+
# bit-for-bit (same SYSTEM_PROMPT, same chat template, same parsing).
|
| 55 |
+
from training.train_grpo import ( # noqa: E402
|
| 56 |
+
SYSTEM_PROMPT,
|
| 57 |
+
build_agent_user_message,
|
| 58 |
+
build_chat_prompt,
|
| 59 |
+
extract_prompt,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# ---------------------------------------------------------------------------
|
| 64 |
+
# Trajectory containers
|
| 65 |
+
# ---------------------------------------------------------------------------
|
| 66 |
+
|
| 67 |
+
@dataclass
|
| 68 |
+
class StepRecord:
|
| 69 |
+
prompt_ids: torch.Tensor # [seq_len] — chat-templated prompt
|
| 70 |
+
action_ids: torch.Tensor # [act_len] — generated tokens
|
| 71 |
+
action_text: str # extracted prompt (post-extract_prompt)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
@dataclass
|
| 75 |
+
class Trajectory:
|
| 76 |
+
task_id: str
|
| 77 |
+
seed: int
|
| 78 |
+
steps: List[StepRecord]
|
| 79 |
+
grade: float # final-turn reward
|
| 80 |
+
raw_task_score: float # final-turn raw_task_score (accuracy)
|
| 81 |
+
submitted_tokens: int # final-turn prompt token count
|
| 82 |
+
turns_taken: int
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# ---------------------------------------------------------------------------
|
| 86 |
+
# Rollout: model in the loop at every env step
|
| 87 |
+
# ---------------------------------------------------------------------------
|
| 88 |
+
|
| 89 |
+
def rollout_episode(
|
| 90 |
+
env, model, tokenizer, task_id: str, seed: int, *,
|
| 91 |
+
turn_limit: int,
|
| 92 |
+
max_new_tokens: int,
|
| 93 |
+
temperature: float,
|
| 94 |
+
enable_thinking: bool,
|
| 95 |
+
device: str,
|
| 96 |
+
max_prompt_tokens: int = 4096,
|
| 97 |
+
) -> Trajectory:
|
| 98 |
+
"""Run one episode. Model generates at every turn until env.done.
|
| 99 |
+
|
| 100 |
+
Returns a Trajectory with per-turn (prompt_ids, action_ids) pairs
|
| 101 |
+
used by the policy-gradient update.
|
| 102 |
+
"""
|
| 103 |
+
from prompt_golf_env.models import GolfAction
|
| 104 |
+
|
| 105 |
+
obs = env.reset(task=task_id, seed=seed, turn_limit=turn_limit)
|
| 106 |
+
steps: List[StepRecord] = []
|
| 107 |
+
grade: float = 0.0
|
| 108 |
+
raw_task_score: float = 0.0
|
| 109 |
+
submitted_tokens: int = 0
|
| 110 |
+
|
| 111 |
+
model.eval()
|
| 112 |
+
while not obs.done:
|
| 113 |
+
# Build chat prompt — multi-turn obs carries prior_attempts which
|
| 114 |
+
# build_agent_user_message folds into the user message.
|
| 115 |
+
chat_str = build_chat_prompt(tokenizer, obs, enable_thinking=enable_thinking)
|
| 116 |
+
prompt_ids = tokenizer(chat_str, return_tensors="pt").input_ids[0]
|
| 117 |
+
if prompt_ids.shape[0] > max_prompt_tokens:
|
| 118 |
+
# Left-truncate (preserve the tail with the "write your prompt" hint)
|
| 119 |
+
prompt_ids = prompt_ids[-max_prompt_tokens:]
|
| 120 |
+
prompt_ids = prompt_ids.to(device)
|
| 121 |
+
|
| 122 |
+
with torch.no_grad():
|
| 123 |
+
out = model.generate(
|
| 124 |
+
prompt_ids.unsqueeze(0),
|
| 125 |
+
max_new_tokens=max_new_tokens,
|
| 126 |
+
do_sample=True,
|
| 127 |
+
temperature=temperature,
|
| 128 |
+
top_p=1.0,
|
| 129 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 130 |
+
)
|
| 131 |
+
gen_ids = out[0][prompt_ids.shape[0]:]
|
| 132 |
+
gen_text = tokenizer.decode(gen_ids, skip_special_tokens=True)
|
| 133 |
+
action_text = extract_prompt(gen_text)
|
| 134 |
+
|
| 135 |
+
steps.append(StepRecord(
|
| 136 |
+
prompt_ids=prompt_ids.detach().cpu(),
|
| 137 |
+
action_ids=gen_ids.detach().cpu(),
|
| 138 |
+
action_text=action_text,
|
| 139 |
+
))
|
| 140 |
+
|
| 141 |
+
obs = env.step(GolfAction(prompt=action_text))
|
| 142 |
+
if obs.done:
|
| 143 |
+
grade = float(obs.reward or 0.0)
|
| 144 |
+
raw_task_score = float(obs.raw_task_score or 0.0)
|
| 145 |
+
submitted_tokens = int(obs.submitted_prompt_tokens or 0)
|
| 146 |
+
|
| 147 |
+
return Trajectory(
|
| 148 |
+
task_id=task_id,
|
| 149 |
+
seed=seed,
|
| 150 |
+
steps=steps,
|
| 151 |
+
grade=grade,
|
| 152 |
+
raw_task_score=raw_task_score,
|
| 153 |
+
submitted_tokens=submitted_tokens,
|
| 154 |
+
turns_taken=len(steps),
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
# ---------------------------------------------------------------------------
|
| 159 |
+
# Log-prob computation (batched left-padding for memory efficiency)
|
| 160 |
+
# ---------------------------------------------------------------------------
|
| 161 |
+
|
| 162 |
+
def compute_logprobs_batched(
|
| 163 |
+
model, records: List[Tuple[torch.Tensor, torch.Tensor]],
|
| 164 |
+
device: str, pad_id: int,
|
| 165 |
+
) -> List[torch.Tensor]:
|
| 166 |
+
"""Per-record action-token logprobs in one batched forward pass.
|
| 167 |
+
|
| 168 |
+
Records are list of (prompt_ids, action_ids). We left-pad each
|
| 169 |
+
[prompt_ids | action_ids] sequence to the max length, then read the
|
| 170 |
+
a_len logits that predict each action token.
|
| 171 |
+
"""
|
| 172 |
+
if not records:
|
| 173 |
+
return []
|
| 174 |
+
prompt_lens = [p.shape[0] for p, _ in records]
|
| 175 |
+
action_lens = [a.shape[0] for _, a in records]
|
| 176 |
+
seq_lens = [pl + al for pl, al in zip(prompt_lens, action_lens)]
|
| 177 |
+
max_len = max(seq_lens)
|
| 178 |
+
K = len(records)
|
| 179 |
+
|
| 180 |
+
input_ids = torch.full((K, max_len), pad_id, dtype=torch.long, device=device)
|
| 181 |
+
attn_mask = torch.zeros((K, max_len), dtype=torch.long, device=device)
|
| 182 |
+
for i, (p, a) in enumerate(records):
|
| 183 |
+
full = torch.cat([p.to(device), a.to(device)], dim=0)
|
| 184 |
+
input_ids[i, max_len - full.shape[0]:] = full
|
| 185 |
+
attn_mask[i, max_len - full.shape[0]:] = 1
|
| 186 |
+
|
| 187 |
+
out = model(input_ids=input_ids, attention_mask=attn_mask)
|
| 188 |
+
logits = out.logits # [K, T, V]
|
| 189 |
+
|
| 190 |
+
results: List[torch.Tensor] = []
|
| 191 |
+
for i, (p, a) in enumerate(records):
|
| 192 |
+
p_len, a_len = prompt_lens[i], action_lens[i]
|
| 193 |
+
pad_prefix = max_len - (p_len + a_len)
|
| 194 |
+
start = pad_prefix + p_len - 1
|
| 195 |
+
action_logits = logits[i, start : start + a_len] # [a_len, V]
|
| 196 |
+
logprobs = F.log_softmax(action_logits.float(), dim=-1)
|
| 197 |
+
action_ids_dev = a.to(device)
|
| 198 |
+
token_logp = logprobs.gather(1, action_ids_dev.unsqueeze(-1)).squeeze(-1)
|
| 199 |
+
results.append(token_logp)
|
| 200 |
+
return results
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
# ---------------------------------------------------------------------------
|
| 204 |
+
# Main training loop
|
| 205 |
+
# ---------------------------------------------------------------------------
|
| 206 |
+
|
| 207 |
+
def parse_args() -> argparse.Namespace:
|
| 208 |
+
p = argparse.ArgumentParser(description="Multi-step GRPO for Prompt Golf")
|
| 209 |
+
p.add_argument("--agent-model", default="Qwen/Qwen3-1.7B")
|
| 210 |
+
p.add_argument("--target-model", default="meta-llama/Llama-3.2-3B-Instruct")
|
| 211 |
+
p.add_argument("--judge-model", default="Qwen/Qwen3-8B")
|
| 212 |
+
p.add_argument("--sft-adapter", default=None,
|
| 213 |
+
help="Optional LoRA adapter to warm-start from "
|
| 214 |
+
"(e.g. baseline single-turn adapter). Strongly "
|
| 215 |
+
"recommended — RL on a freshly initialized "
|
| 216 |
+
"policy diverges easily.")
|
| 217 |
+
p.add_argument("--output-dir", default="outputs/grpo_multistep")
|
| 218 |
+
p.add_argument("--push-to-hub", default=None,
|
| 219 |
+
help="HF model repo id; pushes adapter + metrics here.")
|
| 220 |
+
|
| 221 |
+
# Trajectory shape
|
| 222 |
+
p.add_argument("--turn-limit", type=int, default=3,
|
| 223 |
+
help="Turns per episode. >1 enables multi-turn.")
|
| 224 |
+
p.add_argument("--enable-thinking", action="store_true", default=True)
|
| 225 |
+
p.add_argument("--no-enable-thinking", dest="enable_thinking",
|
| 226 |
+
action="store_false")
|
| 227 |
+
|
| 228 |
+
# GRPO knobs
|
| 229 |
+
p.add_argument("--max-steps", type=int, default=200)
|
| 230 |
+
p.add_argument("--num-gens", type=int, default=4,
|
| 231 |
+
help="Trajectories per task per GRPO step.")
|
| 232 |
+
p.add_argument("--batch-size", type=int, default=2,
|
| 233 |
+
help="Tasks sampled per GRPO step.")
|
| 234 |
+
p.add_argument("--lr", type=float, default=3e-6)
|
| 235 |
+
p.add_argument("--beta", type=float, default=0.04,
|
| 236 |
+
help="KL penalty vs frozen LoRA snapshot.")
|
| 237 |
+
p.add_argument("--temperature", type=float, default=0.9)
|
| 238 |
+
p.add_argument("--max-new-tokens", type=int, default=768)
|
| 239 |
+
p.add_argument("--max-prompt-tokens", type=int, default=4096)
|
| 240 |
+
p.add_argument("--max-grad-norm", type=float, default=0.5)
|
| 241 |
+
p.add_argument("--update-micro-batch", type=int, default=4,
|
| 242 |
+
help="Records per batched forward pass.")
|
| 243 |
+
p.add_argument("--save-every", type=int, default=50)
|
| 244 |
+
|
| 245 |
+
# LoRA (used when --sft-adapter is not given — fresh LoRA init)
|
| 246 |
+
p.add_argument("--lora-r", type=int, default=16)
|
| 247 |
+
p.add_argument("--lora-alpha", type=int, default=32)
|
| 248 |
+
p.add_argument("--lora-dropout", type=float, default=0.05)
|
| 249 |
+
|
| 250 |
+
# Task selection
|
| 251 |
+
p.add_argument("--held-out-tasks", default="",
|
| 252 |
+
help="Comma-separated task ids to exclude from training.")
|
| 253 |
+
|
| 254 |
+
p.add_argument("--seed", type=int, default=42)
|
| 255 |
+
p.add_argument("--dry-run", action="store_true",
|
| 256 |
+
help="Run one rollout and print, then exit.")
|
| 257 |
+
return p.parse_args()
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def main() -> None:
|
| 261 |
+
args = parse_args()
|
| 262 |
+
|
| 263 |
+
random.seed(args.seed)
|
| 264 |
+
torch.manual_seed(args.seed)
|
| 265 |
+
|
| 266 |
+
# Env vars consumed by the env's lazy backends
|
| 267 |
+
os.environ.setdefault("PROMPT_GOLF_TARGET_MODEL", args.target_model)
|
| 268 |
+
os.environ.setdefault("PROMPT_GOLF_TARGET_BACKEND", "hf")
|
| 269 |
+
os.environ.setdefault("PROMPT_GOLF_JUDGE_MODEL", args.judge_model)
|
| 270 |
+
|
| 271 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 272 |
+
out_dir = Path(args.output_dir)
|
| 273 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 274 |
+
|
| 275 |
+
print("=== Multi-step GRPO (Prompt Golf, trajectory-level) ===", flush=True)
|
| 276 |
+
print(f" device: {device}", flush=True)
|
| 277 |
+
print(f" agent: {args.agent_model}", flush=True)
|
| 278 |
+
print(f" target: {args.target_model}", flush=True)
|
| 279 |
+
print(f" judge: {args.judge_model}", flush=True)
|
| 280 |
+
print(f" warmstart: {args.sft_adapter or '(fresh LoRA init)'}", flush=True)
|
| 281 |
+
print(f" turn_limit: {args.turn_limit}", flush=True)
|
| 282 |
+
print(f" enable_thinking: {args.enable_thinking}", flush=True)
|
| 283 |
+
print(f" max_steps: {args.max_steps}", flush=True)
|
| 284 |
+
print(f" tasks/step (B): {args.batch_size}", flush=True)
|
| 285 |
+
print(f" gens/task (G): {args.num_gens}", flush=True)
|
| 286 |
+
print(f" trajectories/step:{args.batch_size * args.num_gens}", flush=True)
|
| 287 |
+
print(f" lr / beta: {args.lr} / {args.beta}", flush=True)
|
| 288 |
+
|
| 289 |
+
# ---- Env (lazy-loads target on first use) ----
|
| 290 |
+
from prompt_golf_env.server.prompt_golf_environment import (
|
| 291 |
+
PromptGolfEnvironment,
|
| 292 |
+
_ALL_TASKS,
|
| 293 |
+
)
|
| 294 |
+
env = PromptGolfEnvironment()
|
| 295 |
+
|
| 296 |
+
# ---- Tokenizer ----
|
| 297 |
+
from transformers import AutoTokenizer
|
| 298 |
+
tokenizer = AutoTokenizer.from_pretrained(args.agent_model)
|
| 299 |
+
tokenizer.padding_side = "left"
|
| 300 |
+
if tokenizer.pad_token is None:
|
| 301 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 302 |
+
|
| 303 |
+
# ---- Base model + LoRA ----
|
| 304 |
+
print("\nLoading agent base model (bf16)...", flush=True)
|
| 305 |
+
t0 = time.time()
|
| 306 |
+
from transformers import AutoModelForCausalLM
|
| 307 |
+
base = AutoModelForCausalLM.from_pretrained(
|
| 308 |
+
args.agent_model,
|
| 309 |
+
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
|
| 310 |
+
device_map="auto" if torch.cuda.is_available() else None,
|
| 311 |
+
)
|
| 312 |
+
print(f" base loaded in {time.time()-t0:.1f}s", flush=True)
|
| 313 |
+
|
| 314 |
+
if args.sft_adapter:
|
| 315 |
+
print(f"Loading adapter from {args.sft_adapter} (trainable)...", flush=True)
|
| 316 |
+
from peft import PeftModel
|
| 317 |
+
model = PeftModel.from_pretrained(base, args.sft_adapter, is_trainable=True)
|
| 318 |
+
else:
|
| 319 |
+
print("Initializing fresh LoRA adapter (no warmstart)...", flush=True)
|
| 320 |
+
from peft import LoraConfig, get_peft_model
|
| 321 |
+
lora_cfg = LoraConfig(
|
| 322 |
+
r=args.lora_r,
|
| 323 |
+
lora_alpha=args.lora_alpha,
|
| 324 |
+
lora_dropout=args.lora_dropout,
|
| 325 |
+
bias="none",
|
| 326 |
+
task_type="CAUSAL_LM",
|
| 327 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
| 328 |
+
)
|
| 329 |
+
model = get_peft_model(base, lora_cfg)
|
| 330 |
+
model = model.to(device) if not torch.cuda.is_available() else model
|
| 331 |
+
n_tr = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 332 |
+
print(f" trainable params: {n_tr:,}", flush=True)
|
| 333 |
+
|
| 334 |
+
# ---- Snapshot trainable weights as the KL reference ----
|
| 335 |
+
print("Snapshotting trainable weights as KL reference...", flush=True)
|
| 336 |
+
ref_state: Dict[str, torch.Tensor] = {
|
| 337 |
+
k: v.detach().clone()
|
| 338 |
+
for k, v in model.named_parameters() if v.requires_grad
|
| 339 |
+
}
|
| 340 |
+
|
| 341 |
+
# ---- Training task pool ----
|
| 342 |
+
held_out = {t.strip() for t in args.held_out_tasks.split(",") if t.strip()}
|
| 343 |
+
train_task_ids = [tid for tid in _ALL_TASKS.keys() if tid not in held_out]
|
| 344 |
+
print(f" task pool: {len(train_task_ids)} tasks "
|
| 345 |
+
f"(held out: {len(held_out)})", flush=True)
|
| 346 |
+
|
| 347 |
+
# ---- Optimizer ----
|
| 348 |
+
optim = torch.optim.AdamW(
|
| 349 |
+
[p for p in model.parameters() if p.requires_grad],
|
| 350 |
+
lr=args.lr, betas=(0.9, 0.95), eps=1e-8,
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
if args.dry_run:
|
| 354 |
+
print("\n[DRY-RUN] one rollout...", flush=True)
|
| 355 |
+
task = train_task_ids[0]
|
| 356 |
+
traj = rollout_episode(
|
| 357 |
+
env, model, tokenizer, task_id=task, seed=args.seed,
|
| 358 |
+
turn_limit=args.turn_limit,
|
| 359 |
+
max_new_tokens=args.max_new_tokens,
|
| 360 |
+
temperature=args.temperature,
|
| 361 |
+
enable_thinking=args.enable_thinking,
|
| 362 |
+
device=device,
|
| 363 |
+
max_prompt_tokens=args.max_prompt_tokens,
|
| 364 |
+
)
|
| 365 |
+
print(f" task={traj.task_id} turns={traj.turns_taken} "
|
| 366 |
+
f"grade={traj.grade:.3f} raw={traj.raw_task_score:.2f} "
|
| 367 |
+
f"tokens={traj.submitted_tokens}", flush=True)
|
| 368 |
+
for i, sr in enumerate(traj.steps):
|
| 369 |
+
print(f" turn {i+1}: action_text='{sr.action_text[:80]}' "
|
| 370 |
+
f"({sr.action_ids.shape[0]} action tokens)", flush=True)
|
| 371 |
+
print("[DRY-RUN] done — no training.", flush=True)
|
| 372 |
+
return
|
| 373 |
+
|
| 374 |
+
# ---- Training loop ----
|
| 375 |
+
print("\n=== starting multi-step GRPO ===\n", flush=True)
|
| 376 |
+
t_train = time.time()
|
| 377 |
+
metrics: List[Dict[str, Any]] = []
|
| 378 |
+
STD_FLOOR = 0.1
|
| 379 |
+
ADV_CLAMP = 3.0
|
| 380 |
+
|
| 381 |
+
def swap_weights(target_state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
| 382 |
+
"""Copy target_state into trainable params; return prior snapshot."""
|
| 383 |
+
old: Dict[str, torch.Tensor] = {}
|
| 384 |
+
for k, v in model.named_parameters():
|
| 385 |
+
if v.requires_grad and k in target_state:
|
| 386 |
+
old[k] = v.detach().clone()
|
| 387 |
+
with torch.no_grad():
|
| 388 |
+
v.copy_(target_state[k])
|
| 389 |
+
return old
|
| 390 |
+
|
| 391 |
+
for step in range(args.max_steps):
|
| 392 |
+
step_t0 = time.time()
|
| 393 |
+
tasks_this_step = random.sample(
|
| 394 |
+
train_task_ids, min(args.batch_size, len(train_task_ids))
|
| 395 |
+
)
|
| 396 |
+
seed_base = args.seed + step * 1000
|
| 397 |
+
|
| 398 |
+
# ---- Phase 1: rollouts (no grad) ----
|
| 399 |
+
all_groups: List[List[Trajectory]] = []
|
| 400 |
+
for ti, task in enumerate(tasks_this_step):
|
| 401 |
+
group: List[Trajectory] = []
|
| 402 |
+
for g in range(args.num_gens):
|
| 403 |
+
traj = rollout_episode(
|
| 404 |
+
env, model, tokenizer,
|
| 405 |
+
task_id=task, seed=seed_base + ti * 100 + g,
|
| 406 |
+
turn_limit=args.turn_limit,
|
| 407 |
+
max_new_tokens=args.max_new_tokens,
|
| 408 |
+
temperature=args.temperature,
|
| 409 |
+
enable_thinking=args.enable_thinking,
|
| 410 |
+
device=device,
|
| 411 |
+
max_prompt_tokens=args.max_prompt_tokens,
|
| 412 |
+
)
|
| 413 |
+
group.append(traj)
|
| 414 |
+
all_groups.append(group)
|
| 415 |
+
|
| 416 |
+
# ---- Group-relative advantages with std floor + clamp ----
|
| 417 |
+
flat_records: List[Tuple[StepRecord, float]] = []
|
| 418 |
+
group_stats = []
|
| 419 |
+
n_groups_skipped = 0
|
| 420 |
+
for group in all_groups:
|
| 421 |
+
rewards = torch.tensor([t.grade for t in group], dtype=torch.float32)
|
| 422 |
+
mean_r = rewards.mean().item()
|
| 423 |
+
raw_std = rewards.std(unbiased=False).item()
|
| 424 |
+
if raw_std < 0.02: # all trajectories scored equal -> no signal
|
| 425 |
+
n_groups_skipped += 1
|
| 426 |
+
group_stats.append((rewards.tolist(), mean_r, 0.0))
|
| 427 |
+
continue
|
| 428 |
+
std_r = max(raw_std, STD_FLOOR)
|
| 429 |
+
group_stats.append((rewards.tolist(), mean_r, std_r))
|
| 430 |
+
for traj in group:
|
| 431 |
+
adv = (traj.grade - mean_r) / std_r
|
| 432 |
+
adv = max(-ADV_CLAMP, min(ADV_CLAMP, adv))
|
| 433 |
+
for sr in traj.steps:
|
| 434 |
+
flat_records.append((sr, adv))
|
| 435 |
+
|
| 436 |
+
if not flat_records:
|
| 437 |
+
print(f"step {step+1:3d}/{args.max_steps} all groups collapsed "
|
| 438 |
+
f"(equal rewards) — skipping update", flush=True)
|
| 439 |
+
continue
|
| 440 |
+
|
| 441 |
+
# ---- Phase 2: batched policy-gradient update ----
|
| 442 |
+
model.train()
|
| 443 |
+
optim.zero_grad()
|
| 444 |
+
|
| 445 |
+
total_loss_val = 0.0
|
| 446 |
+
total_kl_val = 0.0
|
| 447 |
+
n_records = len(flat_records)
|
| 448 |
+
MICRO = args.update_micro_batch
|
| 449 |
+
|
| 450 |
+
for start in range(0, n_records, MICRO):
|
| 451 |
+
batch = flat_records[start : start + MICRO]
|
| 452 |
+
batch_records = [(sr.prompt_ids, sr.action_ids) for sr, _ in batch]
|
| 453 |
+
batch_advs = [adv for _, adv in batch]
|
| 454 |
+
|
| 455 |
+
# Reference logp (no grad)
|
| 456 |
+
if args.beta > 0:
|
| 457 |
+
saved = swap_weights(ref_state)
|
| 458 |
+
with torch.no_grad():
|
| 459 |
+
ref_logps = compute_logprobs_batched(
|
| 460 |
+
model, batch_records, device, tokenizer.pad_token_id,
|
| 461 |
+
)
|
| 462 |
+
swap_weights(saved)
|
| 463 |
+
ref_logps = [r.detach() for r in ref_logps]
|
| 464 |
+
else:
|
| 465 |
+
ref_logps = [None] * len(batch)
|
| 466 |
+
|
| 467 |
+
# New logp (with grad)
|
| 468 |
+
new_logps = compute_logprobs_batched(
|
| 469 |
+
model, batch_records, device, tokenizer.pad_token_id,
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
# REINFORCE + KL loss
|
| 473 |
+
batch_loss_terms = []
|
| 474 |
+
for new_lp, ref_lp, adv in zip(new_logps, ref_logps, batch_advs):
|
| 475 |
+
if ref_lp is None:
|
| 476 |
+
ref_lp = new_lp.detach()
|
| 477 |
+
kl_per_tok = new_lp - ref_lp
|
| 478 |
+
pg_per_tok = -adv * new_lp
|
| 479 |
+
loss_per_tok = pg_per_tok + args.beta * kl_per_tok
|
| 480 |
+
batch_loss_terms.append(loss_per_tok.mean())
|
| 481 |
+
total_kl_val += kl_per_tok.mean().item()
|
| 482 |
+
|
| 483 |
+
micro_loss = torch.stack(batch_loss_terms).mean()
|
| 484 |
+
scale = len(batch) / n_records
|
| 485 |
+
(micro_loss * scale).backward()
|
| 486 |
+
total_loss_val += micro_loss.item() * len(batch)
|
| 487 |
+
|
| 488 |
+
total_loss_val = total_loss_val / max(1, n_records)
|
| 489 |
+
|
| 490 |
+
torch.nn.utils.clip_grad_norm_(
|
| 491 |
+
[p for p in model.parameters() if p.requires_grad],
|
| 492 |
+
args.max_grad_norm,
|
| 493 |
+
)
|
| 494 |
+
optim.step()
|
| 495 |
+
|
| 496 |
+
# ---- Log ----
|
| 497 |
+
all_rewards = [r for g in group_stats for r in g[0]]
|
| 498 |
+
avg_r = sum(all_rewards) / max(1, len(all_rewards))
|
| 499 |
+
max_r = max(all_rewards)
|
| 500 |
+
min_r = min(all_rewards)
|
| 501 |
+
avg_loss = total_loss_val
|
| 502 |
+
avg_kl = total_kl_val / max(1, n_records)
|
| 503 |
+
n_traj = sum(len(g) for g in all_groups)
|
| 504 |
+
n_steps_in_traj = sum(len(t.steps) for g in all_groups for t in g)
|
| 505 |
+
avg_tokens = (
|
| 506 |
+
sum(t.submitted_tokens for g in all_groups for t in g)
|
| 507 |
+
/ max(1, n_traj)
|
| 508 |
+
)
|
| 509 |
+
avg_raw = (
|
| 510 |
+
sum(t.raw_task_score for g in all_groups for t in g)
|
| 511 |
+
/ max(1, n_traj)
|
| 512 |
+
)
|
| 513 |
+
elapsed = time.time() - step_t0
|
| 514 |
+
print(
|
| 515 |
+
f"step {step+1:3d}/{args.max_steps} "
|
| 516 |
+
f"avg_r={avg_r:+.3f} [{min_r:+.2f}..{max_r:+.2f}] "
|
| 517 |
+
f"raw={avg_raw:.2f} tokens={avg_tokens:.1f} "
|
| 518 |
+
f"n_traj={n_traj} n_turns={n_steps_in_traj} "
|
| 519 |
+
f"grp_skip={n_groups_skipped} "
|
| 520 |
+
f"loss={avg_loss:+.4f} kl={avg_kl:+.4f} "
|
| 521 |
+
f"{elapsed:.1f}s",
|
| 522 |
+
flush=True,
|
| 523 |
+
)
|
| 524 |
+
metrics.append({
|
| 525 |
+
"step": step + 1,
|
| 526 |
+
"avg_reward": avg_r,
|
| 527 |
+
"min_reward": min_r,
|
| 528 |
+
"max_reward": max_r,
|
| 529 |
+
"avg_raw_task_score": avg_raw,
|
| 530 |
+
"avg_submitted_tokens": avg_tokens,
|
| 531 |
+
"loss": avg_loss,
|
| 532 |
+
"kl": avg_kl,
|
| 533 |
+
"n_trajectories": n_traj,
|
| 534 |
+
"n_turns_total": n_steps_in_traj,
|
| 535 |
+
"n_groups_skipped": n_groups_skipped,
|
| 536 |
+
"elapsed_s": elapsed,
|
| 537 |
+
})
|
| 538 |
+
|
| 539 |
+
if args.save_every > 0 and (step + 1) % args.save_every == 0 \
|
| 540 |
+
and (step + 1) < args.max_steps:
|
| 541 |
+
ckpt = out_dir / f"checkpoint-{step+1}"
|
| 542 |
+
ckpt.mkdir(parents=True, exist_ok=True)
|
| 543 |
+
model.save_pretrained(str(ckpt))
|
| 544 |
+
(out_dir / "train_metrics.json").write_text(json.dumps(metrics, indent=2))
|
| 545 |
+
print(f" ckpt -> {ckpt.name}", flush=True)
|
| 546 |
+
|
| 547 |
+
train_elapsed = time.time() - t_train
|
| 548 |
+
print(f"\n=== training done in {train_elapsed/60:.1f} min ===", flush=True)
|
| 549 |
+
|
| 550 |
+
# ---- Save adapter + metrics ----
|
| 551 |
+
final_dir = out_dir / "adapter_final"
|
| 552 |
+
final_dir.mkdir(parents=True, exist_ok=True)
|
| 553 |
+
model.save_pretrained(str(final_dir))
|
| 554 |
+
tokenizer.save_pretrained(str(final_dir))
|
| 555 |
+
(out_dir / "train_metrics.json").write_text(json.dumps(metrics, indent=2))
|
| 556 |
+
print(f" adapter -> {final_dir}", flush=True)
|
| 557 |
+
print(f" metrics -> {out_dir / 'train_metrics.json'}", flush=True)
|
| 558 |
+
|
| 559 |
+
# ---- Push to hub ----
|
| 560 |
+
if args.push_to_hub:
|
| 561 |
+
from huggingface_hub import HfApi
|
| 562 |
+
api = HfApi()
|
| 563 |
+
api.create_repo(args.push_to_hub, exist_ok=True, repo_type="model")
|
| 564 |
+
api.upload_folder(
|
| 565 |
+
folder_path=str(final_dir),
|
| 566 |
+
repo_id=args.push_to_hub,
|
| 567 |
+
repo_type="model",
|
| 568 |
+
path_in_repo="adapter_final",
|
| 569 |
+
commit_message=f"multi-step GRPO adapter ({args.max_steps} steps, "
|
| 570 |
+
f"turn_limit={args.turn_limit}, "
|
| 571 |
+
f"thinking={args.enable_thinking})",
|
| 572 |
+
)
|
| 573 |
+
api.upload_file(
|
| 574 |
+
path_or_fileobj=str(out_dir / "train_metrics.json"),
|
| 575 |
+
path_in_repo="metrics/train_metrics_multistep.json",
|
| 576 |
+
repo_id=args.push_to_hub,
|
| 577 |
+
repo_type="model",
|
| 578 |
+
commit_message="multi-step GRPO metrics",
|
| 579 |
+
)
|
| 580 |
+
print(f"[push] uploaded to https://huggingface.co/{args.push_to_hub}",
|
| 581 |
+
flush=True)
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
if __name__ == "__main__":
|
| 585 |
+
main()
|