av-codes commited on
Commit
ba04dd9
Β·
verified Β·
1 Parent(s): 5b3ce9d

v2 training script: paper-aligned recipe (bf16, lr=2.2e-4, wd=0.1, constant LR, beta2=0.95)

Browse files
Files changed (1) hide show
  1. train_hrm_text_pi.py +19 -14
train_hrm_text_pi.py CHANGED
@@ -483,7 +483,7 @@ class HrmTextClassifier(nn.Module):
483
 
484
  loss = None
485
  if labels is not None:
486
- loss = F.cross_entropy(logits, labels)
487
 
488
  return {"logits": logits, "loss": loss}
489
 
@@ -754,7 +754,7 @@ def count_params(model):
754
  def main():
755
  parser = argparse.ArgumentParser(description="Train HRM-Text prompt injection detector")
756
  parser.add_argument("--test", action="store_true", help="Smoke test on 64 samples")
757
- parser.add_argument("--lr", type=float, default=5e-4)
758
  parser.add_argument("--epochs", type=int, default=3)
759
  parser.add_argument("--batch_size", type=int, default=32)
760
  parser.add_argument("--output_dir", type=str, default="./pi-hrm-text")
@@ -781,6 +781,8 @@ def main():
781
  help="Cache dir for dataset download")
782
  parser.add_argument("--max_steps", type=int, default=-1,
783
  help="Max training steps (-1 = use epochs)")
 
 
784
  args = parser.parse_args()
785
 
786
  set_seed(args.seed)
@@ -844,7 +846,7 @@ def main():
844
  eval_dataset = merged.select(range(min(32, len(merged))))
845
  else:
846
  split = merged.train_test_split(
847
- test_size=0.1, seed=args.seed, stratify_by_column="label",
848
  )
849
  train_dataset = split["train"]
850
  eval_dataset = split["test"]
@@ -925,13 +927,13 @@ def main():
925
  per_device_eval_batch_size=min(args.batch_size * 2, 16),
926
  num_train_epochs=args.epochs,
927
  max_steps=args.max_steps,
928
- weight_decay=0.01,
929
- warmup_steps=500 if not args.test else 0,
930
- lr_scheduler_type="cosine",
931
  eval_strategy="steps",
932
- eval_steps=4000,
933
  save_strategy="steps",
934
- save_steps=4000,
935
  load_best_model_at_end=True,
936
  metric_for_best_model="f1",
937
  greater_is_better=True,
@@ -940,14 +942,15 @@ def main():
940
  logging_first_step=True,
941
  logging_steps=5 if args.test else 20,
942
  disable_tqdm=False if args.test else True,
943
- fp16=use_cuda,
944
- bf16=False,
945
  push_to_hub=True,
946
  hub_model_id=args.push_to_hub,
947
  hub_strategy="every_save",
948
  use_cpu=not use_cuda,
949
  dataloader_num_workers=4,
950
  seed=args.seed,
 
951
  save_only_model=True,
952
  remove_unused_columns=False,
953
  ddp_find_unused_parameters=True,
@@ -972,7 +975,7 @@ def main():
972
  # ── Train ─────────────────────────────────────────────────────────────
973
  print("\nπŸš€ Training...")
974
  train_start = time.time()
975
- trainer.train()
976
  train_elapsed = time.time() - train_start
977
  print(f"βœ… Training complete! ({train_elapsed:.1f}s)")
978
  print(f" Best checkpoint: {trainer.state.best_model_checkpoint}")
@@ -1035,9 +1038,11 @@ def main():
1035
  "learning_rate": args.lr,
1036
  "epochs": args.epochs,
1037
  "batch_size": args.batch_size,
1038
- "weight_decay": 0.01,
1039
- "scheduler": "cosine",
1040
- "warmup_steps": 500 if not args.test else 0,
 
 
1041
  },
1042
  }
1043
  with open(os.path.join(best_model_path, "config.json"), "w") as f:
 
483
 
484
  loss = None
485
  if labels is not None:
486
+ loss = F.cross_entropy(logits.float(), labels)
487
 
488
  return {"logits": logits, "loss": loss}
489
 
 
754
  def main():
755
  parser = argparse.ArgumentParser(description="Train HRM-Text prompt injection detector")
756
  parser.add_argument("--test", action="store_true", help="Smoke test on 64 samples")
757
+ parser.add_argument("--lr", type=float, default=2.2e-4)
758
  parser.add_argument("--epochs", type=int, default=3)
759
  parser.add_argument("--batch_size", type=int, default=32)
760
  parser.add_argument("--output_dir", type=str, default="./pi-hrm-text")
 
781
  help="Cache dir for dataset download")
782
  parser.add_argument("--max_steps", type=int, default=-1,
783
  help="Max training steps (-1 = use epochs)")
784
+ parser.add_argument("--resume_from_checkpoint", type=str, default=None,
785
+ help="Path to checkpoint dir to resume from")
786
  args = parser.parse_args()
787
 
788
  set_seed(args.seed)
 
846
  eval_dataset = merged.select(range(min(32, len(merged))))
847
  else:
848
  split = merged.train_test_split(
849
+ test_size=0.05, seed=args.seed, stratify_by_column="label",
850
  )
851
  train_dataset = split["train"]
852
  eval_dataset = split["test"]
 
927
  per_device_eval_batch_size=min(args.batch_size * 2, 16),
928
  num_train_epochs=args.epochs,
929
  max_steps=args.max_steps,
930
+ weight_decay=0.1,
931
+ warmup_steps=2000 if not args.test else 0,
932
+ lr_scheduler_type="constant_with_warmup",
933
  eval_strategy="steps",
934
+ eval_steps=1000,
935
  save_strategy="steps",
936
+ save_steps=1000,
937
  load_best_model_at_end=True,
938
  metric_for_best_model="f1",
939
  greater_is_better=True,
 
942
  logging_first_step=True,
943
  logging_steps=5 if args.test else 20,
944
  disable_tqdm=False if args.test else True,
945
+ fp16=False,
946
+ bf16=use_cuda,
947
  push_to_hub=True,
948
  hub_model_id=args.push_to_hub,
949
  hub_strategy="every_save",
950
  use_cpu=not use_cuda,
951
  dataloader_num_workers=4,
952
  seed=args.seed,
953
+ adam_beta2=0.95,
954
  save_only_model=True,
955
  remove_unused_columns=False,
956
  ddp_find_unused_parameters=True,
 
975
  # ── Train ─────────────────────────────────────────────────────────────
976
  print("\nπŸš€ Training...")
977
  train_start = time.time()
978
+ trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
979
  train_elapsed = time.time() - train_start
980
  print(f"βœ… Training complete! ({train_elapsed:.1f}s)")
981
  print(f" Best checkpoint: {trainer.state.best_model_checkpoint}")
 
1038
  "learning_rate": args.lr,
1039
  "epochs": args.epochs,
1040
  "batch_size": args.batch_size,
1041
+ "weight_decay": 0.1,
1042
+ "scheduler": "constant_with_warmup",
1043
+ "warmup_steps": 2000 if not args.test else 0,
1044
+ "adam_beta2": 0.95,
1045
+ "precision": "bf16",
1046
  },
1047
  }
1048
  with open(os.path.join(best_model_path, "config.json"), "w") as f: