muskan singh Claude Opus 4.7 commited on
Commit
2ab0fe0
·
1 Parent(s): 7a0b2ce

fix: pin trl<=0.24, multi-step reward, lower LR, reduce NUM_GEN

Browse files

- requirements.txt: pin trl>=0.18.2,<=0.24.0 (trl 1.x breaks Unsloth patches → silent crash at step 21)
- train.py: multi-step reward fn (REWARD_STEPS=2) — cumulative score not single-step
- train.py: NUM_GEN 4→2 to halve VRAM pressure from G×reward_steps inference calls
- train.py: LR 5e-5→8e-6 (5e-5 was unstable, caused reward oscillation)
- train.py: switch to max_steps=150 training (more reliable than epoch-based)
- train.py: model.config.max_length=None to silence max_new_tokens warning
- train.py: reward_funcs=[orgos_reward_fn] as list (required by TRL)
- train.py: BATCH_SIZE 4→1 with GRAD_ACCUM=2 (matches memory budget with multi-step reward)

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

Files changed (2) hide show
  1. requirements.txt +1 -1
  2. train.py +101 -51
requirements.txt CHANGED
@@ -11,7 +11,7 @@ aiofiles>=23.0.0
11
  torch
12
  transformers
13
  datasets
14
- trl
15
  unsloth
16
  matplotlib
17
  numpy
 
11
  torch
12
  transformers
13
  datasets
14
+ trl>=0.18.2,<=0.24.0
15
  unsloth
16
  matplotlib
17
  numpy
train.py CHANGED
@@ -1,6 +1,6 @@
1
  """
2
  OrgOS GRPO Training Script
3
- Equivalent to training/grpo_orgos.ipynb but runs headlessly.
4
 
5
  Outputs:
6
  training_log.txt — structured training log for submission
@@ -35,20 +35,22 @@ from unsloth import FastLanguageModel
35
  # Config
36
  # ------------------------------------------------------------------
37
 
38
- MODEL_NAME = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-3B-Instruct")
39
  ENV_URL = "http://localhost:8000"
40
  LOG_FILE = "training_log.txt"
41
  N_PROMPTS_PER_WORKFLOW = 20
42
  N_EVAL = 10
43
- NUM_EPOCHS = 3
44
- BATCH_SIZE = 4
45
  GRAD_ACCUM = 2
46
- LR = 5e-5
47
- NUM_GEN = 4
48
- TEMPERATURE = 0.8
49
  BETA = 0.04
50
  LORA_R = 16
51
  MAX_SEQ_LEN = 4096
 
 
52
 
53
  # ------------------------------------------------------------------
54
  # Logger
@@ -77,7 +79,6 @@ def start_env_server():
77
  stdout=None,
78
  stderr=None,
79
  )
80
- # Wait until healthy
81
  for _ in range(20):
82
  time.sleep(2)
83
  try:
@@ -112,6 +113,9 @@ def load_model():
112
  use_gradient_checkpointing = "unsloth",
113
  random_state = 42,
114
  )
 
 
 
115
  trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
116
  tlog(f"[TRAIN_CONFIG] model={MODEL_NAME} lora_r={LORA_R} "
117
  f"max_seq_len={MAX_SEQ_LEN} trainable_params={trainable:,} quantization=4bit")
@@ -160,6 +164,12 @@ CRITICAL RULES:
160
  6. Stop when pending_steps is empty or done=true.
161
  """
162
 
 
 
 
 
 
 
163
 
164
  def obs_to_text(obs: dict) -> str:
165
  hints = obs.get("schema_hints", {})
@@ -186,25 +196,15 @@ def obs_to_text(obs: dict) -> str:
186
  "",
187
  "=== APP STATES ===",
188
  ]
189
- # workflow-relevant apps only — skip apps the workflow doesn't touch
190
- WORKFLOW_APPS = {
191
- "A": {"jira", "zendesk", "salesforce", "workday"},
192
- "B": {"zendesk", "salesforce", "workday"},
193
- "C": {"jira", "zendesk", "salesforce"},
194
- }
195
- relevant = WORKFLOW_APPS.get(
196
- obs.get("workflow_id", "A"),
197
- {"jira", "zendesk", "salesforce", "workday"},
198
- )
199
  for app_name, view in obs.get("app_states", {}).items():
200
  if app_name not in relevant:
201
  continue
202
- lines.append(f" [{app_name.upper()}]")
203
  view_str = str(view)
204
  if len(view_str) > 600:
205
  view_str = view_str[:600] + "...[truncated]"
206
- lines.append(f" {view_str}")
207
- lines.append("")
208
  return "\n".join(lines)
209
 
210
 
@@ -244,36 +244,83 @@ def build_prompt_dataset(tokenizer) -> Dataset:
244
  rows.append({
245
  "prompt": build_prompt(obs_text, tokenizer),
246
  "workflow_id": wf,
247
- "obs_text": obs_text,
248
  })
249
  tlog(f"[TRAIN_CONFIG] algorithm=GRPO prompts={len(rows)} "
250
  f"workflows=A,B,C prompts_per_workflow={N_PROMPTS_PER_WORKFLOW}")
 
251
  return Dataset.from_list(rows)
252
 
253
 
254
  # ------------------------------------------------------------------
255
- # Reward function
256
  # ------------------------------------------------------------------
257
-
258
- def orgos_reward_fn(completions: List[str], prompts: List[str], **kwargs) -> List[float]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  workflow_ids = kwargs.get("workflow_id", ["A"] * len(completions))
260
  rewards = []
 
261
  for completion, wf_id in zip(completions, workflow_ids):
262
  action = parse_action(completion)
263
  if action is None:
264
  rewards.append(-0.1)
265
  continue
266
  try:
267
- httpx.post(f"{ENV_URL}/reset", json={"workflow_id": wf_id}, timeout=10)
 
 
268
  result = httpx.post(f"{ENV_URL}/step", json=action, timeout=10).json()
269
- rewards.append(float(result["reward"]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  except Exception:
271
  rewards.append(-0.1)
 
272
  return rewards
273
 
274
 
275
  # ------------------------------------------------------------------
276
- # Episode evaluation
277
  # ------------------------------------------------------------------
278
 
279
  def run_episode_with_model(model, tokenizer, workflow_id: str, max_steps: int = 15) -> float:
@@ -284,20 +331,14 @@ def run_episode_with_model(model, tokenizer, workflow_id: str, max_steps: int =
284
  if obs["done"]:
285
  break
286
 
287
- # Stateless single-turn prompt — matches the GRPO training format.
288
- # obs["message"] already carries last-action feedback, so no history needed.
289
  obs_text = obs_to_text(obs)
290
- messages = [{"role": "user",
291
- "content": SYSTEM_PROMPT + "\n\n---\n\n" + obs_text}]
292
-
293
- text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
294
- inputs = tokenizer(text, return_tensors="pt").to(model.device)
295
 
296
  with torch.no_grad():
297
  out = model.generate(
298
  **inputs,
299
  max_new_tokens = 256,
300
- temperature = 0.0,
301
  do_sample = False,
302
  pad_token_id = tokenizer.eos_token_id,
303
  )
@@ -314,7 +355,7 @@ def run_episode_with_model(model, tokenizer, workflow_id: str, max_steps: int =
314
  if obs["done"]:
315
  break
316
 
317
- return obs.get("current_score", 0.001)
318
 
319
 
320
  def evaluate(model, tokenizer, phase: str) -> dict:
@@ -325,10 +366,10 @@ def evaluate(model, tokenizer, phase: str) -> dict:
325
  score = run_episode_with_model(model, tokenizer, wf)
326
  scores[wf].append(score)
327
  tlog(f"[EVAL] phase={phase} workflow={wf} episode={ep+1} score={score:.4f}")
328
- wf_mean = np.mean(scores[wf])
329
  tlog(f"[EVAL_WORKFLOW] phase={phase} workflow={wf} "
330
  f"mean={wf_mean:.4f} min={min(scores[wf]):.4f} max={max(scores[wf]):.4f}")
331
- overall = np.mean([s for v in scores.values() for s in v])
332
  tlog(f"[EVAL_END] phase={phase} overall_mean={overall:.4f}")
333
  return scores
334
 
@@ -426,46 +467,55 @@ class OrgOSLogCallback(TrainerCallback):
426
  # ------------------------------------------------------------------
427
 
428
  def main():
 
 
429
  server_proc = start_env_server()
430
 
431
  try:
432
  model, tokenizer = load_model()
433
 
 
 
 
 
434
  prompt_dataset = build_prompt_dataset(tokenizer)
 
 
435
 
436
  # Sanity-check reward function
437
  test_r = orgos_reward_fn(
438
- completions = ['{"app": "zendesk", "operation": "list_tickets", "args": {"state": "new"}}',
439
- "not json"],
440
- prompts = ["", ""],
441
- workflow_id = ["A", "A"],
442
  )
443
  tlog(f"[REWARD_FN_CHECK] valid_action={test_r[0]:.4f} invalid_action={test_r[1]:.4f}")
444
 
445
  # Baseline evaluation
446
  FastLanguageModel.for_inference(model)
447
  baseline_scores = evaluate(model, tokenizer, phase="baseline")
448
- baseline_mean = np.mean([s for v in baseline_scores.values() for s in v])
449
 
450
  # GRPO training
451
- model.train()
452
- tlog(f"[TRAIN_CONFIG] epochs={NUM_EPOCHS} batch_size={BATCH_SIZE} "
453
  f"grad_accum={GRAD_ACCUM} lr={LR} num_generations={NUM_GEN} "
454
- f"temperature={TEMPERATURE} beta_kl={BETA}")
455
 
456
  grpo_config = GRPOConfig(
457
  output_dir = "./orgos_grpo_ckpt",
458
- num_train_epochs = NUM_EPOCHS,
 
459
  per_device_train_batch_size = BATCH_SIZE,
460
  gradient_accumulation_steps = GRAD_ACCUM,
461
  learning_rate = LR,
462
  warmup_steps = 10,
463
  logging_steps = 5,
464
- save_steps = 100,
465
  bf16 = torch.cuda.is_bf16_supported(),
466
  fp16 = not torch.cuda.is_bf16_supported(),
467
  max_grad_norm = 1.0,
468
  num_generations = NUM_GEN,
 
469
  temperature = TEMPERATURE,
470
  beta = BETA,
471
  report_to = "none",
@@ -475,7 +525,7 @@ def main():
475
  trainer = GRPOTrainer(
476
  model = model,
477
  args = grpo_config,
478
- reward_funcs = orgos_reward_fn,
479
  train_dataset = prompt_dataset,
480
  processing_class = tokenizer,
481
  callbacks = [OrgOSLogCallback()],
@@ -490,7 +540,7 @@ def main():
490
  # Post-training evaluation
491
  FastLanguageModel.for_inference(model)
492
  post_scores = evaluate(model, tokenizer, phase="post_training")
493
- post_mean = np.mean([s for v in post_scores.values() for s in v])
494
  improvement = post_mean - baseline_mean
495
 
496
  tlog(
 
1
  """
2
  OrgOS GRPO Training Script
3
+ Runs headlessly on HuggingFace Spaces (A100/T4 GPU).
4
 
5
  Outputs:
6
  training_log.txt — structured training log for submission
 
35
  # Config
36
  # ------------------------------------------------------------------
37
 
38
+ MODEL_NAME = os.environ.get("MODEL_NAME", "unsloth/Qwen2.5-3B-Instruct-bnb-4bit")
39
  ENV_URL = "http://localhost:8000"
40
  LOG_FILE = "training_log.txt"
41
  N_PROMPTS_PER_WORKFLOW = 20
42
  N_EVAL = 10
43
+ MAX_TRAIN_STEPS = 150 # step-based training (more reliable than epoch-based on Spaces)
44
+ BATCH_SIZE = 1
45
  GRAD_ACCUM = 2
46
+ LR = 8e-6 # stable LR — 5e-5 was too high
47
+ NUM_GEN = 2 # candidates per prompt — keep low to save VRAM
48
+ TEMPERATURE = 0.9
49
  BETA = 0.04
50
  LORA_R = 16
51
  MAX_SEQ_LEN = 4096
52
+ MAX_COMPLETION_LENGTH = 256
53
+ REWARD_STEPS = 2 # multi-step rollout depth in reward fn
54
 
55
  # ------------------------------------------------------------------
56
  # Logger
 
79
  stdout=None,
80
  stderr=None,
81
  )
 
82
  for _ in range(20):
83
  time.sleep(2)
84
  try:
 
113
  use_gradient_checkpointing = "unsloth",
114
  random_state = 42,
115
  )
116
+ # Clear max_length to avoid max_new_tokens vs max_length warnings during generate()
117
+ model.config.max_length = None
118
+
119
  trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
120
  tlog(f"[TRAIN_CONFIG] model={MODEL_NAME} lora_r={LORA_R} "
121
  f"max_seq_len={MAX_SEQ_LEN} trainable_params={trainable:,} quantization=4bit")
 
164
  6. Stop when pending_steps is empty or done=true.
165
  """
166
 
167
+ WORKFLOW_APPS = {
168
+ "A": {"jira", "zendesk", "salesforce", "workday"},
169
+ "B": {"zendesk", "salesforce", "workday"},
170
+ "C": {"jira", "zendesk", "salesforce"},
171
+ }
172
+
173
 
174
  def obs_to_text(obs: dict) -> str:
175
  hints = obs.get("schema_hints", {})
 
196
  "",
197
  "=== APP STATES ===",
198
  ]
199
+ relevant = WORKFLOW_APPS.get(obs.get("workflow_id", "A"),
200
+ {"jira", "zendesk", "salesforce", "workday"})
 
 
 
 
 
 
 
 
201
  for app_name, view in obs.get("app_states", {}).items():
202
  if app_name not in relevant:
203
  continue
 
204
  view_str = str(view)
205
  if len(view_str) > 600:
206
  view_str = view_str[:600] + "...[truncated]"
207
+ lines += [f" [{app_name.upper()}]", f" {view_str}", ""]
 
208
  return "\n".join(lines)
209
 
210
 
 
244
  rows.append({
245
  "prompt": build_prompt(obs_text, tokenizer),
246
  "workflow_id": wf,
 
247
  })
248
  tlog(f"[TRAIN_CONFIG] algorithm=GRPO prompts={len(rows)} "
249
  f"workflows=A,B,C prompts_per_workflow={N_PROMPTS_PER_WORKFLOW}")
250
+ sample_tokens = None # set below after tokenizer is available
251
  return Dataset.from_list(rows)
252
 
253
 
254
  # ------------------------------------------------------------------
255
+ # Reward function — multi-step live environment rollout
256
  # ------------------------------------------------------------------
257
+ # The model reference is set in main() before training starts.
258
+ _reward_model = None
259
+ _reward_tokenizer = None
260
+
261
+
262
+ def orgos_reward_fn(completions: List[str], prompts: List[str] = None, **kwargs) -> List[float]:
263
+ """
264
+ For each GRPO candidate:
265
+ 1. Parse as JSON action.
266
+ 2. Reset env and apply the action (step 1).
267
+ 3. Continue REWARD_STEPS-1 more greedy steps with the current model.
268
+ 4. Return cumulative episode score — not just single-step reward.
269
+
270
+ Multi-step signal prevents the model from collapsing to always outputting
271
+ list_tickets (which gives a small single-step reward but never advances the workflow).
272
+ """
273
  workflow_ids = kwargs.get("workflow_id", ["A"] * len(completions))
274
  rewards = []
275
+
276
  for completion, wf_id in zip(completions, workflow_ids):
277
  action = parse_action(completion)
278
  if action is None:
279
  rewards.append(-0.1)
280
  continue
281
  try:
282
+ # Reset env and apply the GRPO-generated action (step 1)
283
+ obs = httpx.post(f"{ENV_URL}/reset",
284
+ json={"workflow_id": wf_id}, timeout=10).json()["observation"]
285
  result = httpx.post(f"{ENV_URL}/step", json=action, timeout=10).json()
286
+ obs = result["observation"]
287
+
288
+ # Continue REWARD_STEPS-1 more steps with current model (greedy)
289
+ if _reward_model is not None:
290
+ for _ in range(REWARD_STEPS - 1):
291
+ if obs.get("done"):
292
+ break
293
+ prompt_text = build_prompt(obs_to_text(obs), _reward_tokenizer)
294
+ inputs = _reward_tokenizer(
295
+ prompt_text, return_tensors="pt"
296
+ ).to(_reward_model.device)
297
+ with torch.no_grad():
298
+ out = _reward_model.generate(
299
+ **inputs,
300
+ max_new_tokens = 128,
301
+ do_sample = False,
302
+ pad_token_id = _reward_tokenizer.eos_token_id,
303
+ )
304
+ cont_str = _reward_tokenizer.decode(
305
+ out[0][inputs["input_ids"].shape[1]:],
306
+ skip_special_tokens=True,
307
+ ).strip()
308
+ cont_action = parse_action(cont_str)
309
+ if cont_action is None:
310
+ break
311
+ result = httpx.post(f"{ENV_URL}/step",
312
+ json=cont_action, timeout=10).json()
313
+ obs = result["observation"]
314
+
315
+ rewards.append(float(obs.get("current_score", 0.001)))
316
  except Exception:
317
  rewards.append(-0.1)
318
+
319
  return rewards
320
 
321
 
322
  # ------------------------------------------------------------------
323
+ # Episode evaluation (stateless — each step is a fresh single-turn prompt)
324
  # ------------------------------------------------------------------
325
 
326
  def run_episode_with_model(model, tokenizer, workflow_id: str, max_steps: int = 15) -> float:
 
331
  if obs["done"]:
332
  break
333
 
 
 
334
  obs_text = obs_to_text(obs)
335
+ text = build_prompt(obs_text, tokenizer)
336
+ inputs = tokenizer(text, return_tensors="pt").to(model.device)
 
 
 
337
 
338
  with torch.no_grad():
339
  out = model.generate(
340
  **inputs,
341
  max_new_tokens = 256,
 
342
  do_sample = False,
343
  pad_token_id = tokenizer.eos_token_id,
344
  )
 
355
  if obs["done"]:
356
  break
357
 
358
+ return float(obs.get("current_score", 0.001))
359
 
360
 
361
  def evaluate(model, tokenizer, phase: str) -> dict:
 
366
  score = run_episode_with_model(model, tokenizer, wf)
367
  scores[wf].append(score)
368
  tlog(f"[EVAL] phase={phase} workflow={wf} episode={ep+1} score={score:.4f}")
369
+ wf_mean = float(np.mean(scores[wf]))
370
  tlog(f"[EVAL_WORKFLOW] phase={phase} workflow={wf} "
371
  f"mean={wf_mean:.4f} min={min(scores[wf]):.4f} max={max(scores[wf]):.4f}")
372
+ overall = float(np.mean([s for v in scores.values() for s in v]))
373
  tlog(f"[EVAL_END] phase={phase} overall_mean={overall:.4f}")
374
  return scores
375
 
 
467
  # ------------------------------------------------------------------
468
 
469
  def main():
470
+ global _reward_model, _reward_tokenizer
471
+
472
  server_proc = start_env_server()
473
 
474
  try:
475
  model, tokenizer = load_model()
476
 
477
+ # Wire up the reward function's model reference (used for multi-step rollouts)
478
+ _reward_model = model
479
+ _reward_tokenizer = tokenizer
480
+
481
  prompt_dataset = build_prompt_dataset(tokenizer)
482
+ tok_len = len(tokenizer(prompt_dataset[0]["prompt"]).input_ids)
483
+ tlog(f"[PROMPT_DEBUG] first_prompt_tokens={tok_len}")
484
 
485
  # Sanity-check reward function
486
  test_r = orgos_reward_fn(
487
+ completions = ['{"app": "zendesk", "operation": "list_tickets", "args": {}}',
488
+ "not json"],
489
+ prompts = ["", ""],
490
+ workflow_id = ["A", "A"],
491
  )
492
  tlog(f"[REWARD_FN_CHECK] valid_action={test_r[0]:.4f} invalid_action={test_r[1]:.4f}")
493
 
494
  # Baseline evaluation
495
  FastLanguageModel.for_inference(model)
496
  baseline_scores = evaluate(model, tokenizer, phase="baseline")
497
+ baseline_mean = float(np.mean([s for v in baseline_scores.values() for s in v]))
498
 
499
  # GRPO training
500
+ FastLanguageModel.for_training(model)
501
+ tlog(f"[TRAIN_CONFIG] max_steps={MAX_TRAIN_STEPS} batch_size={BATCH_SIZE} "
502
  f"grad_accum={GRAD_ACCUM} lr={LR} num_generations={NUM_GEN} "
503
+ f"temperature={TEMPERATURE} beta_kl={BETA} reward_steps={REWARD_STEPS}")
504
 
505
  grpo_config = GRPOConfig(
506
  output_dir = "./orgos_grpo_ckpt",
507
+ num_train_epochs = 1,
508
+ max_steps = MAX_TRAIN_STEPS,
509
  per_device_train_batch_size = BATCH_SIZE,
510
  gradient_accumulation_steps = GRAD_ACCUM,
511
  learning_rate = LR,
512
  warmup_steps = 10,
513
  logging_steps = 5,
 
514
  bf16 = torch.cuda.is_bf16_supported(),
515
  fp16 = not torch.cuda.is_bf16_supported(),
516
  max_grad_norm = 1.0,
517
  num_generations = NUM_GEN,
518
+ max_new_tokens = MAX_COMPLETION_LENGTH,
519
  temperature = TEMPERATURE,
520
  beta = BETA,
521
  report_to = "none",
 
525
  trainer = GRPOTrainer(
526
  model = model,
527
  args = grpo_config,
528
+ reward_funcs = [orgos_reward_fn],
529
  train_dataset = prompt_dataset,
530
  processing_class = tokenizer,
531
  callbacks = [OrgOSLogCallback()],
 
540
  # Post-training evaluation
541
  FastLanguageModel.for_inference(model)
542
  post_scores = evaluate(model, tokenizer, phase="post_training")
543
+ post_mean = float(np.mean([s for v in post_scores.values() for s in v]))
544
  improvement = post_mean - baseline_mean
545
 
546
  tlog(