v2 training script: paper-aligned recipe (bf16, lr=2.2e-4, wd=0.1, constant LR, beta2=0.95)
Browse files- train_hrm_text_pi.py +19 -14
train_hrm_text_pi.py
CHANGED
|
@@ -483,7 +483,7 @@ class HrmTextClassifier(nn.Module):
|
|
| 483 |
|
| 484 |
loss = None
|
| 485 |
if labels is not None:
|
| 486 |
-
loss = F.cross_entropy(logits, labels)
|
| 487 |
|
| 488 |
return {"logits": logits, "loss": loss}
|
| 489 |
|
|
@@ -754,7 +754,7 @@ def count_params(model):
|
|
| 754 |
def main():
|
| 755 |
parser = argparse.ArgumentParser(description="Train HRM-Text prompt injection detector")
|
| 756 |
parser.add_argument("--test", action="store_true", help="Smoke test on 64 samples")
|
| 757 |
-
parser.add_argument("--lr", type=float, default=
|
| 758 |
parser.add_argument("--epochs", type=int, default=3)
|
| 759 |
parser.add_argument("--batch_size", type=int, default=32)
|
| 760 |
parser.add_argument("--output_dir", type=str, default="./pi-hrm-text")
|
|
@@ -781,6 +781,8 @@ def main():
|
|
| 781 |
help="Cache dir for dataset download")
|
| 782 |
parser.add_argument("--max_steps", type=int, default=-1,
|
| 783 |
help="Max training steps (-1 = use epochs)")
|
|
|
|
|
|
|
| 784 |
args = parser.parse_args()
|
| 785 |
|
| 786 |
set_seed(args.seed)
|
|
@@ -844,7 +846,7 @@ def main():
|
|
| 844 |
eval_dataset = merged.select(range(min(32, len(merged))))
|
| 845 |
else:
|
| 846 |
split = merged.train_test_split(
|
| 847 |
-
test_size=0.
|
| 848 |
)
|
| 849 |
train_dataset = split["train"]
|
| 850 |
eval_dataset = split["test"]
|
|
@@ -925,13 +927,13 @@ def main():
|
|
| 925 |
per_device_eval_batch_size=min(args.batch_size * 2, 16),
|
| 926 |
num_train_epochs=args.epochs,
|
| 927 |
max_steps=args.max_steps,
|
| 928 |
-
weight_decay=0.
|
| 929 |
-
warmup_steps=
|
| 930 |
-
lr_scheduler_type="
|
| 931 |
eval_strategy="steps",
|
| 932 |
-
eval_steps=
|
| 933 |
save_strategy="steps",
|
| 934 |
-
save_steps=
|
| 935 |
load_best_model_at_end=True,
|
| 936 |
metric_for_best_model="f1",
|
| 937 |
greater_is_better=True,
|
|
@@ -940,14 +942,15 @@ def main():
|
|
| 940 |
logging_first_step=True,
|
| 941 |
logging_steps=5 if args.test else 20,
|
| 942 |
disable_tqdm=False if args.test else True,
|
| 943 |
-
fp16=
|
| 944 |
-
bf16=
|
| 945 |
push_to_hub=True,
|
| 946 |
hub_model_id=args.push_to_hub,
|
| 947 |
hub_strategy="every_save",
|
| 948 |
use_cpu=not use_cuda,
|
| 949 |
dataloader_num_workers=4,
|
| 950 |
seed=args.seed,
|
|
|
|
| 951 |
save_only_model=True,
|
| 952 |
remove_unused_columns=False,
|
| 953 |
ddp_find_unused_parameters=True,
|
|
@@ -972,7 +975,7 @@ def main():
|
|
| 972 |
# ββ Train βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 973 |
print("\nπ Training...")
|
| 974 |
train_start = time.time()
|
| 975 |
-
trainer.train()
|
| 976 |
train_elapsed = time.time() - train_start
|
| 977 |
print(f"β
Training complete! ({train_elapsed:.1f}s)")
|
| 978 |
print(f" Best checkpoint: {trainer.state.best_model_checkpoint}")
|
|
@@ -1035,9 +1038,11 @@ def main():
|
|
| 1035 |
"learning_rate": args.lr,
|
| 1036 |
"epochs": args.epochs,
|
| 1037 |
"batch_size": args.batch_size,
|
| 1038 |
-
"weight_decay": 0.
|
| 1039 |
-
"scheduler": "
|
| 1040 |
-
"warmup_steps":
|
|
|
|
|
|
|
| 1041 |
},
|
| 1042 |
}
|
| 1043 |
with open(os.path.join(best_model_path, "config.json"), "w") as f:
|
|
|
|
| 483 |
|
| 484 |
loss = None
|
| 485 |
if labels is not None:
|
| 486 |
+
loss = F.cross_entropy(logits.float(), labels)
|
| 487 |
|
| 488 |
return {"logits": logits, "loss": loss}
|
| 489 |
|
|
|
|
| 754 |
def main():
|
| 755 |
parser = argparse.ArgumentParser(description="Train HRM-Text prompt injection detector")
|
| 756 |
parser.add_argument("--test", action="store_true", help="Smoke test on 64 samples")
|
| 757 |
+
parser.add_argument("--lr", type=float, default=2.2e-4)
|
| 758 |
parser.add_argument("--epochs", type=int, default=3)
|
| 759 |
parser.add_argument("--batch_size", type=int, default=32)
|
| 760 |
parser.add_argument("--output_dir", type=str, default="./pi-hrm-text")
|
|
|
|
| 781 |
help="Cache dir for dataset download")
|
| 782 |
parser.add_argument("--max_steps", type=int, default=-1,
|
| 783 |
help="Max training steps (-1 = use epochs)")
|
| 784 |
+
parser.add_argument("--resume_from_checkpoint", type=str, default=None,
|
| 785 |
+
help="Path to checkpoint dir to resume from")
|
| 786 |
args = parser.parse_args()
|
| 787 |
|
| 788 |
set_seed(args.seed)
|
|
|
|
| 846 |
eval_dataset = merged.select(range(min(32, len(merged))))
|
| 847 |
else:
|
| 848 |
split = merged.train_test_split(
|
| 849 |
+
test_size=0.05, seed=args.seed, stratify_by_column="label",
|
| 850 |
)
|
| 851 |
train_dataset = split["train"]
|
| 852 |
eval_dataset = split["test"]
|
|
|
|
| 927 |
per_device_eval_batch_size=min(args.batch_size * 2, 16),
|
| 928 |
num_train_epochs=args.epochs,
|
| 929 |
max_steps=args.max_steps,
|
| 930 |
+
weight_decay=0.1,
|
| 931 |
+
warmup_steps=2000 if not args.test else 0,
|
| 932 |
+
lr_scheduler_type="constant_with_warmup",
|
| 933 |
eval_strategy="steps",
|
| 934 |
+
eval_steps=1000,
|
| 935 |
save_strategy="steps",
|
| 936 |
+
save_steps=1000,
|
| 937 |
load_best_model_at_end=True,
|
| 938 |
metric_for_best_model="f1",
|
| 939 |
greater_is_better=True,
|
|
|
|
| 942 |
logging_first_step=True,
|
| 943 |
logging_steps=5 if args.test else 20,
|
| 944 |
disable_tqdm=False if args.test else True,
|
| 945 |
+
fp16=False,
|
| 946 |
+
bf16=use_cuda,
|
| 947 |
push_to_hub=True,
|
| 948 |
hub_model_id=args.push_to_hub,
|
| 949 |
hub_strategy="every_save",
|
| 950 |
use_cpu=not use_cuda,
|
| 951 |
dataloader_num_workers=4,
|
| 952 |
seed=args.seed,
|
| 953 |
+
adam_beta2=0.95,
|
| 954 |
save_only_model=True,
|
| 955 |
remove_unused_columns=False,
|
| 956 |
ddp_find_unused_parameters=True,
|
|
|
|
| 975 |
# ββ Train βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 976 |
print("\nπ Training...")
|
| 977 |
train_start = time.time()
|
| 978 |
+
trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
|
| 979 |
train_elapsed = time.time() - train_start
|
| 980 |
print(f"β
Training complete! ({train_elapsed:.1f}s)")
|
| 981 |
print(f" Best checkpoint: {trainer.state.best_model_checkpoint}")
|
|
|
|
| 1038 |
"learning_rate": args.lr,
|
| 1039 |
"epochs": args.epochs,
|
| 1040 |
"batch_size": args.batch_size,
|
| 1041 |
+
"weight_decay": 0.1,
|
| 1042 |
+
"scheduler": "constant_with_warmup",
|
| 1043 |
+
"warmup_steps": 2000 if not args.test else 0,
|
| 1044 |
+
"adam_beta2": 0.95,
|
| 1045 |
+
"precision": "bf16",
|
| 1046 |
},
|
| 1047 |
}
|
| 1048 |
with open(os.path.join(best_model_path, "config.json"), "w") as f:
|