anugrahhu commited on
Commit
a8d4d87
·
verified ·
1 Parent(s): a7acc5f

sft+reward-fix: training/sft_warmstart.py

Browse files
Files changed (1) hide show
  1. training/sft_warmstart.py +402 -0
training/sft_warmstart.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SFT (Supervised Fine-Tuning) warm-start for CERNenv.
2
+
3
+ Generates ``--num_episodes`` oracle trajectories against the local
4
+ ``CERNCollisionEnvironment`` (no HTTP), turns each successful step into
5
+ a (prompt, completion) pair using the *same* chat templating the GRPO
6
+ loop uses (see ``training.training_script.build_dataset`` /
7
+ ``training.llm_agent.build_chat``), and runs ``trl.SFTTrainer`` with
8
+ LoRA so the resulting checkpoint can be used as the starting weights
9
+ for GRPO.
10
+
11
+ This addresses the v1 reward hack head-on. v1
12
+ (``anugrahhu/cernenv-grpo-smollm2-360m``) never saw a positive-reward
13
+ trajectory during early training because SmolLM2-360M-Instruct cannot
14
+ solve the LHC discovery pipeline zero-shot, so GRPO had no positive
15
+ gradient to follow and the policy collapsed to "spam request_systematics
16
+ forever". A short SFT on oracle traces gives the policy a non-zero
17
+ prior over the *correct* action sequence, which RL can then refine.
18
+
19
+ Usage
20
+ -----
21
+ python -m training.sft_warmstart \\
22
+ --out_dir runs/sft-warmstart \\
23
+ --num_episodes 200 --max_steps 8 --epochs 1 --lr 1e-5 \\
24
+ --base_model HuggingFaceTB/SmolLM2-360M-Instruct \\
25
+ --difficulty mixed --evidence_dir evidence
26
+
27
+ Smoke test:
28
+ python -m training.sft_warmstart --num_episodes 4 --max_steps 4 \\
29
+ --epochs 1 --base_model sshleifer/tiny-gpt2 --out_dir /tmp/sft_smoke
30
+ """
31
+
32
+ from __future__ import annotations
33
+
34
+ import argparse
35
+ import json
36
+ import logging
37
+ import time
38
+ from dataclasses import asdict
39
+ from pathlib import Path
40
+ from typing import Any, Dict, List, Optional, Tuple
41
+
42
+ from models import ActionType, ExperimentAction
43
+ from server.environment import CERNCollisionEnvironment
44
+ from training.llm_agent import LLMAgentConfig, build_chat
45
+
46
+
47
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
48
+ logger = logging.getLogger(__name__)
49
+
50
+
51
+ # Mirrors training/training_script.py so the pre-train SFT and the
52
+ # downstream GRPO use the *same* prompt templating. If you change one,
53
+ # change both — divergence here is the most insidious source of
54
+ # warm-start ineffectiveness.
55
+ LORA_R = 16
56
+ LORA_ALPHA = 32
57
+ LORA_DROPOUT = 0.05
58
+ LORA_TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "o_proj"]
59
+
60
+
61
+ # ── Oracle trajectory collection ─────────────────────────────────────────
62
+
63
+
64
+ _DIFFICULTY_CYCLE = ("easy", "medium", "hard")
65
+
66
+
67
+ def _serialise_action(action: ExperimentAction) -> str:
68
+ """Return a JSON string the GRPO parser would accept.
69
+
70
+ Uses ``models.ExperimentAction.model_dump`` so enum values are
71
+ converted to their string representation and parameters survive a
72
+ round-trip through ``training.llm_agent.parse_action``.
73
+ """
74
+ payload: Dict[str, Any] = {
75
+ "action_type": action.action_type.value,
76
+ "parameters": dict(action.parameters or {}),
77
+ }
78
+ if action.method:
79
+ payload["method"] = action.method
80
+ if action.justification:
81
+ payload["justification"] = action.justification
82
+ if action.confidence is not None:
83
+ payload["confidence"] = float(action.confidence)
84
+ return json.dumps(payload, ensure_ascii=False)
85
+
86
+
87
+ def _difficulty_for(idx: int, difficulty: str) -> str:
88
+ if difficulty == "mixed":
89
+ return _DIFFICULTY_CYCLE[idx % len(_DIFFICULTY_CYCLE)]
90
+ return difficulty
91
+
92
+
93
+ def _collect_oracle_trajectories(
94
+ *,
95
+ tokenizer,
96
+ num_episodes: int,
97
+ max_steps: int,
98
+ difficulty: str,
99
+ seed_base: int = 4242,
100
+ ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
101
+ """Run the OracleAgent against the env and harvest (prompt, completion) pairs.
102
+
103
+ Returns (filtered_pairs, stats). ``stats`` reports oracle success
104
+ rate even when no successful trajectories were obtained, so callers
105
+ can log it into ``evidence/sft_summary.json``.
106
+ """
107
+ from scripts.baseline_agents import OracleAgent
108
+
109
+ config = LLMAgentConfig()
110
+ pairs_all: List[Dict[str, Any]] = []
111
+ pairs_successful: List[Dict[str, Any]] = []
112
+ successes = 0
113
+ rewards: List[float] = []
114
+
115
+ for i in range(num_episodes):
116
+ env = CERNCollisionEnvironment(max_steps=max_steps)
117
+ difficulty_i = _difficulty_for(i, difficulty)
118
+ obs = env.reset(seed=seed_base + i, difficulty=difficulty_i)
119
+ truth = env.hidden_truth()
120
+ agent = OracleAgent(truth=truth)
121
+ agent.reset()
122
+
123
+ episode_pairs: List[Dict[str, Any]] = []
124
+ cumulative = 0.0
125
+ while not obs.done and len(episode_pairs) < max_steps:
126
+ chat = build_chat(obs, config)
127
+ try:
128
+ prompt = tokenizer.apply_chat_template(
129
+ chat, add_generation_prompt=True, tokenize=False,
130
+ )
131
+ except Exception as exc: # pragma: no cover - tiny-gpt2 etc.
132
+ logger.warning(
133
+ "tokenizer has no chat template (%s); installing a "
134
+ "minimal fallback so SFT can proceed", exc,
135
+ )
136
+ tokenizer.chat_template = (
137
+ "{% for m in messages %}{{ m['role'] }}: {{ m['content'] }}\n"
138
+ "{% endfor %}"
139
+ "{% if add_generation_prompt %}assistant: {% endif %}"
140
+ )
141
+ prompt = tokenizer.apply_chat_template(
142
+ chat, add_generation_prompt=True, tokenize=False,
143
+ )
144
+
145
+ action = agent.act(obs)
146
+ completion = _serialise_action(action)
147
+ episode_pairs.append({
148
+ "prompt": prompt,
149
+ "completion": completion,
150
+ "step": obs.step_index,
151
+ "difficulty": difficulty_i,
152
+ "seed": seed_base + i,
153
+ })
154
+
155
+ obs = env.step(action)
156
+ cumulative += float(obs.reward or 0.0)
157
+
158
+ rewards.append(cumulative)
159
+ st = env.state
160
+ ok = bool(st.correct_mass) and bool(st.correct_channel)
161
+ if ok:
162
+ successes += 1
163
+ pairs_successful.extend(episode_pairs)
164
+ pairs_all.extend(episode_pairs)
165
+
166
+ success_rate = successes / max(num_episodes, 1)
167
+ mean_reward = sum(rewards) / max(len(rewards), 1) if rewards else 0.0
168
+ stats = {
169
+ "num_episodes": num_episodes,
170
+ "num_successful_episodes": successes,
171
+ "oracle_success_rate": round(success_rate, 4),
172
+ "mean_oracle_reward": round(mean_reward, 4),
173
+ "num_transitions_total": len(pairs_all),
174
+ "num_transitions_successful": len(pairs_successful),
175
+ }
176
+
177
+ # If we filtered out *everything* (e.g. smoke test with max_steps too
178
+ # small for the oracle to ever finish), fall back to the unfiltered
179
+ # set with a warning. Better to teach the model the prefix of the
180
+ # correct pipeline than to give up entirely.
181
+ if pairs_successful:
182
+ return pairs_successful, stats
183
+ logger.warning(
184
+ "no fully-successful oracle trajectories (max_steps=%d may be "
185
+ "too small); using %d unfiltered transitions instead",
186
+ max_steps, len(pairs_all),
187
+ )
188
+ stats["fallback_used"] = True
189
+ return pairs_all, stats
190
+
191
+
192
+ # ── Dataset assembly ─────────────────────────────────────────────────────
193
+
194
+
195
+ def _build_dataset(pairs: List[Dict[str, Any]], tokenizer):
196
+ """Build a HF Dataset whose ``text`` column is ``prompt + completion``.
197
+
198
+ SFTTrainer will train next-token-prediction on the *whole* text, so
199
+ we append an ``eos_token`` at the end of the completion to terminate
200
+ generation cleanly. ``dataset_text_field='text'`` on SFTConfig then
201
+ consumes this column directly.
202
+ """
203
+ from datasets import Dataset
204
+
205
+ eos = tokenizer.eos_token or ""
206
+ rows = []
207
+ for p in pairs:
208
+ rows.append({
209
+ "text": p["prompt"] + p["completion"] + eos,
210
+ "prompt": p["prompt"],
211
+ "completion": p["completion"],
212
+ })
213
+ return Dataset.from_list(rows)
214
+
215
+
216
+ # ── LoRA helper ──────────────────────────────────────────────────────────
217
+
218
+
219
+ def _build_peft_config(model):
220
+ """Build a LoRA config matching the GRPO setup.
221
+
222
+ For tiny stub models like ``sshleifer/tiny-gpt2`` the q_proj/k_proj/
223
+ v_proj/o_proj target modules don't exist (GPT-2 uses a fused
224
+ ``c_attn``); fall back to ``all-linear`` so the smoke test still
225
+ exercises the LoRA path.
226
+ """
227
+ from peft import LoraConfig
228
+
229
+ target_modules: Any = LORA_TARGET_MODULES
230
+ available = {n for n, _ in model.named_modules()}
231
+ has_target = any(t in mod_name for t in LORA_TARGET_MODULES for mod_name in available)
232
+ if not has_target:
233
+ logger.warning(
234
+ "model %s exposes none of %s; falling back to target_modules='all-linear'",
235
+ type(model).__name__, LORA_TARGET_MODULES,
236
+ )
237
+ target_modules = "all-linear"
238
+ return LoraConfig(
239
+ r=LORA_R,
240
+ lora_alpha=LORA_ALPHA,
241
+ lora_dropout=LORA_DROPOUT,
242
+ target_modules=target_modules,
243
+ task_type="CAUSAL_LM",
244
+ bias="none",
245
+ )
246
+
247
+
248
+ # ── Main ─────────────────────────────────────────────────────────────────
249
+
250
+
251
+ def main() -> None: # pragma: no cover - heavy ML path
252
+ parser = argparse.ArgumentParser(description="SFT warm-start for CERNenv GRPO")
253
+ parser.add_argument("--out_dir", default="runs/sft-warmstart")
254
+ parser.add_argument("--num_episodes", type=int, default=200)
255
+ parser.add_argument("--max_steps", type=int, default=8)
256
+ parser.add_argument("--epochs", type=int, default=1)
257
+ parser.add_argument("--lr", type=float, default=1e-5)
258
+ parser.add_argument("--base_model", default="HuggingFaceTB/SmolLM2-360M-Instruct")
259
+ parser.add_argument(
260
+ "--difficulty",
261
+ default="mixed",
262
+ choices=["easy", "medium", "hard", "mixed"],
263
+ help="'mixed' cycles easy/medium/hard across episodes.",
264
+ )
265
+ parser.add_argument("--evidence_dir", default="evidence")
266
+ parser.add_argument("--per_device_batch_size", type=int, default=4)
267
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
268
+ parser.add_argument("--max_seq_length", type=int, default=1280)
269
+ parser.add_argument("--seed", type=int, default=42)
270
+ parser.add_argument(
271
+ "--no_lora",
272
+ action="store_true",
273
+ help="Train the full model without LoRA (used by some smoke tests).",
274
+ )
275
+ args = parser.parse_args()
276
+
277
+ out_dir = Path(args.out_dir)
278
+ out_dir.mkdir(parents=True, exist_ok=True)
279
+ evidence_dir = Path(args.evidence_dir)
280
+ evidence_dir.mkdir(parents=True, exist_ok=True)
281
+
282
+ t_start = time.time()
283
+
284
+ try:
285
+ import torch
286
+ from transformers import AutoModelForCausalLM, AutoTokenizer
287
+ from trl import SFTConfig, SFTTrainer
288
+ except ImportError as exc: # pragma: no cover
289
+ raise SystemExit(
290
+ "Heavy ML deps missing — install -r space/training/requirements.txt"
291
+ ) from exc
292
+
293
+ logger.info("Loading tokenizer + base model: %s", args.base_model)
294
+ tokenizer = AutoTokenizer.from_pretrained(args.base_model)
295
+ if tokenizer.pad_token is None:
296
+ tokenizer.pad_token = tokenizer.eos_token
297
+
298
+ use_bf16 = torch.cuda.is_available()
299
+ dtype = torch.bfloat16 if use_bf16 else torch.float32
300
+ model = AutoModelForCausalLM.from_pretrained(args.base_model, torch_dtype=dtype)
301
+ if model.config.pad_token_id is None:
302
+ model.config.pad_token_id = tokenizer.pad_token_id
303
+
304
+ logger.info(
305
+ "Collecting %d oracle trajectories (max_steps=%d, difficulty=%s)",
306
+ args.num_episodes, args.max_steps, args.difficulty,
307
+ )
308
+ pairs, stats = _collect_oracle_trajectories(
309
+ tokenizer=tokenizer,
310
+ num_episodes=args.num_episodes,
311
+ max_steps=args.max_steps,
312
+ difficulty=args.difficulty,
313
+ )
314
+ logger.info(
315
+ "oracle stats: success_rate=%.2f total_transitions=%d kept=%d",
316
+ stats["oracle_success_rate"],
317
+ stats["num_transitions_total"],
318
+ len(pairs),
319
+ )
320
+ if not pairs:
321
+ raise SystemExit(
322
+ "No transitions collected — check OracleAgent / env wiring."
323
+ )
324
+
325
+ dataset = _build_dataset(pairs, tokenizer)
326
+ logger.info("Built SFT dataset: %d rows", len(dataset))
327
+
328
+ peft_config = None if args.no_lora else _build_peft_config(model)
329
+
330
+ sft_cfg = SFTConfig(
331
+ output_dir=str(out_dir),
332
+ per_device_train_batch_size=args.per_device_batch_size,
333
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
334
+ num_train_epochs=float(args.epochs),
335
+ learning_rate=args.lr,
336
+ logging_steps=5,
337
+ save_strategy="no", # we save manually at the end so checkpoints
338
+ # don't double the disk footprint.
339
+ bf16=use_bf16,
340
+ fp16=False,
341
+ seed=args.seed,
342
+ report_to=[],
343
+ dataset_text_field="text",
344
+ max_seq_length=args.max_seq_length,
345
+ packing=False,
346
+ )
347
+
348
+ trainer = SFTTrainer(
349
+ model=model,
350
+ args=sft_cfg,
351
+ train_dataset=dataset,
352
+ processing_class=tokenizer,
353
+ peft_config=peft_config,
354
+ )
355
+ logger.info("Starting SFT training (epochs=%d, lr=%.1e)", args.epochs, args.lr)
356
+ train_result = trainer.train()
357
+ final_loss = float(train_result.training_loss)
358
+ logger.info("SFT done; final training_loss=%.4f", final_loss)
359
+
360
+ # If we're using LoRA, merge the adapters back into the base model
361
+ # and save a *full* causal-LM checkpoint to ``out_dir``. GRPO
362
+ # downstream just calls ``AutoModelForCausalLM.from_pretrained(out_dir)``,
363
+ # which is much simpler than asking it to load a base model + adapters
364
+ # separately and keeps the warm-start path one ``--base_model`` flag.
365
+ if peft_config is not None:
366
+ try:
367
+ merged = trainer.model.merge_and_unload()
368
+ merged.save_pretrained(str(out_dir))
369
+ logger.info("Merged LoRA adapters into base model and saved to %s", out_dir)
370
+ except Exception as exc:
371
+ logger.warning(
372
+ "merge_and_unload failed (%s); saving adapters alongside base "
373
+ "and pointing GRPO at this directory will still work via PEFT.",
374
+ exc,
375
+ )
376
+ trainer.save_model(str(out_dir))
377
+ else:
378
+ trainer.save_model(str(out_dir))
379
+ tokenizer.save_pretrained(str(out_dir))
380
+
381
+ duration_s = round(time.time() - t_start, 2)
382
+ summary = dict(stats)
383
+ summary.update({
384
+ "final_loss": round(final_loss, 6),
385
+ "duration_s": duration_s,
386
+ "epochs": args.epochs,
387
+ "learning_rate": args.lr,
388
+ "base_model": args.base_model,
389
+ "out_dir": str(out_dir),
390
+ "lora": peft_config is not None,
391
+ "num_train_rows": len(dataset),
392
+ })
393
+ summary_path = evidence_dir / "sft_summary.json"
394
+ summary_path.write_text(json.dumps(summary, indent=2))
395
+ logger.info(
396
+ "Wrote SFT summary to %s (final_loss=%.4f, duration=%.1fs)",
397
+ summary_path, final_loss, duration_s,
398
+ )
399
+
400
+
401
+ if __name__ == "__main__": # pragma: no cover
402
+ main()