akhiilll commited on
Commit
eed849b
·
verified ·
1 Parent(s): 43372d5

add HF Jobs GRPO training script

Browse files
Files changed (1) hide show
  1. training/train_grpo_hf_job.py +464 -0
training/train_grpo_hf_job.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GRPO fine-tune Qwen2.5-1.5B on the ClaimSense gym - designed for HF Jobs.
2
+
3
+ Drop-in replacement for the notebook's training loop, but configured to run
4
+ inside a `huggingface_hub.HfApi.run_uv_job` invocation on `a10g-largex4`
5
+ hardware. We:
6
+
7
+ 1. ``git clone`` the ClaimSense Space so the gym runs in-process (deterministic
8
+ per ``scenario_index``).
9
+ 2. Load ``Qwen/Qwen2.5-1.5B-Instruct`` in bf16 on cuda:0 (no Unsloth -- the
10
+ default ``uv`` image lacks the CUDA dev libs Unsloth's kernels need; on
11
+ A10G we have enough memory to run without it).
12
+ 3. Wrap with PEFT LoRA r=16, alpha=32, target_modules=q/k/v/o/gate/up/down.
13
+ 4. Build the prompt dataset, reward functions (format + env-replay).
14
+ 5. Run ``trl.GRPOTrainer.train()`` for ``NUM_GRPO_STEPS`` steps.
15
+ 6. Plot reward / KL / completion-length curves to ``grpo_training.png``.
16
+ 7. Upload artifacts to ``runs/grpo-<timestamp>/`` on the Space repo so they
17
+ show up in the README plots.
18
+
19
+ Configuration (all env vars):
20
+ * ``HF_TOKEN`` - mandatory, used for hub access + artifact upload
21
+ * ``MODEL_ID`` - default ``Qwen/Qwen2.5-1.5B-Instruct``
22
+ * ``NUM_GRPO_STEPS`` - default ``80``
23
+ * ``NUM_GENERATIONS`` - default ``4``
24
+ * ``BATCH_SIZE`` - default ``2`` (per-device)
25
+ * ``GRAD_ACCUM`` - default ``2``
26
+ * ``LEARNING_RATE`` - default ``5e-6``
27
+ * ``CASE_REPEATS`` - default ``8`` (each of 8 cases x N -> dataset rows)
28
+ * ``ARTIFACT_REPO`` - default ``akhiilll/claims-env``
29
+ * ``ARTIFACT_REPO_TYPE`` - default ``space``
30
+ * ``CLAIMS_ENV_REPO`` - default ``akhiilll/claims-env`` (gym source)
31
+ """
32
+ from __future__ import annotations
33
+
34
+ import datetime
35
+ import json
36
+ import os
37
+ import re
38
+ import statistics
39
+ import subprocess
40
+ import sys
41
+ import traceback
42
+ from pathlib import Path
43
+
44
+ # ---------------------------------------------------------------------------
45
+ # 1. Clone the gym repo so we can import AdjudicationGym in-process
46
+ # ---------------------------------------------------------------------------
47
+
48
+ CLAIMS_ENV_REPO = os.environ.get("CLAIMS_ENV_REPO", "akhiilll/claims-env")
49
+ CLONE_DIR = Path("/tmp/claims-env-repo")
50
+
51
+ if not CLONE_DIR.exists():
52
+ print(f"[setup] cloning https://huggingface.co/spaces/{CLAIMS_ENV_REPO} -> {CLONE_DIR}")
53
+ subprocess.check_call(
54
+ [
55
+ "git",
56
+ "clone",
57
+ "--depth",
58
+ "1",
59
+ f"https://huggingface.co/spaces/{CLAIMS_ENV_REPO}",
60
+ str(CLONE_DIR),
61
+ ]
62
+ )
63
+
64
+ sys.path.insert(0, str(CLONE_DIR))
65
+ sys.path.insert(0, str(CLONE_DIR / "server"))
66
+
67
+ from server.claims_environment import ACTION_VOCABULARY, AdjudicationGym # type: ignore # noqa: E402
68
+ from server.mock_systems import CASE_LIBRARY # type: ignore # noqa: E402
69
+ from models import AdjudicatorAction # type: ignore # noqa: E402
70
+
71
+ print(f"[setup] gym imported. {len(ACTION_VOCABULARY)} verbs, {len(CASE_LIBRARY)} cases.")
72
+
73
+ # ---------------------------------------------------------------------------
74
+ # 2. Heavy ML imports (after gym imports so import errors above are visible)
75
+ # ---------------------------------------------------------------------------
76
+
77
+ import matplotlib # noqa: E402
78
+
79
+ matplotlib.use("Agg")
80
+ import matplotlib.pyplot as plt # noqa: E402
81
+ import torch # noqa: E402
82
+ from datasets import Dataset # noqa: E402
83
+ from peft import LoraConfig, get_peft_model # noqa: E402
84
+ from transformers import AutoModelForCausalLM, AutoTokenizer # noqa: E402
85
+ from trl import GRPOConfig, GRPOTrainer # noqa: E402
86
+
87
+ # ---------------------------------------------------------------------------
88
+ # 3. Configuration
89
+ # ---------------------------------------------------------------------------
90
+
91
+ MODEL_ID = os.environ.get("MODEL_ID", "Qwen/Qwen2.5-1.5B-Instruct")
92
+ NUM_GRPO_STEPS = int(os.environ.get("NUM_GRPO_STEPS", "80"))
93
+ NUM_GENERATIONS = int(os.environ.get("NUM_GENERATIONS", "4"))
94
+ BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "2"))
95
+ GRAD_ACCUM = int(os.environ.get("GRAD_ACCUM", "2"))
96
+ LEARNING_RATE = float(os.environ.get("LEARNING_RATE", "5e-6"))
97
+ CASE_REPEATS = int(os.environ.get("CASE_REPEATS", "8"))
98
+ MAX_PROMPT_LEN = int(os.environ.get("MAX_PROMPT_LEN", "512"))
99
+ MAX_COMPLETION_LEN = int(os.environ.get("MAX_COMPLETION_LEN", "256"))
100
+
101
+ ARTIFACT_REPO = os.environ.get("ARTIFACT_REPO", "akhiilll/claims-env")
102
+ ARTIFACT_REPO_TYPE = os.environ.get("ARTIFACT_REPO_TYPE", "space")
103
+ RUN_ID = datetime.datetime.utcnow().strftime("grpo-%Y%m%d-%H%M%S")
104
+
105
+ print(f"[config] model={MODEL_ID}")
106
+ print(f"[config] steps={NUM_GRPO_STEPS} gens={NUM_GENERATIONS} bsz={BATCH_SIZE} grad_accum={GRAD_ACCUM}")
107
+ print(f"[config] lr={LEARNING_RATE} run_id={RUN_ID}")
108
+ print(f"[config] cuda available: {torch.cuda.is_available()} | n_gpus: {torch.cuda.device_count()}")
109
+ if torch.cuda.is_available():
110
+ for i in range(torch.cuda.device_count()):
111
+ print(f" gpu[{i}]: {torch.cuda.get_device_name(i)}")
112
+
113
+ OUT_DIR = Path("/tmp/grpo-claims")
114
+ OUT_DIR.mkdir(parents=True, exist_ok=True)
115
+
116
+ # ---------------------------------------------------------------------------
117
+ # 4. Tokenizer + model + LoRA
118
+ # ---------------------------------------------------------------------------
119
+
120
+ token = os.environ.get("HF_TOKEN")
121
+ print(f"[setup] loading tokenizer {MODEL_ID}")
122
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=token)
123
+ if tokenizer.pad_token is None:
124
+ tokenizer.pad_token = tokenizer.eos_token
125
+
126
+ print(f"[setup] loading model {MODEL_ID} in bfloat16")
127
+ model = AutoModelForCausalLM.from_pretrained(
128
+ MODEL_ID,
129
+ token=token,
130
+ dtype=torch.bfloat16,
131
+ device_map={"": 0}, # single device, GRPOTrainer handles rollouts
132
+ attn_implementation="eager", # safest across versions
133
+ )
134
+ model.config.pad_token_id = tokenizer.pad_token_id
135
+ model.gradient_checkpointing_enable()
136
+
137
+ print("[setup] applying LoRA r=16, alpha=32")
138
+ lora = LoraConfig(
139
+ r=16,
140
+ lora_alpha=32,
141
+ lora_dropout=0.0,
142
+ bias="none",
143
+ task_type="CAUSAL_LM",
144
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
145
+ "gate_proj", "up_proj", "down_proj"],
146
+ )
147
+ model = get_peft_model(model, lora)
148
+ model.print_trainable_parameters()
149
+
150
+ # ---------------------------------------------------------------------------
151
+ # 5. Build prompt dataset
152
+ # ---------------------------------------------------------------------------
153
+
154
+ SYSTEM_PROMPT = (
155
+ "You are an expert insurance claims adjuster.\n\n"
156
+ "Available actions (one per line, lowercase, in this order of execution):\n"
157
+ " query_policy\n"
158
+ " query_claim_history\n"
159
+ " check_fraud\n"
160
+ " request_documents\n"
161
+ " verify_coverage\n"
162
+ " verify_purchase\n"
163
+ " calculate_payout\n"
164
+ " approve <amount> (terminal)\n"
165
+ " deny <reason> (terminal)\n"
166
+ " escalate <reason> (terminal)\n\n"
167
+ "Information actions cost a small fee; correct terminal verdicts pay big.\n"
168
+ "Catching fraud via deny pays even more. Output up to 6 actions, one per\n"
169
+ "line, ending with a terminal action. Do not write anything else."
170
+ )
171
+
172
+
173
+ def claim_to_user_msg(scenario_index: int) -> str:
174
+ env = AdjudicationGym(scenario_index=scenario_index)
175
+ obs = env.reset()
176
+ return (
177
+ f"New claim arrived:\n"
178
+ f" claim_id : {obs.claim_id}\n"
179
+ f" type : {obs.claim_type}\n"
180
+ f" amount : ${obs.claim_amount_requested:,.2f}\n"
181
+ f" claimant : {obs.claimant_name}\n"
182
+ f" incident_date: {obs.incident_date}\n"
183
+ f" description : {obs.description}\n\n"
184
+ f"What is your action plan?"
185
+ )
186
+
187
+
188
+ def make_prompt(scenario_index: int) -> str:
189
+ msgs = [
190
+ {"role": "system", "content": SYSTEM_PROMPT},
191
+ {"role": "user", "content": claim_to_user_msg(scenario_index)},
192
+ ]
193
+ return tokenizer.apply_chat_template(
194
+ msgs, tokenize=False, add_generation_prompt=True
195
+ )
196
+
197
+
198
+ print(f"[setup] building dataset (case_repeats={CASE_REPEATS})")
199
+ rows = []
200
+ for _ in range(CASE_REPEATS):
201
+ for sidx in range(len(CASE_LIBRARY)):
202
+ rows.append({"prompt": make_prompt(sidx), "scenario_index": sidx})
203
+ train_ds = Dataset.from_list(rows).shuffle(seed=42)
204
+ print(f"[setup] dataset rows: {len(train_ds)}")
205
+
206
+ # ---------------------------------------------------------------------------
207
+ # 6. Reward functions
208
+ # ---------------------------------------------------------------------------
209
+
210
+ ACTIONS_SET = set(ACTION_VOCABULARY)
211
+ TERMINALS = {"approve", "deny", "escalate"}
212
+
213
+
214
+ def _coerce(c) -> str:
215
+ if isinstance(c, list):
216
+ if not c:
217
+ return ""
218
+ return c[0].get("content", "") if isinstance(c[0], dict) else str(c[0])
219
+ return str(c)
220
+
221
+
222
+ def parse_actions(completion: str) -> list[AdjudicatorAction]:
223
+ actions: list[AdjudicatorAction] = []
224
+ for raw in completion.strip().splitlines():
225
+ line = raw.strip().lstrip("-*0123456789. ").lower().strip()
226
+ if not line:
227
+ continue
228
+ parts = line.split(maxsplit=1)
229
+ verb = parts[0]
230
+ if verb not in ACTIONS_SET:
231
+ continue
232
+ params: dict = {}
233
+ rest = parts[1] if len(parts) > 1 else ""
234
+ if verb == "approve":
235
+ m = re.search(r"\d[\d,\.]*", rest)
236
+ if m:
237
+ try:
238
+ params["amount"] = float(m.group().replace(",", ""))
239
+ except ValueError:
240
+ pass
241
+ elif verb == "deny":
242
+ params["reason"] = (rest or "policy_violation")[:80]
243
+ elif verb == "escalate":
244
+ params["reason"] = (rest or "manager_review")[:80]
245
+ actions.append(AdjudicatorAction(action_type=verb, parameters=params))
246
+ if verb in TERMINALS:
247
+ break
248
+ return actions
249
+
250
+
251
+ def replay(actions, sidx, max_steps=8):
252
+ env = AdjudicationGym(scenario_index=int(sidx))
253
+ env.reset()
254
+ total = 0.0
255
+ for act in actions[:max_steps]:
256
+ obs = env.step(act)
257
+ total += float(obs.reward)
258
+ if obs.done:
259
+ break
260
+ return total
261
+
262
+
263
+ def format_reward_fn(prompts, completions, **_):
264
+ rewards = []
265
+ for c in completions:
266
+ actions = parse_actions(_coerce(c))
267
+ if not actions:
268
+ rewards.append(-1.0)
269
+ continue
270
+ rewards.append(0.5 if actions[-1].action_type in TERMINALS else -0.25)
271
+ return rewards
272
+
273
+
274
+ def env_reward_fn(prompts, completions, scenario_index, **_):
275
+ return [
276
+ replay(parse_actions(_coerce(c)), s)
277
+ for c, s in zip(completions, scenario_index)
278
+ ]
279
+
280
+
281
+ # Sanity check (so a broken reward fn fails fast, before the trainer starts)
282
+ sane_text = "query_policy\nverify_coverage\napprove 3500"
283
+ sane_r = replay(parse_actions(sane_text), 0)
284
+ print(f"[sanity] optimal trace on case 0 -> env reward {sane_r:+.2f}")
285
+ assert sane_r > 0, f"reward fn broken (expected >0 on case 0, got {sane_r})"
286
+
287
+ # ---------------------------------------------------------------------------
288
+ # 7. GRPO training
289
+ # ---------------------------------------------------------------------------
290
+
291
+ training_args = GRPOConfig(
292
+ output_dir=str(OUT_DIR),
293
+ learning_rate=LEARNING_RATE,
294
+ adam_beta1=0.9,
295
+ adam_beta2=0.99,
296
+ weight_decay=0.1,
297
+ warmup_ratio=0.1,
298
+ lr_scheduler_type="cosine",
299
+ optim="adamw_torch",
300
+ logging_steps=1,
301
+ per_device_train_batch_size=BATCH_SIZE,
302
+ gradient_accumulation_steps=GRAD_ACCUM,
303
+ num_generations=NUM_GENERATIONS,
304
+ max_prompt_length=MAX_PROMPT_LEN,
305
+ max_completion_length=MAX_COMPLETION_LEN,
306
+ max_steps=NUM_GRPO_STEPS,
307
+ save_steps=999_999,
308
+ report_to="none",
309
+ bf16=True,
310
+ temperature=0.9,
311
+ top_p=0.95,
312
+ epsilon=0.2,
313
+ beta=0.04,
314
+ )
315
+
316
+ trainer = GRPOTrainer(
317
+ model=model,
318
+ processing_class=tokenizer,
319
+ reward_funcs=[format_reward_fn, env_reward_fn],
320
+ args=training_args,
321
+ train_dataset=train_ds,
322
+ )
323
+
324
+ print("[train] launching GRPOTrainer.train()")
325
+ try:
326
+ trainer.train()
327
+ print("[train] done")
328
+ except Exception:
329
+ traceback.print_exc()
330
+ raise
331
+
332
+ # ---------------------------------------------------------------------------
333
+ # 8. Plot training curves
334
+ # ---------------------------------------------------------------------------
335
+
336
+ log = trainer.state.log_history
337
+
338
+
339
+ def series(key: str):
340
+ xs, ys = [], []
341
+ for entry in log:
342
+ if key in entry and "step" in entry:
343
+ xs.append(entry["step"])
344
+ ys.append(entry[key])
345
+ return xs, ys
346
+
347
+
348
+ fig, axes = plt.subplots(2, 2, figsize=(13, 8))
349
+
350
+ xs, ys = series("reward")
351
+ axes[0, 0].plot(xs, ys, color="#1f77b4")
352
+ axes[0, 0].set_title("mean group reward")
353
+ axes[0, 0].set_xlabel("training step")
354
+ axes[0, 0].set_ylabel("reward")
355
+ axes[0, 0].grid(alpha=0.3)
356
+
357
+ fmt_xs, fmt_ys = series("rewards/format_reward_fn")
358
+ env_xs, env_ys = series("rewards/env_reward_fn")
359
+ if not fmt_ys:
360
+ fmt_xs, fmt_ys = series("rewards/format_reward_fn/mean")
361
+ env_xs, env_ys = series("rewards/env_reward_fn/mean")
362
+ axes[0, 1].plot(fmt_xs, fmt_ys, label="format reward", color="#2ca02c")
363
+ axes[0, 1].plot(env_xs, env_ys, label="env reward", color="#d62728")
364
+ axes[0, 1].set_title("per-reward-function score")
365
+ axes[0, 1].set_xlabel("training step")
366
+ axes[0, 1].set_ylabel("reward")
367
+ axes[0, 1].legend()
368
+ axes[0, 1].grid(alpha=0.3)
369
+
370
+ xs, ys = series("kl")
371
+ axes[1, 0].plot(xs, ys, color="#9467bd")
372
+ axes[1, 0].set_title("KL(model || reference)")
373
+ axes[1, 0].set_xlabel("training step")
374
+ axes[1, 0].set_ylabel("kl")
375
+ axes[1, 0].grid(alpha=0.3)
376
+
377
+ xs, ys = series("completion_length")
378
+ if not ys:
379
+ xs, ys = series("completions/mean_length")
380
+ axes[1, 1].plot(xs, ys, color="#ff7f0e")
381
+ axes[1, 1].set_title("mean completion length (tokens)")
382
+ axes[1, 1].set_xlabel("training step")
383
+ axes[1, 1].set_ylabel("tokens")
384
+ axes[1, 1].grid(alpha=0.3)
385
+
386
+ fig.tight_layout()
387
+ png_path = OUT_DIR / "grpo_training.png"
388
+ fig.savefig(png_path, dpi=120)
389
+ print(f"[plot] saved {png_path}")
390
+
391
+ log_path = OUT_DIR / "training_log.json"
392
+ with log_path.open("w") as fh:
393
+ json.dump(log, fh, indent=2, default=str)
394
+ print(f"[plot] saved {log_path}")
395
+
396
+ summary = {
397
+ "run_id": RUN_ID,
398
+ "base_model": MODEL_ID,
399
+ "trainer": "trl.GRPOTrainer",
400
+ "num_steps": NUM_GRPO_STEPS,
401
+ "num_generations": NUM_GENERATIONS,
402
+ "batch_size": BATCH_SIZE,
403
+ "grad_accum": GRAD_ACCUM,
404
+ "learning_rate": LEARNING_RATE,
405
+ "case_repeats": CASE_REPEATS,
406
+ "dataset_rows": len(train_ds),
407
+ "reward_functions": ["format_reward_fn", "env_reward_fn"],
408
+ "env": "ClaimSense (https://huggingface.co/spaces/akhiilll/claims-env)",
409
+ }
410
+ last_reward = ys[-1] if ys else None
411
+ xs2, ys2 = series("reward")
412
+ if ys2:
413
+ summary["first_reward"] = ys2[0]
414
+ summary["last_reward"] = ys2[-1]
415
+ summary["best_reward"] = max(ys2)
416
+ summary["worst_reward"] = min(ys2)
417
+ summary["mean_reward"] = statistics.mean(ys2)
418
+
419
+ summary_path = OUT_DIR / "run_summary.json"
420
+ with summary_path.open("w") as fh:
421
+ json.dump(summary, fh, indent=2)
422
+ print(json.dumps(summary, indent=2))
423
+
424
+ # ---------------------------------------------------------------------------
425
+ # 9. Save adapter + upload artifacts back to the Space
426
+ # ---------------------------------------------------------------------------
427
+
428
+ adapter_dir = OUT_DIR / "lora-adapter"
429
+ trainer.model.save_pretrained(str(adapter_dir))
430
+ tokenizer.save_pretrained(str(adapter_dir))
431
+ print(f"[save] LoRA adapter -> {adapter_dir}")
432
+
433
+ try:
434
+ from huggingface_hub import HfApi
435
+
436
+ api = HfApi(token=token)
437
+ target_dir = f"runs/{RUN_ID}"
438
+ uploads = [
439
+ (png_path, f"{target_dir}/grpo_training.png"),
440
+ (log_path, f"{target_dir}/training_log.json"),
441
+ (summary_path, f"{target_dir}/run_summary.json"),
442
+ ]
443
+ for src, dst in uploads:
444
+ api.upload_file(
445
+ path_or_fileobj=str(src),
446
+ path_in_repo=dst,
447
+ repo_id=ARTIFACT_REPO,
448
+ repo_type=ARTIFACT_REPO_TYPE,
449
+ commit_message=f"GRPO run: {RUN_ID}",
450
+ )
451
+ print(f"[upload] {dst}")
452
+
453
+ api.upload_folder(
454
+ folder_path=str(adapter_dir),
455
+ path_in_repo=f"{target_dir}/lora-adapter",
456
+ repo_id=ARTIFACT_REPO,
457
+ repo_type=ARTIFACT_REPO_TYPE,
458
+ commit_message=f"GRPO LoRA adapter: {RUN_ID}",
459
+ )
460
+ print(f"[upload] {target_dir}/lora-adapter (folder)")
461
+ print(f"[done] artifacts at https://huggingface.co/spaces/{ARTIFACT_REPO}/tree/main/{target_dir}")
462
+ except Exception as exc:
463
+ print(f"[upload] skipped: {type(exc).__name__}: {exc}")
464
+ traceback.print_exc()