av-codes commited on
Commit
399f1d7
Β·
verified Β·
1 Parent(s): c23cb27

add DistilBERT eval/fine-tune script for bordair comparison

Browse files
Files changed (1) hide show
  1. eval_distilbert_bordair.py +202 -0
eval_distilbert_bordair.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Evaluate and fine-tune DistilBERT on Bordair multimodal dataset.
3
+
4
+ Two tests:
5
+ 1. Zero-shot: existing av-codes/prompt-injection-detector-v2 on bordair eval split
6
+ 2. Fine-tune: distilbert-base-uncased trained on bordair train split, 1 epoch
7
+
8
+ Uses the same train/eval split as HRM-Text (seed=42, stratified 90/10).
9
+ """
10
+ import json
11
+ import glob
12
+ import time
13
+ import numpy as np
14
+
15
+ import datasets as hf_datasets
16
+ import evaluate
17
+ import torch
18
+ from datasets import Dataset
19
+ from huggingface_hub import snapshot_download
20
+ from transformers import (
21
+ AutoModelForSequenceClassification,
22
+ AutoTokenizer,
23
+ Trainer,
24
+ TrainingArguments,
25
+ pipeline,
26
+ )
27
+
28
+
29
+ def load_bordair_multimodal():
30
+ print("πŸ“¦ Downloading Bordair/bordair-multimodal...")
31
+ path = snapshot_download(repo_id="Bordair/bordair-multimodal", repo_type="dataset")
32
+ print(f" Downloaded to: {path}")
33
+
34
+ all_samples = []
35
+ patterns = [
36
+ "benign/*.json",
37
+ "payloads/*/*.json",
38
+ "payloads_v5/*.json",
39
+ "payloads_v5_external/*/*.json",
40
+ ]
41
+ for pattern in patterns:
42
+ files = sorted(glob.glob(f"{path}/{pattern}"))
43
+ for f in files:
44
+ fname = f.split("/")[-1]
45
+ if fname in ("summary.json", "_pool.json", "summary_old.json"):
46
+ continue
47
+ try:
48
+ with open(f, "r") as fh:
49
+ data = json.load(fh)
50
+ except (json.JSONDecodeError, UnicodeDecodeError):
51
+ continue
52
+ if isinstance(data, list):
53
+ for item in data:
54
+ if isinstance(item, dict) and item.get("expected_detection") is not None:
55
+ text_parts = [item.get("text", "")]
56
+ for k in ("image_content", "document_content", "audio_content"):
57
+ if item.get(k):
58
+ text_parts.append(item[k])
59
+ all_samples.append({
60
+ "text": "\n".join(text_parts),
61
+ "label": 1 if item["expected_detection"] else 0,
62
+ })
63
+ print(f" {pattern}: {len(all_samples)} cumulative")
64
+
65
+ ds = Dataset.from_list(all_samples)
66
+ print(f"\nβœ… Total: {len(ds)} samples ({sum(1 for s in all_samples if s['label']==1)} injection, {sum(1 for s in all_samples if s['label']==0)} safe)")
67
+ return ds
68
+
69
+
70
+ def compute_metrics(eval_pred):
71
+ accuracy = evaluate.load("accuracy")
72
+ precision_m = evaluate.load("precision")
73
+ recall_m = evaluate.load("recall")
74
+ f1_m = evaluate.load("f1")
75
+
76
+ logits, labels = eval_pred
77
+ preds = np.argmax(logits, axis=-1)
78
+ return {
79
+ "accuracy": accuracy.compute(predictions=preds, references=labels)["accuracy"],
80
+ "precision": precision_m.compute(predictions=preds, references=labels)["precision"],
81
+ "recall": recall_m.compute(predictions=preds, references=labels)["recall"],
82
+ "f1": f1_m.compute(predictions=preds, references=labels)["f1"],
83
+ }
84
+
85
+
86
+ def main():
87
+ merged = load_bordair_multimodal()
88
+ merged = merged.cast_column("label", hf_datasets.ClassLabel(names=["safe", "injection"]))
89
+ split = merged.train_test_split(test_size=0.1, seed=42, stratify_by_column="label")
90
+ train_dataset = split["train"]
91
+ eval_dataset = split["test"]
92
+ print(f" Train: {len(train_dataset)}, Eval: {len(eval_dataset)}")
93
+
94
+ # ── Test 1: Zero-shot eval of existing model ─────────────────────────
95
+ print("\n" + "="*60)
96
+ print("TEST 1: Zero-shot eval of av-codes/prompt-injection-detector-v2")
97
+ print("="*60)
98
+
99
+ zs_model_id = "av-codes/prompt-injection-detector-v2"
100
+ zs_tokenizer = AutoTokenizer.from_pretrained(zs_model_id)
101
+ zs_model = AutoModelForSequenceClassification.from_pretrained(zs_model_id)
102
+
103
+ zs_args = TrainingArguments(
104
+ output_dir="/tmp/zs_eval",
105
+ per_device_eval_batch_size=64,
106
+ fp16=torch.cuda.is_available(),
107
+ report_to="none",
108
+ disable_tqdm=True,
109
+ use_cpu=not torch.cuda.is_available(),
110
+ remove_unused_columns=False,
111
+ )
112
+
113
+ def zs_tokenize(batch):
114
+ return zs_tokenizer(batch["text"], truncation=True, padding="max_length", max_length=512)
115
+
116
+ eval_tok = eval_dataset.map(zs_tokenize, batched=True, batch_size=1000)
117
+
118
+ zs_trainer = Trainer(
119
+ model=zs_model,
120
+ args=zs_args,
121
+ eval_dataset=eval_tok,
122
+ compute_metrics=compute_metrics,
123
+ )
124
+
125
+ t0 = time.time()
126
+ zs_results = zs_trainer.evaluate()
127
+ t1 = time.time()
128
+ print(f"\nπŸ“Š Zero-shot results ({t1-t0:.0f}s):")
129
+ for k, v in zs_results.items():
130
+ print(f" {k}: {v}")
131
+
132
+ del zs_model, zs_trainer
133
+ torch.cuda.empty_cache()
134
+
135
+ # ── Test 2: Fine-tune DistilBERT on bordair ──────────────────────────
136
+ print("\n" + "="*60)
137
+ print("TEST 2: Fine-tune distilbert-base-uncased on bordair (1 epoch)")
138
+ print("="*60)
139
+
140
+ ft_model_id = "distilbert-base-uncased"
141
+ ft_tokenizer = AutoTokenizer.from_pretrained(ft_model_id)
142
+ ft_model = AutoModelForSequenceClassification.from_pretrained(ft_model_id, num_labels=2)
143
+
144
+ def ft_tokenize(batch):
145
+ return ft_tokenizer(batch["text"], truncation=True, padding="max_length", max_length=512)
146
+
147
+ train_tok = train_dataset.map(ft_tokenize, batched=True, batch_size=1000)
148
+ eval_tok2 = eval_dataset.map(ft_tokenize, batched=True, batch_size=1000)
149
+
150
+ ft_args = TrainingArguments(
151
+ output_dir="/tmp/ft_distilbert",
152
+ learning_rate=2e-5,
153
+ per_device_train_batch_size=32,
154
+ per_device_eval_batch_size=64,
155
+ num_train_epochs=1,
156
+ weight_decay=0.01,
157
+ warmup_steps=500,
158
+ lr_scheduler_type="cosine",
159
+ eval_strategy="epoch",
160
+ save_strategy="epoch",
161
+ load_best_model_at_end=False,
162
+ logging_strategy="steps",
163
+ logging_steps=100,
164
+ logging_first_step=True,
165
+ disable_tqdm=True,
166
+ fp16=torch.cuda.is_available(),
167
+ report_to="none",
168
+ use_cpu=not torch.cuda.is_available(),
169
+ dataloader_num_workers=4,
170
+ seed=42,
171
+ remove_unused_columns=False,
172
+ )
173
+
174
+ ft_trainer = Trainer(
175
+ model=ft_model,
176
+ args=ft_args,
177
+ train_dataset=train_tok,
178
+ eval_dataset=eval_tok2,
179
+ compute_metrics=compute_metrics,
180
+ )
181
+
182
+ t0 = time.time()
183
+ ft_trainer.train()
184
+ t1 = time.time()
185
+ print(f"\n⏱️ Training time: {t1-t0:.0f}s ({(t1-t0)/3600:.1f}h)")
186
+
187
+ ft_results = ft_trainer.evaluate()
188
+ print(f"\nπŸ“Š Fine-tuned DistilBERT results:")
189
+ for k, v in ft_results.items():
190
+ print(f" {k}: {v}")
191
+
192
+ # ── Summary ──────────────────────────────────────────────────────────
193
+ print("\n" + "="*60)
194
+ print("SUMMARY β€” Bordair multimodal eval set (47,644 samples)")
195
+ print("="*60)
196
+ print(f" Zero-shot DistilBERT v2 (61K data): F1={zs_results.get('eval_f1', '?')}")
197
+ print(f" Fine-tuned DistilBERT (bordair 1ep): F1={ft_results.get('eval_f1', '?')}")
198
+ print(f" HRM-Text (bordair, in progress): F1=TBD (check HF Jobs)")
199
+
200
+
201
+ if __name__ == "__main__":
202
+ main()