aamrinder commited on
Commit
bb95b22
·
verified ·
1 Parent(s): 42cdedd

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. train/train_grpo.py +84 -66
train/train_grpo.py CHANGED
@@ -405,7 +405,18 @@ def main():
405
  processing_class=tokenizer,
406
  )
407
  trainer.train()
 
 
 
 
408
  trainer.save_model(args.output_dir)
 
 
 
 
 
 
 
409
  print(f"[done] checkpoint saved to {args.output_dir}")
410
 
411
  # ----- HELD-OUT GENERALIZATION EVAL -----
@@ -415,6 +426,7 @@ def main():
415
  print(f"\n[held-out-eval] running trained model on {min(args.n_eval_clips, len(eval_ids))} held-out clips")
416
  eval_clip_ids = sorted(eval_ids)[: args.n_eval_clips]
417
  held_out_results = []
 
418
  model.eval()
419
  if hasattr(model, "gradient_checkpointing_disable"):
420
  try: model.gradient_checkpointing_disable()
@@ -422,46 +434,50 @@ def main():
422
  n_eval_correct = 0
423
  n_eval_well_formed = 0
424
  eval_rewards = []
425
- for i, cid in enumerate(eval_clip_ids):
426
- sc = scenarios[cid]
427
- gold = "sarcastic" if sc["sarcasm"] else "sincere"
428
- messages = [
429
- {"role": "system", "content": SYSTEM_PROMPT},
430
- {"role": "user", "content": build_full_observation(cid, scenarios)},
431
- ]
432
- encoded = tokenizer.apply_chat_template(
433
- messages, return_tensors="pt", add_generation_prompt=True,
434
- )
435
- input_ids = encoded.input_ids if hasattr(encoded, "input_ids") else encoded
436
- input_ids = input_ids.to(model.device)
437
- prompt_len = input_ids.shape[1]
438
- with _t.no_grad():
439
- out = model.generate(
440
- input_ids=input_ids,
441
- max_new_tokens=args.max_completion_length,
442
- do_sample=False,
443
- pad_token_id=tokenizer.eos_token_id,
444
- use_cache=True,
445
  )
446
- text = tokenizer.decode(out[0][prompt_len:], skip_special_tokens=True)
447
- decomp = reward_decomposition(text, gold)
448
- held_out_results.append({
449
- "clip_id": cid,
450
- "gold": gold,
451
- "is_pivot": bool(sc.get("is_pivot")),
452
- "predicted": decomp["_predicted"],
453
- "confidence": decomp["_confidence"],
454
- "correct": decomp["_correct"],
455
- "well_formed": decomp["_well_formed"],
456
- "reward_total": decomp["_total"],
457
- "completion_text": text[:1500],
458
- })
459
- eval_rewards.append(decomp["_total"])
460
- if decomp["_correct"]: n_eval_correct += 1
461
- if decomp["_well_formed"]: n_eval_well_formed += 1
462
- if (i + 1) % 20 == 0:
463
- print(f" [{i+1}/{len(eval_clip_ids)}] running mean reward = {sum(eval_rewards)/len(eval_rewards):.3f}, "
464
- f"correct so far = {n_eval_correct}/{i+1}", flush=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
465
  eval_summary = {
466
  "n_eval_clips": len(eval_clip_ids),
467
  "mean_reward": sum(eval_rewards) / max(1, len(eval_rewards)),
@@ -498,36 +514,38 @@ def main():
498
  except Exception as e:
499
  print(f"[error] push_to_hub failed: {e}")
500
 
501
- # Push trainer_state.json (with per-step rewards/loss) to HF Space for plotting
 
 
502
  if args.save_trainer_state_to_hub_space:
503
- try:
504
- from huggingface_hub import HfApi
505
- from pathlib import Path as _P
506
- state_path = _P(args.output_dir) / "trainer_state.json"
507
- if state_path.exists():
508
- HfApi().upload_file(
509
- path_or_fileobj=str(state_path),
510
- path_in_repo="data/trainer_state_run1.json",
511
- repo_id=args.save_trainer_state_to_hub_space,
 
 
 
 
 
 
 
 
 
 
512
  repo_type="space",
513
- commit_message=f"GRPO Run #1 trainer_state ({args.max_steps} steps)",
514
  )
515
- print(f"[done] trainer_state.json pushed to {args.save_trainer_state_to_hub_space}/data/trainer_state_run1.json")
516
- else:
517
- print(f"[warn] trainer_state.json not found at {state_path}")
518
- # Also push held_out_eval.json (the proof of generalization)
519
- held_out_path = _P(args.output_dir) / "held_out_eval.json"
520
- if held_out_path.exists():
521
- HfApi().upload_file(
522
- path_or_fileobj=str(held_out_path),
523
- path_in_repo="data/held_out_eval_run1.json",
524
- repo_id=args.save_trainer_state_to_hub_space,
525
- repo_type="space",
526
- commit_message=f"GRPO Run #1 held-out eval ({args.n_eval_clips} clips)",
527
- )
528
- print(f"[done] held_out_eval.json pushed")
529
- except Exception as e:
530
- print(f"[error] save_trainer_state_to_hub_space failed: {e}")
531
 
532
 
533
  if __name__ == "__main__":
 
405
  processing_class=tokenizer,
406
  )
407
  trainer.train()
408
+ # CRITICAL: save_state() writes trainer_state.json to output_dir.
409
+ # save_model() alone only saves the adapter weights, NOT the per-step log.
410
+ # In Run #2, we missed save_state() and lost the reward history that drives the plot.
411
+ trainer.save_state()
412
  trainer.save_model(args.output_dir)
413
+ # Also explicitly write the log_history to a JSON we know we can find.
414
+ try:
415
+ log_path = Path(args.output_dir) / "log_history.json"
416
+ log_path.write_text(json.dumps(trainer.state.log_history, indent=2))
417
+ print(f"[done] log_history saved to {log_path} ({len(trainer.state.log_history)} entries)")
418
+ except Exception as e:
419
+ print(f"[warn] couldn't write log_history.json: {e}")
420
  print(f"[done] checkpoint saved to {args.output_dir}")
421
 
422
  # ----- HELD-OUT GENERALIZATION EVAL -----
 
426
  print(f"\n[held-out-eval] running trained model on {min(args.n_eval_clips, len(eval_ids))} held-out clips")
427
  eval_clip_ids = sorted(eval_ids)[: args.n_eval_clips]
428
  held_out_results = []
429
+ eval_failed = False
430
  model.eval()
431
  if hasattr(model, "gradient_checkpointing_disable"):
432
  try: model.gradient_checkpointing_disable()
 
434
  n_eval_correct = 0
435
  n_eval_well_formed = 0
436
  eval_rewards = []
437
+ try:
438
+ for i, cid in enumerate(eval_clip_ids):
439
+ sc = scenarios[cid]
440
+ gold = "sarcastic" if sc["sarcasm"] else "sincere"
441
+ messages = [
442
+ {"role": "system", "content": SYSTEM_PROMPT},
443
+ {"role": "user", "content": build_full_observation(cid, scenarios)},
444
+ ]
445
+ encoded = tokenizer.apply_chat_template(
446
+ messages, return_tensors="pt", add_generation_prompt=True,
 
 
 
 
 
 
 
 
 
 
447
  )
448
+ input_ids = encoded.input_ids if hasattr(encoded, "input_ids") else encoded
449
+ input_ids = input_ids.to(model.device)
450
+ prompt_len = input_ids.shape[1]
451
+ with _t.no_grad():
452
+ out = model.generate(
453
+ input_ids=input_ids,
454
+ max_new_tokens=args.max_completion_length,
455
+ do_sample=False,
456
+ pad_token_id=tokenizer.eos_token_id,
457
+ use_cache=True,
458
+ )
459
+ text = tokenizer.decode(out[0][prompt_len:], skip_special_tokens=True)
460
+ decomp = reward_decomposition(text, gold)
461
+ held_out_results.append({
462
+ "clip_id": cid,
463
+ "gold": gold,
464
+ "is_pivot": bool(sc.get("is_pivot")),
465
+ "predicted": decomp["_predicted"],
466
+ "confidence": decomp["_confidence"],
467
+ "correct": decomp["_correct"],
468
+ "well_formed": decomp["_well_formed"],
469
+ "reward_total": decomp["_total"],
470
+ "completion_text": text[:1500],
471
+ })
472
+ eval_rewards.append(decomp["_total"])
473
+ if decomp["_correct"]: n_eval_correct += 1
474
+ if decomp["_well_formed"]: n_eval_well_formed += 1
475
+ if (i + 1) % 20 == 0:
476
+ print(f" [{i+1}/{len(eval_clip_ids)}] running mean reward = {sum(eval_rewards)/len(eval_rewards):.3f}, "
477
+ f"correct so far = {n_eval_correct}/{i+1}", flush=True)
478
+ except Exception as e:
479
+ print(f"[error] held-out eval crashed at clip {i}: {e}")
480
+ eval_failed = True
481
  eval_summary = {
482
  "n_eval_clips": len(eval_clip_ids),
483
  "mean_reward": sum(eval_rewards) / max(1, len(eval_rewards)),
 
514
  except Exception as e:
515
  print(f"[error] push_to_hub failed: {e}")
516
 
517
+ # Push trainer_state.json + log_history.json + held_out_eval.json to HF Space.
518
+ # Each upload is wrapped individually so a partial network failure doesn't
519
+ # kill the whole script. We need at least the held_out_eval JSON to land.
520
  if args.save_trainer_state_to_hub_space:
521
+ from huggingface_hub import HfApi
522
+ from pathlib import Path as _P
523
+ repo_id = args.save_trainer_state_to_hub_space
524
+ run_tag = _P(args.output_dir).name # e.g. "run3"
525
+ api = HfApi()
526
+ for local_name, hub_name, label in [
527
+ ("trainer_state.json", f"data/trainer_state_{run_tag}.json", "trainer_state"),
528
+ ("log_history.json", f"data/log_history_{run_tag}.json", "log_history"),
529
+ ("held_out_eval.json", f"data/held_out_eval_{run_tag}.json", "held_out_eval"),
530
+ ]:
531
+ path = _P(args.output_dir) / local_name
532
+ if not path.exists():
533
+ print(f"[warn] {local_name} not found at {path}, skipping upload")
534
+ continue
535
+ try:
536
+ api.upload_file(
537
+ path_or_fileobj=str(path),
538
+ path_in_repo=hub_name,
539
+ repo_id=repo_id,
540
  repo_type="space",
541
+ commit_message=f"GRPO {run_tag} {label} ({args.max_steps} steps)",
542
  )
543
+ print(f"[done] {label} pushed to {repo_id}/{hub_name}")
544
+ except Exception as e:
545
+ print(f"[error] upload {label} failed: {e}")
546
+
547
+ print(f"\n[main] subtext-arena GRPO run finished cleanly.")
548
+ sys.exit(0)
 
 
 
 
 
 
 
 
 
 
549
 
550
 
551
  if __name__ == "__main__":