Don Rishabh Claude Opus 4.7 (1M context) commited on
Commit
67509ac
·
1 Parent(s): 3a1b533

v3: multi-turn env, thinking tokens, cross-family Qwen->Llama, multi-step GRPO

Browse files

Multi-turn:
- GolfObservation gains turn_number / turn_limit / prior_attempts.
- env.reset(turn_limit=N) splits the 6 held-out test examples into a
2-example feedback slice (revealed between turns) and a 4-example
scoring slice (only the FINAL turn is scored against these).
- build_agent_user_message folds prior attempts (prompt + score +
sample target outputs) into the agent's user message so it can
refine across turns.

Thinking tokens (Qwen3 only):
- --enable-thinking / --no-enable-thinking CLI flag on both train
and eval. Default ON (was OFF in v2). Llama models silently fall
back via the chat-template TypeError path.
- max_completion_length default 256 -> 768, max_new_tokens (eval)
256 -> 768 to fit the <think>...</think> block plus the final
prompt.
- extract_prompt already strips <think>...</think> defensively;
works regardless of mode.

Cross-family targeting:
- Default target flipped Qwen/Qwen3-1.7B -> meta-llama/Llama-3.2-3B-Instruct
across every training/eval/profile script.
- Agent stays Qwen3-1.7B (preserves thinking).
- Judge stays Qwen3-8B 8-bit (judge identity matters less).

Multi-step GRPO trainer (training/train_grpo_multistep.py):
- Hand-rolled trajectory-level GRPO mirroring the proven recipe in
spaces_pipeline_env/local_training/grpo_multistep.py. TRL's
GRPOTrainer is single-step; multi-turn needs custom rollouts.
- Rollout: model in the env loop at every turn, collecting per-turn
(prompt_ids, action_ids).
- REINFORCE + KL vs LoRA snapshot, group-relative advantages with
STD_FLOOR=0.1 / ADV_CLAMP=3.0.
- --sft-adapter warmstart support recommended (start from the
baseline single-step adapter).

Eval default seeds-per-task dropped from 3 to 1 — at temperature=0.0
the agent is deterministic so seeds>1 was producing bit-identical
duplicate rows.

README updated to document all of the above.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

README.md CHANGED
@@ -25,14 +25,18 @@ Prompt Golf is the missing environment.
25
 
26
  ## How It Works
27
 
28
- Each episode is one task and one step:
29
 
30
- 1. `reset(task="sentiment_basic")` → the env returns a task description, 3 visible train examples, a token budget, and the target model's zero-shot score on the held-out set.
31
- 2. The agent outputs a **prompt string** as its action.
32
- 3. The env prepends that prompt to each of ~6 held-out test inputs, runs the **frozen target LLM** (greedy decoding), and scores each output with a task-specific scorer.
33
- 4. `reward = raw_task_score × length_factor × leakage_penalty + 0.3 × max(0, gain_over_baseline) × length_factor`, clipped to [0, 1.3].
34
 
35
- The test inputs are **never shown to the agent**. An n-gram leakage detector scales the reward toward zero if the agent tries to paste answers into its prompt.
 
 
 
 
 
 
 
36
 
37
  ## Quick Start
38
 
@@ -70,80 +74,152 @@ uvicorn server.app:app --port 8000
70
 
71
  ## Task Bank
72
 
73
- Ships with **19 tasks across 8 categories**:
74
 
75
- | Category | Tasks | Scorer |
76
- |---|---|---|
77
- | classification | sentiment, sentiment_nuanced, topic_news, toxicity_detect, intent_support | exact_label |
78
- | extraction | ner_people, json_contact, number_extract | contains / json / numeric |
79
- | format | three_bullets, uppercase, json_object | structural |
80
- | arithmetic | word problems, percent change | numeric |
81
- | translation | greetings (EN→FR), numbers (EN→ES) | token-F1 |
82
- | style | formal rewrite, concise rewrite | keyword coverage |
83
- | reasoning | quantity comparison, event ordering | exact_label |
84
- | refusal | make target decline unsafe requests | refusal detector |
85
 
86
  Each task has:
87
- - 2–3 visible train examples in the observation
88
- - 6 hidden test examples used for scoring
89
- - A per-task token budget (30100 tokens)
90
 
91
- New tasks drop into `server/tasks.py`.
 
 
92
 
93
  ## Reward Components
94
 
95
- | Component | Range | Purpose |
96
- |---|---|---|
97
- | `raw_task_score` | [0, 1] | Mean scorer output on held-out test set |
98
- | `length_factor` | (0, 1] | 1.0 within budget, decays exponentially past it |
99
- | `leakage_penalty` | [0, 1] | Scales toward 0 when prompt leaks held-out n-grams |
100
- | `gain_over_baseline` | [-baseline, 1-baseline] | Delta vs. target's zero-shot score |
101
 
102
- Final reward:
103
  ```
104
- base = raw_task_score × length_factor × leakage_penalty
105
- bonus = max(0, gain_over_baseline) × length_factor × 0.3
106
- reward = clip(base + bonus, 0.0, 1.3)
 
 
 
 
107
  ```
108
 
109
- ## Target Model
 
 
 
 
 
 
110
 
111
- Frozen for the lifetime of the process. Defaults to `Qwen/Qwen2.5-0.5B-Instruct` — small enough to run on a T4, strong enough to reward good prompting. Override with `PROMPT_GOLF_TARGET_MODEL`.
112
 
113
- For CPU / CI, set `PROMPT_GOLF_TARGET_BACKEND=mock` to use a deterministic pattern-based fake target that lets the env boot without loading a model.
 
 
 
 
 
 
 
 
 
 
114
 
115
  ## Training
116
 
117
- Designed for **TRL GRPO** out of the box. The agent's action is a free-form string, matching GRPO's typical setup. Recommended starting config:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
- - Agent: Qwen/Qwen2.5-1.5B-Instruct (trainable)
120
- - Target: Qwen/Qwen2.5-0.5B-Instruct (frozen, smaller than agent so reward signal is informative)
121
- - `num_generations=8`, `learning_rate=5e-6`, `beta=0.04`
122
- - 500–1000 steps with a budget curriculum (start loose, tighten over training)
123
 
124
- Plots to watch:
125
- - **Mean reward per step** — should climb from ~baseline toward ~1.0
126
- - **Mean prompt tokens** — the "compression" story; should drop from hundreds to tens
127
- - **Per-category accuracy** — generalization across task types
128
- - **Baseline-normalized gain** — how much the agent's prompt beats zero-shot
129
 
130
  ## Files
131
 
132
  ```
133
  prompt_golf_env/
134
- openenv.yaml # spec manifest
135
- models.py # GolfAction, GolfObservation, constants
136
- client.py # PromptGolfEnv (EnvClient subclass)
 
137
  pyproject.toml
138
  server/
139
- app.py # FastAPI app
140
- prompt_golf_environment.py # core Env: reset/step
141
- target_model.py # frozen-target wrapper (HF + mock backends)
142
- scorer.py # per-scorer implementations
143
- tasks.py # 19-task bank
144
- rubrics.py # reward composition
 
 
 
145
  Dockerfile
146
  requirements.txt
 
 
 
 
 
 
 
 
 
147
  ```
148
 
149
  ## Why This Environment
 
25
 
26
  ## How It Works
27
 
28
+ Each episode is one task. By default it's one step (single-turn). With `turn_limit > 1` it becomes multi-turn — the agent submits a prompt, sees how it performed on a feedback slice, and refines.
29
 
30
+ **Single-turn (default):**
 
 
 
31
 
32
+ 1. `reset(task="sentiment_basic")` env returns task description, 3 visible train examples, token budget, and target's empty-prompt baseline.
33
+ 2. Agent outputs a **prompt string** as its action.
34
+ 3. Env prepends the prompt to each of 6 held-out test inputs, runs the **frozen target LLM**, scores each output with the task scorer.
35
+ 4. `reward = raw_task_score − 0.5·baseline_zero_shot − 0.002·tokens − leakage_overlap²`, clipped to `[−0.5, 1.3]`.
36
+
37
+ **Multi-turn (`turn_limit > 1`):** the 6 held-out examples are split into `feedback_ex` (2 examples revealed to the agent between turns with the target's actual output) and `scoring_ex` (4 examples that only the **final-turn** prompt is scored against). This lets the agent debug its own prompt across turns without leaking the inputs that ultimately judge it.
38
+
39
+ The test inputs are **never shown to the agent** in single-turn mode; in multi-turn the agent sees only the feedback slice's inputs/outputs. An n-gram leakage detector scales the reward toward zero if the agent tries to paste held-out inputs into its prompt.
40
 
41
  ## Quick Start
42
 
 
74
 
75
  ## Task Bank
76
 
77
+ Ships with **87 tasks** across three banks:
78
 
79
+ | Bank | Count | Where | Difficulty |
80
+ |---|---|---|---|
81
+ | v1 (`tasks.py`) | 20 | classification, extraction, format, arithmetic, translation, style, reasoning, refusal | easy / medium |
82
+ | v2 (`tasks_v2.py`) | 15 | acrostic, no-letter-e, yaml depth, json key order, pirate persona, Shakespearean, terminal output, etc. | hard |
83
+ | tough (`tasks_tough.py`) | 52 | classification_tough (10), extraction_tough (10), format_tough (8), persona_tough (8), reasoning_tough (10), adversarial_tough (6) | hard |
84
+
85
+ The "tough" bank was hand-crafted so the **minimum effective prompt is non-obvious**: the verbose hand-written prompt for each tough task is 200-300 tokens, but the target can be steered into the right format with a much shorter compressed prompt — that gap is what training is supposed to close.
 
 
 
86
 
87
  Each task has:
88
+ - 2–3 visible train examples shown to the agent in the observation
89
+ - 6 hidden test examples used for scoring (split into 2 feedback + 4 scoring in multi-turn mode)
90
+ - A per-task token budget (60250 tokens depending on difficulty)
91
 
92
+ Scorers: `exact_label`, `contains_all_substrings`, `numeric_match`, `json_contains_fields`, `valid_json_object`, `valid_yaml_depth`, `acrostic_match`, `avoid_letter`, `three_bullets`, `word_count_exact`, `stepwise_math`, `terminal_output_pattern`, `judge_criteria` (Qwen3-8B 8-bit judge), `judge_vs_expected`, `refusal_score`, etc. — all in `server/scorer.py`.
93
+
94
+ New tasks drop into the appropriate bank file.
95
 
96
  ## Reward Components
97
 
98
+ The rubric is **additive** (v3) for smoother gradients than the original multiplicative form:
 
 
 
 
 
99
 
 
100
  ```
101
+ reward = raw_task_score
102
+ BASELINE_SUBTRACT · baseline_zero_shot_score
103
+ LAMBDA_LEN · submitted_tokens
104
+ − LAMBDA_LEAK · leakage_overlap²
105
+ − short_penalty (if tokens < MIN_TOKENS_FLOOR)
106
+
107
+ clipped to [REWARD_CLIP_LOW, REWARD_CLIP_HIGH] = [-0.5, 1.3]
108
  ```
109
 
110
+ Defaults (`server/rubrics.py`):
111
+ - `LAMBDA_LEN = 0.002` — soft length penalty; ~0.1 cost on a 50-token prompt
112
+ - `LAMBDA_LEAK = 1.0` — full reward wiped at saturation overlap
113
+ - `BASELINE_SUBTRACT = 0.5` — partially normalize against the target's natural ability
114
+ - `MIN_TOKENS_FLOOR = 5`, `MIN_TOKENS_PENALTY = 0.25` — anti-collapse guard against degenerate 1-token prompts
115
+
116
+ Legacy `length_factor` and `leakage_penalty` fields are still emitted on the observation for plot continuity but are no longer multiplicatively composed.
117
 
118
+ ## Models (Cross-Family Setup)
119
 
120
+ We deliberately pair a **Qwen agent** with a **Llama target** testing whether prompt golf transfers across model families:
121
+
122
+ | Role | Default | Why |
123
+ |---|---|---|
124
+ | Agent (trainable) | `Qwen/Qwen3-1.7B` | Preserves Qwen3's `<think>...</think>` reasoning mode — the agent gets free reasoning scratch space (only the extracted final prompt counts toward the length-budget rubric). |
125
+ | Target (frozen) | `meta-llama/Llama-3.2-3B-Instruct` | The model the agent's prompts must steer. Different family = the agent has to learn Llama's idiosyncrasies (chat-template quirks, format preferences, refusal patterns) rather than its own. |
126
+ | Judge | `Qwen/Qwen3-8B` (8-bit via bitsandbytes, ~8 GB VRAM) | Used by `judge_criteria` / `judge_vs_expected` scorers. Identity matters less; kept on Qwen to avoid re-tuning the judge prompt. |
127
+
128
+ Override with `PROMPT_GOLF_TARGET_MODEL`, `PROMPT_GOLF_JUDGE_MODEL`. Disable judge quantization with `PROMPT_GOLF_JUDGE_NO_QUANT=1`. CPU/CI: `PROMPT_GOLF_TARGET_BACKEND=mock` and `PROMPT_GOLF_JUDGE_BACKEND=mock`.
129
+
130
+ > **Note:** Llama-3.2 requires accepting the license on HuggingFace. Make sure your `HF_TOKEN` has access before launching.
131
 
132
  ## Training
133
 
134
+ Two trainers ship in `training/`:
135
+
136
+ ### Single-step GRPO (`train_grpo.py`)
137
+
138
+ Standard TRL GRPOTrainer. Treats each task as a single decision (one prompt → one reward). Recommended starting config:
139
+
140
+ - Agent: `Qwen/Qwen3-1.7B` (trainable, LoRA)
141
+ - Target: `meta-llama/Llama-3.2-3B-Instruct` (frozen)
142
+ - `num_generations=8`, `learning_rate=5e-6`, `beta=0.04`, `temperature=0.9`
143
+ - `max_completion_length=768` (Qwen3 thinking ON by default; pass `--no-enable-thinking` to drop back to 256)
144
+ - 500 steps × 87 tasks × 4 seeds = ~140-200 min on L40S with judge co-resident
145
+
146
+ Launch via `training/hf_job_train.sh` for HuggingFace Jobs.
147
+
148
+ ### Multi-step GRPO (`train_grpo_multistep.py`)
149
+
150
+ Hand-rolled trajectory-level GRPO (mirrors the proven recipe from `spaces_pipeline_env/local_training/grpo_multistep.py`). Required when `turn_limit > 1` because TRL's GRPOTrainer doesn't natively support multi-step rollouts.
151
+
152
+ - Custom rollout: model generates at every env turn, collecting `(prompt_ids, action_ids)` per step
153
+ - Group-relative advantages with `STD_FLOOR=0.1`, `ADV_CLAMP=3.0`
154
+ - REINFORCE + KL vs frozen LoRA snapshot (snapshotted at start, swapped in for ref logp computation)
155
+ - Recommended: `--sft-adapter` warmstart from the single-step adapter — RL on a fresh policy diverges easily
156
+
157
+ Launch via `training/hf_job_train_multistep.sh`.
158
+
159
+ ### Pre-flight: capability profiling
160
+
161
+ Before committing GPU hours to a 500-step run, verify the target is capable on each task:
162
+
163
+ ```bash
164
+ TARGET_MODEL=Qwen/Qwen3-1.7B bash training/hf_job_profile.sh
165
+ ```
166
+
167
+ This runs the target with each task's verbose hand-written description and dumps `description_baseline` per task. Use the output to decide whether to keep the target, bump to a larger one, or filter dead-baseline tasks.
168
+
169
+ ### Eval + demo CSV
170
+
171
+ After training, generate the side-by-side demo CSV with `verbose_prompt`, `base_prompt` (untrained), `trained_prompt` columns plus per-row accuracy/reward:
172
+
173
+ ```bash
174
+ python training/eval_before_after.py --label base --output-json outputs/eval_base.jsonl
175
+ python training/eval_before_after.py --label trained --adapter <repo>/adapter_final \
176
+ --output-json outputs/eval_trained.jsonl
177
+
178
+ python training/build_before_after_csv.py \
179
+ --base-jsonl outputs/eval_base.jsonl \
180
+ --trained-jsonl outputs/eval_trained.jsonl \
181
+ --verbose-profile-csv outputs/baseline_profile.csv \
182
+ --output-csv outputs/before_after_prompts.csv
183
+ ```
184
 
185
+ ### Plots to watch
 
 
 
186
 
187
+ - **Mean reward per step** — should drift up; typical 500-step run reaches +0.3–0.5
188
+ - **Mean prompt tokens** — the compression story; drops from hundreds to tens
189
+ - **Per-category accuracy** — generalization across task families
190
+ - **Length factor / leakage penalty** — diagnostic signals (legacy multiplicative form)
191
+ - **`frac_reward_zero_std`** — fraction of GRPO groups with no intra-group variance; high means many tasks have flat baselines and contribute no gradient
192
 
193
  ## Files
194
 
195
  ```
196
  prompt_golf_env/
197
+ openenv.yaml # spec manifest
198
+ models.py # GolfAction, GolfObservation, constants
199
+ # (turn_limit, prior_attempts, multi-turn split sizes)
200
+ client.py # PromptGolfEnv (EnvClient subclass)
201
  pyproject.toml
202
  server/
203
+ app.py # FastAPI app
204
+ prompt_golf_environment.py # core Env: reset/step (single + multi-turn)
205
+ target_model.py # frozen-target wrapper (HF + mock backends)
206
+ scorer.py # 21+ scorers (structural + LLM judge)
207
+ judge.py # Qwen3-8B 8-bit judge backend
208
+ tasks.py # 20-task v1 bank
209
+ tasks_v2.py # 15-task v2 hard bank
210
+ tasks_tough.py # 52-task tough bank (6 categories)
211
+ rubrics.py # additive reward composition
212
  Dockerfile
213
  requirements.txt
214
+ training/
215
+ train_grpo.py # single-step TRL GRPO
216
+ train_grpo_multistep.py # trajectory-level GRPO (multi-turn)
217
+ eval_before_after.py # base + trained eval JSONL writer
218
+ profile_baseline.py # per-task target capability profiler
219
+ build_before_after_csv.py # demo CSV merger (verbose / base / trained)
220
+ hf_job_train.sh # single-step trainer launcher
221
+ hf_job_train_multistep.sh # multi-step trainer launcher
222
+ hf_job_profile.sh # profile launcher
223
  ```
224
 
225
  ## Why This Environment
models.py CHANGED
@@ -85,6 +85,18 @@ TEST_EXAMPLES_PER_EPISODE: int = 6
85
  # Number of visible train examples shown to the agent in the observation.
86
  TRAIN_EXAMPLES_VISIBLE: int = 3
87
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  # ---------------------------------------------------------------------------
90
  # Action
@@ -212,3 +224,30 @@ class GolfObservation(Observation):
212
  "held-out set, for debugging / demo. Only populated at step."
213
  ),
214
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  # Number of visible train examples shown to the agent in the observation.
86
  TRAIN_EXAMPLES_VISIBLE: int = 3
87
 
88
+ # --- Multi-turn ---
89
+ # When turn_limit > 1, the test pool is split:
90
+ # - first MULTITURN_FEEDBACK_EXAMPLES are shown to the agent between
91
+ # turns (target outputs revealed) so it can refine its prompt
92
+ # - the remaining MULTITURN_SCORING_EXAMPLES score ONLY the final turn
93
+ # This prevents the agent from overfitting its prompt to outputs it will
94
+ # also be scored on. Single-turn (default) skips the split and scores on
95
+ # the full TEST_EXAMPLES_PER_EPISODE slice, preserving v2 behavior.
96
+ MULTITURN_FEEDBACK_EXAMPLES: int = 2
97
+ MULTITURN_SCORING_EXAMPLES: int = 4
98
+ DEFAULT_TURN_LIMIT: int = 1
99
+
100
 
101
  # ---------------------------------------------------------------------------
102
  # Action
 
224
  "held-out set, for debugging / demo. Only populated at step."
225
  ),
226
  )
227
+
228
+ # --- Multi-turn fields (single-turn episodes leave these at defaults) ---
229
+ turn_number: int = Field(
230
+ default=1,
231
+ description=(
232
+ "1-indexed current turn within the episode. Always 1 for "
233
+ "single-turn (turn_limit=1) episodes."
234
+ ),
235
+ )
236
+ turn_limit: int = Field(
237
+ default=DEFAULT_TURN_LIMIT,
238
+ description=(
239
+ "Total turns the agent has in this episode. Set via "
240
+ "reset(turn_limit=N). When turn_number==turn_limit, the "
241
+ "next step() will be terminal and scored on the held-out "
242
+ "scoring slice."
243
+ ),
244
+ )
245
+ prior_attempts: List[Dict[str, Any]] = Field(
246
+ default_factory=list,
247
+ description=(
248
+ "History of attempts in this episode (only populated on "
249
+ "non-terminal observations during multi-turn). Each entry: "
250
+ "{prompt, tokens, feedback_score, sample_generations}. The "
251
+ "agent uses these to refine its prompt for the next turn."
252
+ ),
253
+ )
server/prompt_golf_environment.py CHANGED
@@ -37,7 +37,10 @@ from openenv.core.env_server.types import State
37
  try:
38
  from ..models import (
39
  DEFAULT_PROMPT_BUDGET,
 
40
  MAX_TARGET_OUTPUT_TOKENS,
 
 
41
  TEST_EXAMPLES_PER_EPISODE,
42
  TRAIN_EXAMPLES_VISIBLE,
43
  GolfAction,
@@ -52,7 +55,10 @@ try:
52
  except ImportError:
53
  from models import (
54
  DEFAULT_PROMPT_BUDGET,
 
55
  MAX_TARGET_OUTPUT_TOKENS,
 
 
56
  TEST_EXAMPLES_PER_EPISODE,
57
  TRAIN_EXAMPLES_VISIBLE,
58
  GolfAction,
@@ -97,9 +103,17 @@ class PromptGolfEnvironment(Environment):
97
  # Resampled every reset
98
  self._train_ex: List[tuple[str, str]] = []
99
  self._test_ex: List[tuple[str, str]] = []
 
 
 
100
  # Cached per-episode baseline (target with empty prompt)
101
  self._baseline_zero_shot: float = 0.0
102
 
 
 
 
 
 
103
  # Reward rubric (stateless per episode)
104
  self._rubric = PromptGolfRubric()
105
 
@@ -115,12 +129,18 @@ class PromptGolfEnvironment(Environment):
115
  seed: Optional[int] = None,
116
  episode_id: Optional[str] = None,
117
  task: Optional[str] = None,
 
118
  ) -> GolfObservation:
119
  self._state = State(episode_id=episode_id or str(uuid4()), step_count=0)
120
  self._rng = random.Random(seed) if seed is not None else random.Random()
121
 
122
  self._task = self._choose_task(task)
123
 
 
 
 
 
 
124
  # Sample visible train examples (stable for this episode)
125
  train_pool = list(self._task.train_examples)
126
  self._rng.shuffle(train_pool)
@@ -133,6 +153,23 @@ class PromptGolfEnvironment(Environment):
133
  self._rng.shuffle(test_pool)
134
  self._test_ex = test_pool[:TEST_EXAMPLES_PER_EPISODE]
135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  # Compute (or reuse) baseline for this task with empty prompt
137
  cache_key = (self._target.model_id, self._task.task_id)
138
  if cache_key not in _BASELINE_CACHE:
@@ -146,7 +183,10 @@ class PromptGolfEnvironment(Environment):
146
  target_model_id=self._target.model_id,
147
  prompt_budget_tokens=self._task.budget_tokens or DEFAULT_PROMPT_BUDGET,
148
  max_target_output_tokens=MAX_TARGET_OUTPUT_TOKENS,
149
- num_test_examples=len(self._test_ex),
 
 
 
150
  train_examples=[
151
  {"input": x, "expected": y} for (x, y) in self._train_ex
152
  ],
@@ -154,6 +194,9 @@ class PromptGolfEnvironment(Environment):
154
  baseline_zero_shot_score=round(self._baseline_zero_shot, 4),
155
  done=False,
156
  reward=0.0,
 
 
 
157
  metadata={
158
  "task_difficulty": self._task.difficulty,
159
  "task_tags": list(self._task.tags),
@@ -170,18 +213,67 @@ class PromptGolfEnvironment(Environment):
170
  if self._task is None:
171
  raise RuntimeError("step() called before reset()")
172
 
 
 
 
 
173
  # Truncate prompt to the task's budget (in target tokens).
174
  budget = self._task.budget_tokens or DEFAULT_PROMPT_BUDGET
175
  truncated_prompt = self._target.truncate_to_tokens(action.prompt, budget)
176
  submitted_tokens = self._target.count_prompt_tokens(truncated_prompt)
177
 
178
- # Score the prompt.
 
 
 
 
 
 
 
 
179
  raw_task_score, sample_gens = self._score_prompt(
180
- truncated_prompt, return_samples=True
181
  )
182
 
183
- # Apply rubric.
184
- held_out_inputs = [x for x, _ in self._test_ex]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  result = self._rubric.grade(
186
  raw_task_score=raw_task_score,
187
  baseline_zero_shot_score=self._baseline_zero_shot,
@@ -192,8 +284,6 @@ class PromptGolfEnvironment(Environment):
192
  )
193
  details = grade_details_dict(result, task_id=self._task.task_id)
194
 
195
- # Build terminal observation. We re-emit the task framing so the
196
- # agent/trainer has a self-contained record of the episode.
197
  return GolfObservation(
198
  task_id=self._task.task_id,
199
  task_category=self._task.category,
@@ -201,7 +291,7 @@ class PromptGolfEnvironment(Environment):
201
  target_model_id=self._target.model_id,
202
  prompt_budget_tokens=budget,
203
  max_target_output_tokens=MAX_TARGET_OUTPUT_TOKENS,
204
- num_test_examples=len(self._test_ex),
205
  train_examples=[
206
  {"input": x, "expected": y} for (x, y) in self._train_ex
207
  ],
@@ -216,6 +306,9 @@ class PromptGolfEnvironment(Environment):
216
  sample_generations=sample_gens,
217
  done=True,
218
  reward=round(result.reward, 4),
 
 
 
219
  metadata={
220
  "task_difficulty": self._task.difficulty,
221
  "task_tags": list(self._task.tags),
@@ -244,15 +337,23 @@ class PromptGolfEnvironment(Environment):
244
  return _ALL_TASKS[task_id]
245
 
246
  def _score_prompt(
247
- self, prompt: str, return_samples: bool = False
 
 
 
248
  ) -> float | tuple[float, list]:
249
  """Run target on test inputs with `prompt`, score each output,
250
  return mean score. Optionally also return up to 2 sample triples
251
  for debugging.
 
 
 
 
252
  """
253
  assert self._task is not None
254
- test_inputs = [x for x, _ in self._test_ex]
255
- test_expected = [y for _, y in self._test_ex]
 
256
 
257
  generations: List[TargetGeneration] = self._target.generate_batch(
258
  prompt=prompt,
 
37
  try:
38
  from ..models import (
39
  DEFAULT_PROMPT_BUDGET,
40
+ DEFAULT_TURN_LIMIT,
41
  MAX_TARGET_OUTPUT_TOKENS,
42
+ MULTITURN_FEEDBACK_EXAMPLES,
43
+ MULTITURN_SCORING_EXAMPLES,
44
  TEST_EXAMPLES_PER_EPISODE,
45
  TRAIN_EXAMPLES_VISIBLE,
46
  GolfAction,
 
55
  except ImportError:
56
  from models import (
57
  DEFAULT_PROMPT_BUDGET,
58
+ DEFAULT_TURN_LIMIT,
59
  MAX_TARGET_OUTPUT_TOKENS,
60
+ MULTITURN_FEEDBACK_EXAMPLES,
61
+ MULTITURN_SCORING_EXAMPLES,
62
  TEST_EXAMPLES_PER_EPISODE,
63
  TRAIN_EXAMPLES_VISIBLE,
64
  GolfAction,
 
103
  # Resampled every reset
104
  self._train_ex: List[tuple[str, str]] = []
105
  self._test_ex: List[tuple[str, str]] = []
106
+ # Multi-turn slices (only populated when turn_limit > 1)
107
+ self._feedback_ex: List[tuple[str, str]] = []
108
+ self._scoring_ex: List[tuple[str, str]] = []
109
  # Cached per-episode baseline (target with empty prompt)
110
  self._baseline_zero_shot: float = 0.0
111
 
112
+ # Multi-turn state (single-turn defaults preserve v2 behavior)
113
+ self._turn_count: int = 0
114
+ self._turn_limit: int = DEFAULT_TURN_LIMIT
115
+ self._prior_attempts: List[dict] = []
116
+
117
  # Reward rubric (stateless per episode)
118
  self._rubric = PromptGolfRubric()
119
 
 
129
  seed: Optional[int] = None,
130
  episode_id: Optional[str] = None,
131
  task: Optional[str] = None,
132
+ turn_limit: int = DEFAULT_TURN_LIMIT,
133
  ) -> GolfObservation:
134
  self._state = State(episode_id=episode_id or str(uuid4()), step_count=0)
135
  self._rng = random.Random(seed) if seed is not None else random.Random()
136
 
137
  self._task = self._choose_task(task)
138
 
139
+ # Reset multi-turn state
140
+ self._turn_count = 0
141
+ self._turn_limit = max(1, int(turn_limit))
142
+ self._prior_attempts = []
143
+
144
  # Sample visible train examples (stable for this episode)
145
  train_pool = list(self._task.train_examples)
146
  self._rng.shuffle(train_pool)
 
153
  self._rng.shuffle(test_pool)
154
  self._test_ex = test_pool[:TEST_EXAMPLES_PER_EPISODE]
155
 
156
+ # Multi-turn split: feedback slice (revealed between turns) vs
157
+ # scoring slice (only ever scored on the FINAL turn). Single-turn
158
+ # episodes leave both empty and use _test_ex as before.
159
+ if self._turn_limit > 1:
160
+ self._feedback_ex = self._test_ex[:MULTITURN_FEEDBACK_EXAMPLES]
161
+ self._scoring_ex = self._test_ex[
162
+ MULTITURN_FEEDBACK_EXAMPLES:
163
+ MULTITURN_FEEDBACK_EXAMPLES + MULTITURN_SCORING_EXAMPLES
164
+ ]
165
+ # Guarantee a non-empty scoring slice even on tasks with few
166
+ # test examples — fall back to the full slice.
167
+ if not self._scoring_ex:
168
+ self._scoring_ex = list(self._test_ex)
169
+ else:
170
+ self._feedback_ex = []
171
+ self._scoring_ex = []
172
+
173
  # Compute (or reuse) baseline for this task with empty prompt
174
  cache_key = (self._target.model_id, self._task.task_id)
175
  if cache_key not in _BASELINE_CACHE:
 
183
  target_model_id=self._target.model_id,
184
  prompt_budget_tokens=self._task.budget_tokens or DEFAULT_PROMPT_BUDGET,
185
  max_target_output_tokens=MAX_TARGET_OUTPUT_TOKENS,
186
+ num_test_examples=(
187
+ len(self._scoring_ex) if self._turn_limit > 1
188
+ else len(self._test_ex)
189
+ ),
190
  train_examples=[
191
  {"input": x, "expected": y} for (x, y) in self._train_ex
192
  ],
 
194
  baseline_zero_shot_score=round(self._baseline_zero_shot, 4),
195
  done=False,
196
  reward=0.0,
197
+ turn_number=1,
198
+ turn_limit=self._turn_limit,
199
+ prior_attempts=[],
200
  metadata={
201
  "task_difficulty": self._task.difficulty,
202
  "task_tags": list(self._task.tags),
 
213
  if self._task is None:
214
  raise RuntimeError("step() called before reset()")
215
 
216
+ # Bump turn counter; `is_final_turn` decides scoring slice + done-flag.
217
+ self._turn_count += 1
218
+ is_final_turn = self._turn_count >= self._turn_limit
219
+
220
  # Truncate prompt to the task's budget (in target tokens).
221
  budget = self._task.budget_tokens or DEFAULT_PROMPT_BUDGET
222
  truncated_prompt = self._target.truncate_to_tokens(action.prompt, budget)
223
  submitted_tokens = self._target.count_prompt_tokens(truncated_prompt)
224
 
225
+ # Pick the scoring slice for THIS turn:
226
+ # - single-turn (turn_limit=1): score on the full _test_ex (v2 behavior)
227
+ # - multi-turn non-final: score on _feedback_ex (cheap, revealed to agent)
228
+ # - multi-turn final: score on _scoring_ex (held-out, drives reward)
229
+ if self._turn_limit > 1:
230
+ scoring_slice = self._scoring_ex if is_final_turn else self._feedback_ex
231
+ else:
232
+ scoring_slice = self._test_ex
233
+
234
  raw_task_score, sample_gens = self._score_prompt(
235
+ truncated_prompt, return_samples=True, examples=scoring_slice,
236
  )
237
 
238
+ # ----- Non-final turn in multi-turn: return feedback obs (done=False) -----
239
+ if not is_final_turn:
240
+ self._prior_attempts.append({
241
+ "turn": self._turn_count,
242
+ "prompt": truncated_prompt,
243
+ "tokens": submitted_tokens,
244
+ "feedback_score": round(raw_task_score, 4),
245
+ "sample_generations": sample_gens,
246
+ })
247
+ return GolfObservation(
248
+ task_id=self._task.task_id,
249
+ task_category=self._task.category,
250
+ task_description=self._task.description,
251
+ target_model_id=self._target.model_id,
252
+ prompt_budget_tokens=budget,
253
+ max_target_output_tokens=MAX_TARGET_OUTPUT_TOKENS,
254
+ num_test_examples=len(self._scoring_ex),
255
+ train_examples=[
256
+ {"input": x, "expected": y} for (x, y) in self._train_ex
257
+ ],
258
+ scorer_name=self._task.scorer,
259
+ baseline_zero_shot_score=round(self._baseline_zero_shot, 4),
260
+ submitted_prompt_tokens=submitted_tokens,
261
+ raw_task_score=round(raw_task_score, 4), # on feedback slice
262
+ sample_generations=sample_gens,
263
+ done=False,
264
+ reward=0.0, # no reward until terminal
265
+ turn_number=self._turn_count + 1, # next turn
266
+ turn_limit=self._turn_limit,
267
+ prior_attempts=list(self._prior_attempts),
268
+ metadata={
269
+ "task_difficulty": self._task.difficulty,
270
+ "task_tags": list(self._task.tags),
271
+ "is_intermediate_feedback": True,
272
+ },
273
+ )
274
+
275
+ # ----- Final (or single-turn): apply rubric, return terminal obs -----
276
+ held_out_inputs = [x for x, _ in scoring_slice]
277
  result = self._rubric.grade(
278
  raw_task_score=raw_task_score,
279
  baseline_zero_shot_score=self._baseline_zero_shot,
 
284
  )
285
  details = grade_details_dict(result, task_id=self._task.task_id)
286
 
 
 
287
  return GolfObservation(
288
  task_id=self._task.task_id,
289
  task_category=self._task.category,
 
291
  target_model_id=self._target.model_id,
292
  prompt_budget_tokens=budget,
293
  max_target_output_tokens=MAX_TARGET_OUTPUT_TOKENS,
294
+ num_test_examples=len(scoring_slice),
295
  train_examples=[
296
  {"input": x, "expected": y} for (x, y) in self._train_ex
297
  ],
 
306
  sample_generations=sample_gens,
307
  done=True,
308
  reward=round(result.reward, 4),
309
+ turn_number=self._turn_count,
310
+ turn_limit=self._turn_limit,
311
+ prior_attempts=list(self._prior_attempts),
312
  metadata={
313
  "task_difficulty": self._task.difficulty,
314
  "task_tags": list(self._task.tags),
 
337
  return _ALL_TASKS[task_id]
338
 
339
  def _score_prompt(
340
+ self,
341
+ prompt: str,
342
+ return_samples: bool = False,
343
+ examples: Optional[List[tuple[str, str]]] = None,
344
  ) -> float | tuple[float, list]:
345
  """Run target on test inputs with `prompt`, score each output,
346
  return mean score. Optionally also return up to 2 sample triples
347
  for debugging.
348
+
349
+ `examples` overrides the default `self._test_ex` slice — used by
350
+ multi-turn step() to score against the feedback or scoring slice
351
+ rather than the full pool.
352
  """
353
  assert self._task is not None
354
+ ex_pool = examples if examples is not None else self._test_ex
355
+ test_inputs = [x for x, _ in ex_pool]
356
+ test_expected = [y for _, y in ex_pool]
357
 
358
  generations: List[TargetGeneration] = self._target.generate_batch(
359
  prompt=prompt,
training/build_before_after_csv.py CHANGED
@@ -48,7 +48,7 @@ def parse_args() -> argparse.Namespace:
48
  "verbose_accuracy (target's accuracy when given "
49
  "the hand-written description as the prompt). "
50
  "If omitted, verbose_accuracy is left blank.")
51
- p.add_argument("--target-model", default="Qwen/Qwen3-1.7B",
52
  help="Used to count tokens of the verbose description.")
53
  p.add_argument("--output-csv", default="outputs/before_after_prompts.csv")
54
  p.add_argument("--push-to-hub", default=None,
 
48
  "verbose_accuracy (target's accuracy when given "
49
  "the hand-written description as the prompt). "
50
  "If omitted, verbose_accuracy is left blank.")
51
+ p.add_argument("--target-model", default="meta-llama/Llama-3.2-3B-Instruct",
52
  help="Used to count tokens of the verbose description.")
53
  p.add_argument("--output-csv", default="outputs/before_after_prompts.csv")
54
  p.add_argument("--push-to-hub", default=None,
training/eval_before_after.py CHANGED
@@ -42,7 +42,7 @@ def parse_args() -> argparse.Namespace:
42
  p.add_argument("--agent-model", default="Qwen/Qwen3-1.7B")
43
  p.add_argument("--adapter", default=None,
44
  help="Optional LoRA adapter dir or HF repo id.")
45
- p.add_argument("--target-model", default="Qwen/Qwen3-1.7B")
46
  p.add_argument("--tasks", default="all",
47
  help="'all' or comma-separated task ids.")
48
  p.add_argument("--seeds-per-task", type=int, default=1,
@@ -53,7 +53,16 @@ def parse_args() -> argparse.Namespace:
53
  p.add_argument("--output-json", default="outputs/eval_results.jsonl")
54
  p.add_argument("--label", default="base",
55
  help="Label to tag this eval run (e.g. 'base', 'trained').")
56
- p.add_argument("--max-new-tokens", type=int, default=256)
 
 
 
 
 
 
 
 
 
57
  p.add_argument("--temperature", type=float, default=0.0)
58
  p.add_argument("--push-to-hub", default=None,
59
  help="HF model repo id to upload the eval JSONL under evals/eval_<label>.jsonl.")
@@ -86,16 +95,19 @@ def load_agent(agent_model: str, adapter: str | None):
86
  return model, tok
87
 
88
 
89
- def build_chat_string(tok, obs) -> str:
90
  messages = [
91
  {"role": "system", "content": SYSTEM_PROMPT},
92
  {"role": "user", "content": build_agent_user_message(obs)},
93
  ]
94
  if getattr(tok, "chat_template", None):
95
  try:
 
 
 
96
  return tok.apply_chat_template(
97
  messages, tokenize=False, add_generation_prompt=True,
98
- enable_thinking=False,
99
  )
100
  except TypeError:
101
  return tok.apply_chat_template(
@@ -157,7 +169,7 @@ def main() -> None:
157
  for task_id in task_ids:
158
  for seed in range(args.seeds_per_task):
159
  obs = env.reset(task=task_id, seed=seed)
160
- chat_str = build_chat_string(tok, obs)
161
  agent_prompt = generate_prompt(
162
  model, tok, chat_str,
163
  max_new_tokens=args.max_new_tokens,
 
42
  p.add_argument("--agent-model", default="Qwen/Qwen3-1.7B")
43
  p.add_argument("--adapter", default=None,
44
  help="Optional LoRA adapter dir or HF repo id.")
45
+ p.add_argument("--target-model", default="meta-llama/Llama-3.2-3B-Instruct")
46
  p.add_argument("--tasks", default="all",
47
  help="'all' or comma-separated task ids.")
48
  p.add_argument("--seeds-per-task", type=int, default=1,
 
53
  p.add_argument("--output-json", default="outputs/eval_results.jsonl")
54
  p.add_argument("--label", default="base",
55
  help="Label to tag this eval run (e.g. 'base', 'trained').")
56
+ p.add_argument("--max-new-tokens", type=int, default=768,
57
+ help="Bumped from 256 to fit Qwen3's <think>...</think> "
58
+ "block (200-600 tokens) plus the final prompt. "
59
+ "Drop back to 256 if running with thinking=OFF.")
60
+ p.add_argument("--enable-thinking", action="store_true", default=True,
61
+ help="Apply Qwen3 chat template with thinking ON. "
62
+ "Default. Use --no-enable-thinking when evaluating "
63
+ "an adapter that was TRAINED with thinking=False.")
64
+ p.add_argument("--no-enable-thinking", dest="enable_thinking",
65
+ action="store_false")
66
  p.add_argument("--temperature", type=float, default=0.0)
67
  p.add_argument("--push-to-hub", default=None,
68
  help="HF model repo id to upload the eval JSONL under evals/eval_<label>.jsonl.")
 
95
  return model, tok
96
 
97
 
98
+ def build_chat_string(tok, obs, enable_thinking: bool = True) -> str:
99
  messages = [
100
  {"role": "system", "content": SYSTEM_PROMPT},
101
  {"role": "user", "content": build_agent_user_message(obs)},
102
  ]
103
  if getattr(tok, "chat_template", None):
104
  try:
105
+ # Mirror the chat template the adapter was trained against.
106
+ # Pass --no-enable-thinking when evaluating a thinking=False
107
+ # adapter to keep eval-time inputs in-distribution.
108
  return tok.apply_chat_template(
109
  messages, tokenize=False, add_generation_prompt=True,
110
+ enable_thinking=enable_thinking,
111
  )
112
  except TypeError:
113
  return tok.apply_chat_template(
 
169
  for task_id in task_ids:
170
  for seed in range(args.seeds_per_task):
171
  obs = env.reset(task=task_id, seed=seed)
172
+ chat_str = build_chat_string(tok, obs, enable_thinking=args.enable_thinking)
173
  agent_prompt = generate_prompt(
174
  model, tok, chat_str,
175
  max_new_tokens=args.max_new_tokens,
training/hf_job_eval.sh CHANGED
@@ -16,14 +16,26 @@ REPO_URL="${REPO_URL:-https://huggingface.co/spaces/rishabh16196/prompt_golf_env
16
  REPO_REF="${REPO_REF:-main}"
17
  ADAPTER_REPO="${ADAPTER_REPO:-rishabh16196/prompt-golf-grpo-1.5b}"
18
 
19
- AGENT_MODEL="${AGENT_MODEL:-Qwen/Qwen2.5-1.5B-Instruct}"
20
- TARGET_MODEL="${TARGET_MODEL:-Qwen/Qwen2.5-0.5B-Instruct}"
21
- SEEDS_PER_TASK="${SEEDS_PER_TASK:-3}"
 
 
 
 
 
 
22
 
23
  FLAVOR="${FLAVOR:-l40sx1}"
24
  TIMEOUT="${TIMEOUT:-1h}"
25
  IMAGE="${IMAGE:-pytorch/pytorch:2.4.0-cuda12.4-cudnn9-runtime}"
26
 
 
 
 
 
 
 
27
  run_eval() {
28
  local LABEL=$1
29
  local EXTRA_FLAGS=$2
@@ -51,6 +63,7 @@ python -u training/eval_before_after.py \
51
  --agent-model ${AGENT_MODEL} \
52
  --target-model ${TARGET_MODEL} \
53
  --seeds-per-task ${SEEDS_PER_TASK} \
 
54
  --label ${LABEL} \
55
  --output-json /app/outputs/eval_${LABEL}.jsonl \
56
  --push-to-hub ${ADAPTER_REPO} \
 
16
  REPO_REF="${REPO_REF:-main}"
17
  ADAPTER_REPO="${ADAPTER_REPO:-rishabh16196/prompt-golf-grpo-1.5b}"
18
 
19
+ AGENT_MODEL="${AGENT_MODEL:-Qwen/Qwen3-1.7B}"
20
+ TARGET_MODEL="${TARGET_MODEL:-Qwen/Qwen3-1.7B}"
21
+ # Eval is deterministic at temperature=0; seeds>1 produces bit-identical
22
+ # duplicate rows. Override only when running with temperature>0.
23
+ SEEDS_PER_TASK="${SEEDS_PER_TASK:-1}"
24
+ # Match the chat template the adapter was TRAINED against. The
25
+ # in-flight v2 adapter trained with thinking=OFF; v3 cross-family runs
26
+ # will train with thinking=ON. Override accordingly.
27
+ ENABLE_THINKING="${ENABLE_THINKING:-false}"
28
 
29
  FLAVOR="${FLAVOR:-l40sx1}"
30
  TIMEOUT="${TIMEOUT:-1h}"
31
  IMAGE="${IMAGE:-pytorch/pytorch:2.4.0-cuda12.4-cudnn9-runtime}"
32
 
33
+ # Build conditional thinking flag
34
+ THINKING_FLAG="--no-enable-thinking"
35
+ if [[ "${ENABLE_THINKING}" == "true" || "${ENABLE_THINKING}" == "True" ]]; then
36
+ THINKING_FLAG="--enable-thinking"
37
+ fi
38
+
39
  run_eval() {
40
  local LABEL=$1
41
  local EXTRA_FLAGS=$2
 
63
  --agent-model ${AGENT_MODEL} \
64
  --target-model ${TARGET_MODEL} \
65
  --seeds-per-task ${SEEDS_PER_TASK} \
66
+ ${THINKING_FLAG} \
67
  --label ${LABEL} \
68
  --output-json /app/outputs/eval_${LABEL}.jsonl \
69
  --push-to-hub ${ADAPTER_REPO} \
training/hf_job_profile.sh CHANGED
@@ -14,7 +14,7 @@ REPO_URL="${REPO_URL:-https://huggingface.co/spaces/rishabh16196/prompt_golf_env
14
  REPO_REF="${REPO_REF:-main}"
15
  PUSH_TO_HUB="${PUSH_TO_HUB:-rishabh16196/prompt-golf-grpo-1.5b}"
16
 
17
- TARGET_MODEL="${TARGET_MODEL:-Qwen/Qwen3-1.7B}"
18
  TASKS="${TASKS:-all}"
19
 
20
  FLAVOR="${FLAVOR:-l4x1}" # smaller flavor — no agent, no judge, no GRPO
 
14
  REPO_REF="${REPO_REF:-main}"
15
  PUSH_TO_HUB="${PUSH_TO_HUB:-rishabh16196/prompt-golf-grpo-1.5b}"
16
 
17
+ TARGET_MODEL="${TARGET_MODEL:-meta-llama/Llama-3.2-3B-Instruct}"
18
  TASKS="${TASKS:-all}"
19
 
20
  FLAVOR="${FLAVOR:-l4x1}" # smaller flavor — no agent, no judge, no GRPO
training/hf_job_train.sh CHANGED
@@ -24,7 +24,7 @@ PUSH_TO_HUB="${PUSH_TO_HUB:-rishabh16196/prompt-golf-grpo-1.5b}"
24
  # hard dep via TRL's newer import path; installing vllm on top of the
25
  # current image is flaky. Revisit for v3.
26
  AGENT_MODEL="${AGENT_MODEL:-Qwen/Qwen3-1.7B}"
27
- TARGET_MODEL="${TARGET_MODEL:-Qwen/Qwen3-1.7B}"
28
  JUDGE_MODEL="${JUDGE_MODEL:-Qwen/Qwen3-8B}"
29
 
30
  MAX_STEPS="${MAX_STEPS:-500}"
 
24
  # hard dep via TRL's newer import path; installing vllm on top of the
25
  # current image is flaky. Revisit for v3.
26
  AGENT_MODEL="${AGENT_MODEL:-Qwen/Qwen3-1.7B}"
27
+ TARGET_MODEL="${TARGET_MODEL:-meta-llama/Llama-3.2-3B-Instruct}"
28
  JUDGE_MODEL="${JUDGE_MODEL:-Qwen/Qwen3-8B}"
29
 
30
  MAX_STEPS="${MAX_STEPS:-500}"
training/hf_job_train_multistep.sh ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ #
3
+ # Launch multi-step GRPO training on HuggingFace Jobs. Hand-rolled
4
+ # trajectory-level GRPO loop (custom rollout + REINFORCE + KL); used
5
+ # when turn_limit > 1 and TRL's single-step GRPOTrainer cannot do
6
+ # the job.
7
+ #
8
+ # Mirrors hf_job_train.sh's install pattern verbatim — same OpenEnv-
9
+ # official torch/transformers/trl pin so the env loads identically.
10
+
11
+ set -euo pipefail
12
+
13
+ # -------- Configuration --------
14
+ REPO_URL="${REPO_URL:-https://huggingface.co/spaces/rishabh16196/prompt_golf_env}"
15
+ REPO_REF="${REPO_REF:-main}"
16
+ PUSH_TO_HUB="${PUSH_TO_HUB:-rishabh16196/prompt-golf-grpo-multistep}"
17
+
18
+ AGENT_MODEL="${AGENT_MODEL:-Qwen/Qwen3-1.7B}"
19
+ TARGET_MODEL="${TARGET_MODEL:-meta-llama/Llama-3.2-3B-Instruct}"
20
+ JUDGE_MODEL="${JUDGE_MODEL:-Qwen/Qwen3-8B}"
21
+ SFT_ADAPTER="${SFT_ADAPTER:-}" # optional warmstart from a single-step adapter
22
+
23
+ # Multi-step GRPO knobs (smaller defaults than train.sh because
24
+ # trajectories cost ~turn_limit× more per step).
25
+ MAX_STEPS="${MAX_STEPS:-200}"
26
+ NUM_GENS="${NUM_GENS:-4}"
27
+ BATCH_SIZE="${BATCH_SIZE:-2}"
28
+ LR="${LR:-3e-6}"
29
+ BETA="${BETA:-0.04}"
30
+ TURN_LIMIT="${TURN_LIMIT:-3}"
31
+ ENABLE_THINKING="${ENABLE_THINKING:-true}"
32
+
33
+ FLAVOR="${FLAVOR:-l40sx1}"
34
+ TIMEOUT="${TIMEOUT:-5h}"
35
+ IMAGE="${IMAGE:-pytorch/pytorch:2.4.0-cuda12.4-cudnn9-runtime}"
36
+
37
+ echo "[hf-jobs] repo=$REPO_URL@$REPO_REF"
38
+ echo "[hf-jobs] agent=$AGENT_MODEL target=$TARGET_MODEL judge=$JUDGE_MODEL"
39
+ echo "[hf-jobs] sft_adapter=${SFT_ADAPTER:-(none)}"
40
+ echo "[hf-jobs] turn_limit=$TURN_LIMIT enable_thinking=$ENABLE_THINKING"
41
+ echo "[hf-jobs] steps=$MAX_STEPS gens=$NUM_GENS B=$BATCH_SIZE lr=$LR beta=$BETA"
42
+ echo "[hf-jobs] flavor=$FLAVOR timeout=$TIMEOUT push_to_hub=$PUSH_TO_HUB"
43
+
44
+ # Build CLI tail conditionally (--no-enable-thinking when ENABLE_THINKING=false,
45
+ # --sft-adapter only when set).
46
+ THINKING_FLAG="--enable-thinking"
47
+ if [[ "${ENABLE_THINKING}" == "false" || "${ENABLE_THINKING}" == "False" ]]; then
48
+ THINKING_FLAG="--no-enable-thinking"
49
+ fi
50
+ SFT_FLAG=""
51
+ if [[ -n "${SFT_ADAPTER}" ]]; then
52
+ SFT_FLAG="--sft-adapter ${SFT_ADAPTER}"
53
+ fi
54
+
55
+ read -r -d '' JOB_CMD <<EOF || true
56
+ set -euo pipefail
57
+
58
+ apt-get update -qq
59
+ apt-get install -y -qq git curl build-essential
60
+
61
+ pip install --upgrade -q uv
62
+
63
+ uv pip install --system -q \\
64
+ "torch>=2.8.0" "torchvision>=0.25.0" "triton>=3.4.0" bitsandbytes \\
65
+ "transformers==4.56.2" \\
66
+ "unsloth_zoo[base] @ git+https://github.com/unslothai/unsloth-zoo" \\
67
+ "unsloth[base] @ git+https://github.com/unslothai/unsloth"
68
+
69
+ uv pip install --system --upgrade --no-deps -q \\
70
+ "transformers==4.56.2" tokenizers "trl==0.22.2" unsloth unsloth_zoo
71
+
72
+ git clone --depth 1 --branch ${REPO_REF} ${REPO_URL} /app
73
+ cd /app
74
+ pip install -q --no-deps -e .
75
+
76
+ pip install -q 'openenv-core[core]>=0.2.2' \\
77
+ 'peft>=0.13.0' 'datasets>=3.0.0' 'accelerate>=0.34.0' \\
78
+ 'huggingface_hub>=0.26.0' 'safetensors>=0.4.0' matplotlib
79
+
80
+ python -c "import torch; print('torch:', torch.__version__, '| cuda:', torch.cuda.is_available())"
81
+
82
+ python -u training/train_grpo_multistep.py \\
83
+ --agent-model ${AGENT_MODEL} \\
84
+ --target-model ${TARGET_MODEL} \\
85
+ --judge-model ${JUDGE_MODEL} \\
86
+ --turn-limit ${TURN_LIMIT} \\
87
+ ${THINKING_FLAG} \\
88
+ --max-steps ${MAX_STEPS} \\
89
+ --num-gens ${NUM_GENS} \\
90
+ --batch-size ${BATCH_SIZE} \\
91
+ --lr ${LR} \\
92
+ --beta ${BETA} \\
93
+ --output-dir /app/outputs/grpo_multistep \\
94
+ ${SFT_FLAG} \\
95
+ ${PUSH_TO_HUB:+--push-to-hub ${PUSH_TO_HUB}}
96
+ echo "[hf-jobs] done."
97
+ EOF
98
+
99
+ hf jobs run \
100
+ --flavor "${FLAVOR}" \
101
+ --timeout "${TIMEOUT}" \
102
+ --detach \
103
+ --secrets HF_TOKEN \
104
+ --env HF_HUB_ENABLE_HF_TRANSFER=1 \
105
+ --env TRANSFORMERS_VERBOSITY=warning \
106
+ "${IMAGE}" \
107
+ -- bash -c "${JOB_CMD}"
training/profile_baseline.py CHANGED
@@ -34,7 +34,7 @@ sys.path.insert(0, str(_REPO_ROOT))
34
 
35
  def parse_args() -> argparse.Namespace:
36
  p = argparse.ArgumentParser(description="Per-task target-capability profiler")
37
- p.add_argument("--target-model", default="Qwen/Qwen3-1.7B")
38
  p.add_argument("--target-backend", default="hf",
39
  help="hf | mock (mock for local dev only)")
40
  p.add_argument("--tasks", default="all",
 
34
 
35
  def parse_args() -> argparse.Namespace:
36
  p = argparse.ArgumentParser(description="Per-task target-capability profiler")
37
+ p.add_argument("--target-model", default="meta-llama/Llama-3.2-3B-Instruct")
38
  p.add_argument("--target-backend", default="hf",
39
  help="hf | mock (mock for local dev only)")
40
  p.add_argument("--tasks", default="all",
training/train_grpo.py CHANGED
@@ -72,6 +72,33 @@ def build_agent_user_message(obs) -> str:
72
  f"- input: {ex.get('input','')!r} expected: {ex.get('expected','')!r}"
73
  for ex in (obs.train_examples or [])
74
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  return textwrap.dedent(
76
  f"""
77
  TASK: {obs.task_id} (category: {obs.task_category})
@@ -81,20 +108,23 @@ def build_agent_user_message(obs) -> str:
81
  BASELINE (empty prompt) SCORE: {obs.baseline_zero_shot_score:.2f}
82
 
83
  Visible train examples (do not copy verbatim):
84
- {examples_block}
85
 
86
  Write your prompt inside <prompt>...</prompt>.
87
  """
88
  ).strip()
89
 
90
 
91
- def build_chat_prompt(tokenizer, obs) -> str:
92
  """Apply chat template → single string the agent's tokenizer will see.
93
 
94
  Passes enable_thinking=False for Qwen3 models so the agent emits its
95
  prompt directly instead of a <think>...</think> reasoning trace
96
- followed by output. Thinking-mode output would also blow past our
97
- max_completion_length budget.
 
 
 
98
  """
99
  messages = [
100
  {"role": "system", "content": SYSTEM_PROMPT},
@@ -105,7 +135,7 @@ def build_chat_prompt(tokenizer, obs) -> str:
105
  # Qwen3 / Qwen3.5 support this kwarg; other models ignore it.
106
  return tokenizer.apply_chat_template(
107
  messages, tokenize=False, add_generation_prompt=True,
108
- enable_thinking=False,
109
  )
110
  except TypeError:
111
  return tokenizer.apply_chat_template(
@@ -114,17 +144,20 @@ def build_chat_prompt(tokenizer, obs) -> str:
114
  return f"{SYSTEM_PROMPT}\n\n{build_agent_user_message(obs)}\n\nAssistant:"
115
 
116
 
117
- def build_prompt_dataset(env, tokenizer, task_ids: List[str], seeds_per_task: int):
 
 
 
118
  """Build a HF Dataset where each row is (chat-formatted prompt, task_id, seed)."""
119
  from datasets import Dataset
120
 
121
  rows: List[Dict] = []
122
  for task_id in task_ids:
123
  for seed in range(seeds_per_task):
124
- obs = env.reset(task=task_id, seed=seed)
125
  rows.append(
126
  {
127
- "prompt": build_chat_prompt(tokenizer, obs),
128
  "task_id": task_id,
129
  "seed": seed,
130
  }
@@ -259,7 +292,7 @@ def make_callback(log_state: Dict, output_dir: Path):
259
  def parse_args() -> argparse.Namespace:
260
  p = argparse.ArgumentParser(description="GRPO training for Prompt Golf")
261
  p.add_argument("--agent-model", default="Qwen/Qwen3-1.7B")
262
- p.add_argument("--target-model", default="Qwen/Qwen3-1.7B")
263
  p.add_argument("--output-dir", default="outputs/grpo")
264
 
265
  # Task split — held out spans v1 AND v2 for honest generalization eval
@@ -282,7 +315,23 @@ def parse_args() -> argparse.Namespace:
282
  p.add_argument("--gradient-accumulation-steps", type=int, default=4)
283
  p.add_argument("--learning-rate", type=float, default=5e-6)
284
  p.add_argument("--beta", type=float, default=0.04, help="KL penalty")
285
- p.add_argument("--max-completion-length", type=int, default=256)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
  p.add_argument("--max-prompt-length", type=int, default=1024)
287
 
288
  # Rollout sampling — explicit so we don't silently inherit Qwen3's
@@ -350,8 +399,14 @@ def main() -> None:
350
  print(f"[setup] tasks total={len(all_tasks)} train={len(train_tasks)} held_out={len(held_out)}", flush=True)
351
 
352
  # ----- dataset -----
353
- train_ds = build_prompt_dataset(env, tokenizer, train_tasks, args.seeds_per_task)
354
- eval_ds = build_prompt_dataset(env, tokenizer, sorted(held_out), seeds_per_task=2) if held_out else None
 
 
 
 
 
 
355
  print(f"[setup] train rows={len(train_ds)} eval rows={len(eval_ds) if eval_ds else 0}", flush=True)
356
 
357
  # ----- reward + callback -----
 
72
  f"- input: {ex.get('input','')!r} expected: {ex.get('expected','')!r}"
73
  for ex in (obs.train_examples or [])
74
  )
75
+
76
+ # When the env runs in multi-turn mode and a prior attempt has been
77
+ # scored, fold the per-attempt feedback into the user message so the
78
+ # agent can see what its earlier prompts produced and refine.
79
+ prior = list(getattr(obs, "prior_attempts", None) or [])
80
+ prior_block = ""
81
+ if prior:
82
+ chunks = []
83
+ for att in prior:
84
+ sg = att.get("sample_generations") or []
85
+ sg_lines = "\n".join(
86
+ f" input: {g.get('input','')!r} "
87
+ f"target_said: {g.get('target_output','')!r} "
88
+ f"expected: {g.get('expected','')!r}"
89
+ for g in sg[:2]
90
+ )
91
+ chunks.append(
92
+ f" Turn {att.get('turn','?')}: prompt={att.get('prompt','')!r} "
93
+ f"(tokens={att.get('tokens','?')}, score={att.get('feedback_score',0):.2f})"
94
+ + (f"\n{sg_lines}" if sg_lines else "")
95
+ )
96
+ prior_block = (
97
+ "\n\nPRIOR ATTEMPTS (refine your prompt to score higher on the "
98
+ "scoring slice — note where the target's wording missed the "
99
+ "expected format):\n" + "\n".join(chunks)
100
+ )
101
+
102
  return textwrap.dedent(
103
  f"""
104
  TASK: {obs.task_id} (category: {obs.task_category})
 
108
  BASELINE (empty prompt) SCORE: {obs.baseline_zero_shot_score:.2f}
109
 
110
  Visible train examples (do not copy verbatim):
111
+ {examples_block}{prior_block}
112
 
113
  Write your prompt inside <prompt>...</prompt>.
114
  """
115
  ).strip()
116
 
117
 
118
+ def build_chat_prompt(tokenizer, obs, enable_thinking: bool = True) -> str:
119
  """Apply chat template → single string the agent's tokenizer will see.
120
 
121
  Passes enable_thinking=False for Qwen3 models so the agent emits its
122
  prompt directly instead of a <think>...</think> reasoning trace
123
+ followed by output. With thinking ON the agent gets reasoning scratch
124
+ space "for free" — only the final extracted prompt is counted in the
125
+ length-budget rubric, so think tokens don't hurt reward. The cost is
126
+ longer generations, addressed by raising --max-completion-length.
127
+ extract_prompt() already strips <think>...</think> blocks defensively.
128
  """
129
  messages = [
130
  {"role": "system", "content": SYSTEM_PROMPT},
 
135
  # Qwen3 / Qwen3.5 support this kwarg; other models ignore it.
136
  return tokenizer.apply_chat_template(
137
  messages, tokenize=False, add_generation_prompt=True,
138
+ enable_thinking=enable_thinking,
139
  )
140
  except TypeError:
141
  return tokenizer.apply_chat_template(
 
144
  return f"{SYSTEM_PROMPT}\n\n{build_agent_user_message(obs)}\n\nAssistant:"
145
 
146
 
147
+ def build_prompt_dataset(
148
+ env, tokenizer, task_ids: List[str], seeds_per_task: int,
149
+ enable_thinking: bool = True,
150
+ ):
151
  """Build a HF Dataset where each row is (chat-formatted prompt, task_id, seed)."""
152
  from datasets import Dataset
153
 
154
  rows: List[Dict] = []
155
  for task_id in task_ids:
156
  for seed in range(seeds_per_task):
157
+ obs = env.reset(task=task_id, seed=seed) # turn_limit=1 (training fixed single-turn)
158
  rows.append(
159
  {
160
+ "prompt": build_chat_prompt(tokenizer, obs, enable_thinking=enable_thinking),
161
  "task_id": task_id,
162
  "seed": seed,
163
  }
 
292
  def parse_args() -> argparse.Namespace:
293
  p = argparse.ArgumentParser(description="GRPO training for Prompt Golf")
294
  p.add_argument("--agent-model", default="Qwen/Qwen3-1.7B")
295
+ p.add_argument("--target-model", default="meta-llama/Llama-3.2-3B-Instruct")
296
  p.add_argument("--output-dir", default="outputs/grpo")
297
 
298
  # Task split — held out spans v1 AND v2 for honest generalization eval
 
315
  p.add_argument("--gradient-accumulation-steps", type=int, default=4)
316
  p.add_argument("--learning-rate", type=float, default=5e-6)
317
  p.add_argument("--beta", type=float, default=0.04, help="KL penalty")
318
+ p.add_argument("--max-completion-length", type=int, default=768,
319
+ help="With enable_thinking=True (Qwen3), generations "
320
+ "include a <think>...</think> reasoning block "
321
+ "before the final prompt — typically 200-600 "
322
+ "tokens. 768 leaves room for both. Drop to 256 "
323
+ "if running thinking=OFF.")
324
+ p.add_argument("--enable-thinking", action="store_true", default=True,
325
+ help="Apply Qwen3 chat template with thinking ON. "
326
+ "Default. Use --no-enable-thinking to train a "
327
+ "thinking=False adapter (matches v2 behavior).")
328
+ p.add_argument("--no-enable-thinking", dest="enable_thinking",
329
+ action="store_false")
330
+ # NOTE: training is fixed at turn_limit=1 because GRPO is a
331
+ # single-decision algorithm (one prompt -> one reward). Multi-turn
332
+ # at training time would require PPO/A2C — deferred to v3.
333
+ # Multi-turn IS supported at inference / eval time (see
334
+ # eval_before_after.py --turn-limit).
335
  p.add_argument("--max-prompt-length", type=int, default=1024)
336
 
337
  # Rollout sampling — explicit so we don't silently inherit Qwen3's
 
399
  print(f"[setup] tasks total={len(all_tasks)} train={len(train_tasks)} held_out={len(held_out)}", flush=True)
400
 
401
  # ----- dataset -----
402
+ train_ds = build_prompt_dataset(
403
+ env, tokenizer, train_tasks, args.seeds_per_task,
404
+ enable_thinking=args.enable_thinking, turn_limit=args.turn_limit,
405
+ )
406
+ eval_ds = build_prompt_dataset(
407
+ env, tokenizer, sorted(held_out), seeds_per_task=2,
408
+ enable_thinking=args.enable_thinking, turn_limit=args.turn_limit,
409
+ ) if held_out else None
410
  print(f"[setup] train rows={len(train_ds)} eval rows={len(eval_ds) if eval_ds else 0}", flush=True)
411
 
412
  # ----- reward + callback -----
training/train_grpo_multistep.py ADDED
@@ -0,0 +1,585 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Multi-step GRPO for Prompt Golf — model in the env loop at every turn.
3
+
4
+ Adapted from spaces_pipeline_env/local_training/grpo_multistep.py (the
5
+ proven trajectory-level GRPO recipe used in the Spaces env). Differences
6
+ for Prompt Golf:
7
+
8
+ - Action is a free-form prompt string (not a JSON action).
9
+ - Trajectory length = `turn_limit` (typically 2 or 3).
10
+ - Trajectory grade = final-turn reward (`obs.reward` after step where
11
+ `obs.done == True`). Intermediate turns are unrewarded; the agent
12
+ only sees feedback in the next observation's `prior_attempts`.
13
+
14
+ Why this exists: TRL's GRPOTrainer treats one prompt -> one completion.
15
+ For multi-turn we need the model to generate at every env step, observe
16
+ the resulting feedback, and refine. This script runs a custom
17
+ trajectory-level GRPO loop (REINFORCE + KL vs frozen LoRA snapshot).
18
+
19
+ Memory cost: trainable LoRA + a snapshot dict of those LoRA weights as
20
+ the reference. Both fit easily on L40S (48 GB) alongside Qwen3-1.7B
21
+ target + Qwen3-8B 8-bit judge.
22
+
23
+ Usage:
24
+ python -u training/train_grpo_multistep.py \
25
+ --max-steps 200 --num-gens 4 --batch-size 2 \
26
+ --turn-limit 3 \
27
+ --enable-thinking \
28
+ --push-to-hub rishabh16196/prompt-golf-grpo-multistep
29
+ """
30
+
31
+ from __future__ import annotations
32
+
33
+ import argparse
34
+ import json
35
+ import os
36
+ import random
37
+ import sys
38
+ import time
39
+ from dataclasses import dataclass
40
+ from pathlib import Path
41
+ from typing import Any, Dict, List, Optional, Tuple
42
+
43
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
44
+
45
+ import torch
46
+ import torch.nn.functional as F
47
+
48
+ _HERE = Path(__file__).resolve().parent
49
+ _REPO_ROOT = _HERE.parent
50
+ sys.path.insert(0, str(_REPO_ROOT))
51
+
52
+ # Reuse the prompt format + extract_prompt from the single-step trainer
53
+ # so the multi-step rollouts match the agent's training distribution
54
+ # bit-for-bit (same SYSTEM_PROMPT, same chat template, same parsing).
55
+ from training.train_grpo import ( # noqa: E402
56
+ SYSTEM_PROMPT,
57
+ build_agent_user_message,
58
+ build_chat_prompt,
59
+ extract_prompt,
60
+ )
61
+
62
+
63
+ # ---------------------------------------------------------------------------
64
+ # Trajectory containers
65
+ # ---------------------------------------------------------------------------
66
+
67
+ @dataclass
68
+ class StepRecord:
69
+ prompt_ids: torch.Tensor # [seq_len] — chat-templated prompt
70
+ action_ids: torch.Tensor # [act_len] — generated tokens
71
+ action_text: str # extracted prompt (post-extract_prompt)
72
+
73
+
74
+ @dataclass
75
+ class Trajectory:
76
+ task_id: str
77
+ seed: int
78
+ steps: List[StepRecord]
79
+ grade: float # final-turn reward
80
+ raw_task_score: float # final-turn raw_task_score (accuracy)
81
+ submitted_tokens: int # final-turn prompt token count
82
+ turns_taken: int
83
+
84
+
85
+ # ---------------------------------------------------------------------------
86
+ # Rollout: model in the loop at every env step
87
+ # ---------------------------------------------------------------------------
88
+
89
+ def rollout_episode(
90
+ env, model, tokenizer, task_id: str, seed: int, *,
91
+ turn_limit: int,
92
+ max_new_tokens: int,
93
+ temperature: float,
94
+ enable_thinking: bool,
95
+ device: str,
96
+ max_prompt_tokens: int = 4096,
97
+ ) -> Trajectory:
98
+ """Run one episode. Model generates at every turn until env.done.
99
+
100
+ Returns a Trajectory with per-turn (prompt_ids, action_ids) pairs
101
+ used by the policy-gradient update.
102
+ """
103
+ from prompt_golf_env.models import GolfAction
104
+
105
+ obs = env.reset(task=task_id, seed=seed, turn_limit=turn_limit)
106
+ steps: List[StepRecord] = []
107
+ grade: float = 0.0
108
+ raw_task_score: float = 0.0
109
+ submitted_tokens: int = 0
110
+
111
+ model.eval()
112
+ while not obs.done:
113
+ # Build chat prompt — multi-turn obs carries prior_attempts which
114
+ # build_agent_user_message folds into the user message.
115
+ chat_str = build_chat_prompt(tokenizer, obs, enable_thinking=enable_thinking)
116
+ prompt_ids = tokenizer(chat_str, return_tensors="pt").input_ids[0]
117
+ if prompt_ids.shape[0] > max_prompt_tokens:
118
+ # Left-truncate (preserve the tail with the "write your prompt" hint)
119
+ prompt_ids = prompt_ids[-max_prompt_tokens:]
120
+ prompt_ids = prompt_ids.to(device)
121
+
122
+ with torch.no_grad():
123
+ out = model.generate(
124
+ prompt_ids.unsqueeze(0),
125
+ max_new_tokens=max_new_tokens,
126
+ do_sample=True,
127
+ temperature=temperature,
128
+ top_p=1.0,
129
+ pad_token_id=tokenizer.pad_token_id,
130
+ )
131
+ gen_ids = out[0][prompt_ids.shape[0]:]
132
+ gen_text = tokenizer.decode(gen_ids, skip_special_tokens=True)
133
+ action_text = extract_prompt(gen_text)
134
+
135
+ steps.append(StepRecord(
136
+ prompt_ids=prompt_ids.detach().cpu(),
137
+ action_ids=gen_ids.detach().cpu(),
138
+ action_text=action_text,
139
+ ))
140
+
141
+ obs = env.step(GolfAction(prompt=action_text))
142
+ if obs.done:
143
+ grade = float(obs.reward or 0.0)
144
+ raw_task_score = float(obs.raw_task_score or 0.0)
145
+ submitted_tokens = int(obs.submitted_prompt_tokens or 0)
146
+
147
+ return Trajectory(
148
+ task_id=task_id,
149
+ seed=seed,
150
+ steps=steps,
151
+ grade=grade,
152
+ raw_task_score=raw_task_score,
153
+ submitted_tokens=submitted_tokens,
154
+ turns_taken=len(steps),
155
+ )
156
+
157
+
158
+ # ---------------------------------------------------------------------------
159
+ # Log-prob computation (batched left-padding for memory efficiency)
160
+ # ---------------------------------------------------------------------------
161
+
162
+ def compute_logprobs_batched(
163
+ model, records: List[Tuple[torch.Tensor, torch.Tensor]],
164
+ device: str, pad_id: int,
165
+ ) -> List[torch.Tensor]:
166
+ """Per-record action-token logprobs in one batched forward pass.
167
+
168
+ Records are list of (prompt_ids, action_ids). We left-pad each
169
+ [prompt_ids | action_ids] sequence to the max length, then read the
170
+ a_len logits that predict each action token.
171
+ """
172
+ if not records:
173
+ return []
174
+ prompt_lens = [p.shape[0] for p, _ in records]
175
+ action_lens = [a.shape[0] for _, a in records]
176
+ seq_lens = [pl + al for pl, al in zip(prompt_lens, action_lens)]
177
+ max_len = max(seq_lens)
178
+ K = len(records)
179
+
180
+ input_ids = torch.full((K, max_len), pad_id, dtype=torch.long, device=device)
181
+ attn_mask = torch.zeros((K, max_len), dtype=torch.long, device=device)
182
+ for i, (p, a) in enumerate(records):
183
+ full = torch.cat([p.to(device), a.to(device)], dim=0)
184
+ input_ids[i, max_len - full.shape[0]:] = full
185
+ attn_mask[i, max_len - full.shape[0]:] = 1
186
+
187
+ out = model(input_ids=input_ids, attention_mask=attn_mask)
188
+ logits = out.logits # [K, T, V]
189
+
190
+ results: List[torch.Tensor] = []
191
+ for i, (p, a) in enumerate(records):
192
+ p_len, a_len = prompt_lens[i], action_lens[i]
193
+ pad_prefix = max_len - (p_len + a_len)
194
+ start = pad_prefix + p_len - 1
195
+ action_logits = logits[i, start : start + a_len] # [a_len, V]
196
+ logprobs = F.log_softmax(action_logits.float(), dim=-1)
197
+ action_ids_dev = a.to(device)
198
+ token_logp = logprobs.gather(1, action_ids_dev.unsqueeze(-1)).squeeze(-1)
199
+ results.append(token_logp)
200
+ return results
201
+
202
+
203
+ # ---------------------------------------------------------------------------
204
+ # Main training loop
205
+ # ---------------------------------------------------------------------------
206
+
207
+ def parse_args() -> argparse.Namespace:
208
+ p = argparse.ArgumentParser(description="Multi-step GRPO for Prompt Golf")
209
+ p.add_argument("--agent-model", default="Qwen/Qwen3-1.7B")
210
+ p.add_argument("--target-model", default="meta-llama/Llama-3.2-3B-Instruct")
211
+ p.add_argument("--judge-model", default="Qwen/Qwen3-8B")
212
+ p.add_argument("--sft-adapter", default=None,
213
+ help="Optional LoRA adapter to warm-start from "
214
+ "(e.g. baseline single-turn adapter). Strongly "
215
+ "recommended — RL on a freshly initialized "
216
+ "policy diverges easily.")
217
+ p.add_argument("--output-dir", default="outputs/grpo_multistep")
218
+ p.add_argument("--push-to-hub", default=None,
219
+ help="HF model repo id; pushes adapter + metrics here.")
220
+
221
+ # Trajectory shape
222
+ p.add_argument("--turn-limit", type=int, default=3,
223
+ help="Turns per episode. >1 enables multi-turn.")
224
+ p.add_argument("--enable-thinking", action="store_true", default=True)
225
+ p.add_argument("--no-enable-thinking", dest="enable_thinking",
226
+ action="store_false")
227
+
228
+ # GRPO knobs
229
+ p.add_argument("--max-steps", type=int, default=200)
230
+ p.add_argument("--num-gens", type=int, default=4,
231
+ help="Trajectories per task per GRPO step.")
232
+ p.add_argument("--batch-size", type=int, default=2,
233
+ help="Tasks sampled per GRPO step.")
234
+ p.add_argument("--lr", type=float, default=3e-6)
235
+ p.add_argument("--beta", type=float, default=0.04,
236
+ help="KL penalty vs frozen LoRA snapshot.")
237
+ p.add_argument("--temperature", type=float, default=0.9)
238
+ p.add_argument("--max-new-tokens", type=int, default=768)
239
+ p.add_argument("--max-prompt-tokens", type=int, default=4096)
240
+ p.add_argument("--max-grad-norm", type=float, default=0.5)
241
+ p.add_argument("--update-micro-batch", type=int, default=4,
242
+ help="Records per batched forward pass.")
243
+ p.add_argument("--save-every", type=int, default=50)
244
+
245
+ # LoRA (used when --sft-adapter is not given — fresh LoRA init)
246
+ p.add_argument("--lora-r", type=int, default=16)
247
+ p.add_argument("--lora-alpha", type=int, default=32)
248
+ p.add_argument("--lora-dropout", type=float, default=0.05)
249
+
250
+ # Task selection
251
+ p.add_argument("--held-out-tasks", default="",
252
+ help="Comma-separated task ids to exclude from training.")
253
+
254
+ p.add_argument("--seed", type=int, default=42)
255
+ p.add_argument("--dry-run", action="store_true",
256
+ help="Run one rollout and print, then exit.")
257
+ return p.parse_args()
258
+
259
+
260
+ def main() -> None:
261
+ args = parse_args()
262
+
263
+ random.seed(args.seed)
264
+ torch.manual_seed(args.seed)
265
+
266
+ # Env vars consumed by the env's lazy backends
267
+ os.environ.setdefault("PROMPT_GOLF_TARGET_MODEL", args.target_model)
268
+ os.environ.setdefault("PROMPT_GOLF_TARGET_BACKEND", "hf")
269
+ os.environ.setdefault("PROMPT_GOLF_JUDGE_MODEL", args.judge_model)
270
+
271
+ device = "cuda" if torch.cuda.is_available() else "cpu"
272
+ out_dir = Path(args.output_dir)
273
+ out_dir.mkdir(parents=True, exist_ok=True)
274
+
275
+ print("=== Multi-step GRPO (Prompt Golf, trajectory-level) ===", flush=True)
276
+ print(f" device: {device}", flush=True)
277
+ print(f" agent: {args.agent_model}", flush=True)
278
+ print(f" target: {args.target_model}", flush=True)
279
+ print(f" judge: {args.judge_model}", flush=True)
280
+ print(f" warmstart: {args.sft_adapter or '(fresh LoRA init)'}", flush=True)
281
+ print(f" turn_limit: {args.turn_limit}", flush=True)
282
+ print(f" enable_thinking: {args.enable_thinking}", flush=True)
283
+ print(f" max_steps: {args.max_steps}", flush=True)
284
+ print(f" tasks/step (B): {args.batch_size}", flush=True)
285
+ print(f" gens/task (G): {args.num_gens}", flush=True)
286
+ print(f" trajectories/step:{args.batch_size * args.num_gens}", flush=True)
287
+ print(f" lr / beta: {args.lr} / {args.beta}", flush=True)
288
+
289
+ # ---- Env (lazy-loads target on first use) ----
290
+ from prompt_golf_env.server.prompt_golf_environment import (
291
+ PromptGolfEnvironment,
292
+ _ALL_TASKS,
293
+ )
294
+ env = PromptGolfEnvironment()
295
+
296
+ # ---- Tokenizer ----
297
+ from transformers import AutoTokenizer
298
+ tokenizer = AutoTokenizer.from_pretrained(args.agent_model)
299
+ tokenizer.padding_side = "left"
300
+ if tokenizer.pad_token is None:
301
+ tokenizer.pad_token = tokenizer.eos_token
302
+
303
+ # ---- Base model + LoRA ----
304
+ print("\nLoading agent base model (bf16)...", flush=True)
305
+ t0 = time.time()
306
+ from transformers import AutoModelForCausalLM
307
+ base = AutoModelForCausalLM.from_pretrained(
308
+ args.agent_model,
309
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
310
+ device_map="auto" if torch.cuda.is_available() else None,
311
+ )
312
+ print(f" base loaded in {time.time()-t0:.1f}s", flush=True)
313
+
314
+ if args.sft_adapter:
315
+ print(f"Loading adapter from {args.sft_adapter} (trainable)...", flush=True)
316
+ from peft import PeftModel
317
+ model = PeftModel.from_pretrained(base, args.sft_adapter, is_trainable=True)
318
+ else:
319
+ print("Initializing fresh LoRA adapter (no warmstart)...", flush=True)
320
+ from peft import LoraConfig, get_peft_model
321
+ lora_cfg = LoraConfig(
322
+ r=args.lora_r,
323
+ lora_alpha=args.lora_alpha,
324
+ lora_dropout=args.lora_dropout,
325
+ bias="none",
326
+ task_type="CAUSAL_LM",
327
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
328
+ )
329
+ model = get_peft_model(base, lora_cfg)
330
+ model = model.to(device) if not torch.cuda.is_available() else model
331
+ n_tr = sum(p.numel() for p in model.parameters() if p.requires_grad)
332
+ print(f" trainable params: {n_tr:,}", flush=True)
333
+
334
+ # ---- Snapshot trainable weights as the KL reference ----
335
+ print("Snapshotting trainable weights as KL reference...", flush=True)
336
+ ref_state: Dict[str, torch.Tensor] = {
337
+ k: v.detach().clone()
338
+ for k, v in model.named_parameters() if v.requires_grad
339
+ }
340
+
341
+ # ---- Training task pool ----
342
+ held_out = {t.strip() for t in args.held_out_tasks.split(",") if t.strip()}
343
+ train_task_ids = [tid for tid in _ALL_TASKS.keys() if tid not in held_out]
344
+ print(f" task pool: {len(train_task_ids)} tasks "
345
+ f"(held out: {len(held_out)})", flush=True)
346
+
347
+ # ---- Optimizer ----
348
+ optim = torch.optim.AdamW(
349
+ [p for p in model.parameters() if p.requires_grad],
350
+ lr=args.lr, betas=(0.9, 0.95), eps=1e-8,
351
+ )
352
+
353
+ if args.dry_run:
354
+ print("\n[DRY-RUN] one rollout...", flush=True)
355
+ task = train_task_ids[0]
356
+ traj = rollout_episode(
357
+ env, model, tokenizer, task_id=task, seed=args.seed,
358
+ turn_limit=args.turn_limit,
359
+ max_new_tokens=args.max_new_tokens,
360
+ temperature=args.temperature,
361
+ enable_thinking=args.enable_thinking,
362
+ device=device,
363
+ max_prompt_tokens=args.max_prompt_tokens,
364
+ )
365
+ print(f" task={traj.task_id} turns={traj.turns_taken} "
366
+ f"grade={traj.grade:.3f} raw={traj.raw_task_score:.2f} "
367
+ f"tokens={traj.submitted_tokens}", flush=True)
368
+ for i, sr in enumerate(traj.steps):
369
+ print(f" turn {i+1}: action_text='{sr.action_text[:80]}' "
370
+ f"({sr.action_ids.shape[0]} action tokens)", flush=True)
371
+ print("[DRY-RUN] done — no training.", flush=True)
372
+ return
373
+
374
+ # ---- Training loop ----
375
+ print("\n=== starting multi-step GRPO ===\n", flush=True)
376
+ t_train = time.time()
377
+ metrics: List[Dict[str, Any]] = []
378
+ STD_FLOOR = 0.1
379
+ ADV_CLAMP = 3.0
380
+
381
+ def swap_weights(target_state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
382
+ """Copy target_state into trainable params; return prior snapshot."""
383
+ old: Dict[str, torch.Tensor] = {}
384
+ for k, v in model.named_parameters():
385
+ if v.requires_grad and k in target_state:
386
+ old[k] = v.detach().clone()
387
+ with torch.no_grad():
388
+ v.copy_(target_state[k])
389
+ return old
390
+
391
+ for step in range(args.max_steps):
392
+ step_t0 = time.time()
393
+ tasks_this_step = random.sample(
394
+ train_task_ids, min(args.batch_size, len(train_task_ids))
395
+ )
396
+ seed_base = args.seed + step * 1000
397
+
398
+ # ---- Phase 1: rollouts (no grad) ----
399
+ all_groups: List[List[Trajectory]] = []
400
+ for ti, task in enumerate(tasks_this_step):
401
+ group: List[Trajectory] = []
402
+ for g in range(args.num_gens):
403
+ traj = rollout_episode(
404
+ env, model, tokenizer,
405
+ task_id=task, seed=seed_base + ti * 100 + g,
406
+ turn_limit=args.turn_limit,
407
+ max_new_tokens=args.max_new_tokens,
408
+ temperature=args.temperature,
409
+ enable_thinking=args.enable_thinking,
410
+ device=device,
411
+ max_prompt_tokens=args.max_prompt_tokens,
412
+ )
413
+ group.append(traj)
414
+ all_groups.append(group)
415
+
416
+ # ---- Group-relative advantages with std floor + clamp ----
417
+ flat_records: List[Tuple[StepRecord, float]] = []
418
+ group_stats = []
419
+ n_groups_skipped = 0
420
+ for group in all_groups:
421
+ rewards = torch.tensor([t.grade for t in group], dtype=torch.float32)
422
+ mean_r = rewards.mean().item()
423
+ raw_std = rewards.std(unbiased=False).item()
424
+ if raw_std < 0.02: # all trajectories scored equal -> no signal
425
+ n_groups_skipped += 1
426
+ group_stats.append((rewards.tolist(), mean_r, 0.0))
427
+ continue
428
+ std_r = max(raw_std, STD_FLOOR)
429
+ group_stats.append((rewards.tolist(), mean_r, std_r))
430
+ for traj in group:
431
+ adv = (traj.grade - mean_r) / std_r
432
+ adv = max(-ADV_CLAMP, min(ADV_CLAMP, adv))
433
+ for sr in traj.steps:
434
+ flat_records.append((sr, adv))
435
+
436
+ if not flat_records:
437
+ print(f"step {step+1:3d}/{args.max_steps} all groups collapsed "
438
+ f"(equal rewards) — skipping update", flush=True)
439
+ continue
440
+
441
+ # ---- Phase 2: batched policy-gradient update ----
442
+ model.train()
443
+ optim.zero_grad()
444
+
445
+ total_loss_val = 0.0
446
+ total_kl_val = 0.0
447
+ n_records = len(flat_records)
448
+ MICRO = args.update_micro_batch
449
+
450
+ for start in range(0, n_records, MICRO):
451
+ batch = flat_records[start : start + MICRO]
452
+ batch_records = [(sr.prompt_ids, sr.action_ids) for sr, _ in batch]
453
+ batch_advs = [adv for _, adv in batch]
454
+
455
+ # Reference logp (no grad)
456
+ if args.beta > 0:
457
+ saved = swap_weights(ref_state)
458
+ with torch.no_grad():
459
+ ref_logps = compute_logprobs_batched(
460
+ model, batch_records, device, tokenizer.pad_token_id,
461
+ )
462
+ swap_weights(saved)
463
+ ref_logps = [r.detach() for r in ref_logps]
464
+ else:
465
+ ref_logps = [None] * len(batch)
466
+
467
+ # New logp (with grad)
468
+ new_logps = compute_logprobs_batched(
469
+ model, batch_records, device, tokenizer.pad_token_id,
470
+ )
471
+
472
+ # REINFORCE + KL loss
473
+ batch_loss_terms = []
474
+ for new_lp, ref_lp, adv in zip(new_logps, ref_logps, batch_advs):
475
+ if ref_lp is None:
476
+ ref_lp = new_lp.detach()
477
+ kl_per_tok = new_lp - ref_lp
478
+ pg_per_tok = -adv * new_lp
479
+ loss_per_tok = pg_per_tok + args.beta * kl_per_tok
480
+ batch_loss_terms.append(loss_per_tok.mean())
481
+ total_kl_val += kl_per_tok.mean().item()
482
+
483
+ micro_loss = torch.stack(batch_loss_terms).mean()
484
+ scale = len(batch) / n_records
485
+ (micro_loss * scale).backward()
486
+ total_loss_val += micro_loss.item() * len(batch)
487
+
488
+ total_loss_val = total_loss_val / max(1, n_records)
489
+
490
+ torch.nn.utils.clip_grad_norm_(
491
+ [p for p in model.parameters() if p.requires_grad],
492
+ args.max_grad_norm,
493
+ )
494
+ optim.step()
495
+
496
+ # ---- Log ----
497
+ all_rewards = [r for g in group_stats for r in g[0]]
498
+ avg_r = sum(all_rewards) / max(1, len(all_rewards))
499
+ max_r = max(all_rewards)
500
+ min_r = min(all_rewards)
501
+ avg_loss = total_loss_val
502
+ avg_kl = total_kl_val / max(1, n_records)
503
+ n_traj = sum(len(g) for g in all_groups)
504
+ n_steps_in_traj = sum(len(t.steps) for g in all_groups for t in g)
505
+ avg_tokens = (
506
+ sum(t.submitted_tokens for g in all_groups for t in g)
507
+ / max(1, n_traj)
508
+ )
509
+ avg_raw = (
510
+ sum(t.raw_task_score for g in all_groups for t in g)
511
+ / max(1, n_traj)
512
+ )
513
+ elapsed = time.time() - step_t0
514
+ print(
515
+ f"step {step+1:3d}/{args.max_steps} "
516
+ f"avg_r={avg_r:+.3f} [{min_r:+.2f}..{max_r:+.2f}] "
517
+ f"raw={avg_raw:.2f} tokens={avg_tokens:.1f} "
518
+ f"n_traj={n_traj} n_turns={n_steps_in_traj} "
519
+ f"grp_skip={n_groups_skipped} "
520
+ f"loss={avg_loss:+.4f} kl={avg_kl:+.4f} "
521
+ f"{elapsed:.1f}s",
522
+ flush=True,
523
+ )
524
+ metrics.append({
525
+ "step": step + 1,
526
+ "avg_reward": avg_r,
527
+ "min_reward": min_r,
528
+ "max_reward": max_r,
529
+ "avg_raw_task_score": avg_raw,
530
+ "avg_submitted_tokens": avg_tokens,
531
+ "loss": avg_loss,
532
+ "kl": avg_kl,
533
+ "n_trajectories": n_traj,
534
+ "n_turns_total": n_steps_in_traj,
535
+ "n_groups_skipped": n_groups_skipped,
536
+ "elapsed_s": elapsed,
537
+ })
538
+
539
+ if args.save_every > 0 and (step + 1) % args.save_every == 0 \
540
+ and (step + 1) < args.max_steps:
541
+ ckpt = out_dir / f"checkpoint-{step+1}"
542
+ ckpt.mkdir(parents=True, exist_ok=True)
543
+ model.save_pretrained(str(ckpt))
544
+ (out_dir / "train_metrics.json").write_text(json.dumps(metrics, indent=2))
545
+ print(f" ckpt -> {ckpt.name}", flush=True)
546
+
547
+ train_elapsed = time.time() - t_train
548
+ print(f"\n=== training done in {train_elapsed/60:.1f} min ===", flush=True)
549
+
550
+ # ---- Save adapter + metrics ----
551
+ final_dir = out_dir / "adapter_final"
552
+ final_dir.mkdir(parents=True, exist_ok=True)
553
+ model.save_pretrained(str(final_dir))
554
+ tokenizer.save_pretrained(str(final_dir))
555
+ (out_dir / "train_metrics.json").write_text(json.dumps(metrics, indent=2))
556
+ print(f" adapter -> {final_dir}", flush=True)
557
+ print(f" metrics -> {out_dir / 'train_metrics.json'}", flush=True)
558
+
559
+ # ---- Push to hub ----
560
+ if args.push_to_hub:
561
+ from huggingface_hub import HfApi
562
+ api = HfApi()
563
+ api.create_repo(args.push_to_hub, exist_ok=True, repo_type="model")
564
+ api.upload_folder(
565
+ folder_path=str(final_dir),
566
+ repo_id=args.push_to_hub,
567
+ repo_type="model",
568
+ path_in_repo="adapter_final",
569
+ commit_message=f"multi-step GRPO adapter ({args.max_steps} steps, "
570
+ f"turn_limit={args.turn_limit}, "
571
+ f"thinking={args.enable_thinking})",
572
+ )
573
+ api.upload_file(
574
+ path_or_fileobj=str(out_dir / "train_metrics.json"),
575
+ path_in_repo="metrics/train_metrics_multistep.json",
576
+ repo_id=args.push_to_hub,
577
+ repo_type="model",
578
+ commit_message="multi-step GRPO metrics",
579
+ )
580
+ print(f"[push] uploaded to https://huggingface.co/{args.push_to_hub}",
581
+ flush=True)
582
+
583
+
584
+ if __name__ == "__main__":
585
+ main()