narcolepticchicken commited on
Commit
dda6dd9
·
verified ·
1 Parent(s): 4fe03d1

Add sidecar training script with DeBERTa-v3

Browse files
Files changed (1) hide show
  1. train_sidecar.py +295 -0
train_sidecar.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Train a DeBERTa-v3 sidecar NER model for 3 new PII categories."""
3
+ import json, random, argparse, ast, sys
4
+ import numpy as np
5
+ import torch
6
+ from datasets import load_dataset, Dataset
7
+ from transformers import (
8
+ AutoTokenizer, AutoModelForTokenClassification,
9
+ TrainingArguments, Trainer, DataCollatorForTokenClassification,
10
+ EarlyStoppingCallback
11
+ )
12
+ import evaluate
13
+
14
+ CATEGORIES = ["fax_number", "credit_card_last4", "company_contact_block"]
15
+ LABELS = ["O"]
16
+ for cat in CATEGORIES:
17
+ for p in ("B", "I"):
18
+ LABELS.append(f"{p}-{cat}")
19
+
20
+ label2id = {l: i for i, l in enumerate(LABELS)}
21
+ id2label = {i: l for l, i in label2id.items()}
22
+ NUM_LABELS = len(LABELS)
23
+
24
+ seqeval = evaluate.load("seqeval")
25
+
26
+
27
+ def compute_metrics(p):
28
+ predictions, labels = p
29
+ predictions = np.argmax(predictions, axis=2)
30
+ true_preds = [
31
+ [id2label[pred] for pred, lab in zip(pred_row, lab_row) if lab != -100]
32
+ for pred_row, lab_row in zip(predictions, labels)
33
+ ]
34
+ true_labs = [
35
+ [id2label[lab] for pred, lab in zip(pred_row, lab_row) if lab != -100]
36
+ for pred_row, lab_row in zip(predictions, labels)
37
+ ]
38
+ results = seqeval.compute(predictions=true_preds, references=true_labs)
39
+ return {
40
+ "precision": results["overall_precision"],
41
+ "recall": results["overall_recall"],
42
+ "f1": results["overall_f1"],
43
+ "accuracy": results["overall_accuracy"],
44
+ }
45
+
46
+
47
+ from faker import Faker
48
+ fake = Faker()
49
+
50
+
51
+ def generate_synthetic_examples(n=5000, seed=42):
52
+ random.seed(seed)
53
+ fake.seed_instance(seed)
54
+ examples = []
55
+
56
+ def add(text, spans):
57
+ examples.append({"text": text, "spans": spans})
58
+
59
+ for _ in range(n):
60
+ r = random.random()
61
+ if r < 0.33:
62
+ fax = fake.numerify(text="(###) ###-####")
63
+ tmpl = random.choice([
64
+ f"Please fax documents to {fax}.",
65
+ f"Fax: {fax}\nAttn: Legal",
66
+ f"Secure fax line: {fax}",
67
+ f"You can reach us at phone (555) 123-4567 or fax {fax}.",
68
+ ])
69
+ s = tmpl.find(fax)
70
+ add(tmpl, [(s, s + len(fax), "fax_number")])
71
+
72
+ elif r < 0.66:
73
+ last4 = fake.numerify(text="####")
74
+ tmpl = random.choice([
75
+ f"Card ending in {last4} charged.",
76
+ f"Visa ****-****-****-{last4}",
77
+ f"Last 4 digits: {last4}",
78
+ f"Card on file ...{last4}",
79
+ ])
80
+ s = tmpl.find(last4)
81
+ add(tmpl, [(s, s + len(last4), "credit_card_last4")])
82
+
83
+ else:
84
+ company = fake.company()
85
+ addr = (
86
+ fake.street_address() + ", " + fake.city() + ", "
87
+ + fake.state_abbr() + " " + fake.zipcode()
88
+ )
89
+ phone = fake.numerify(text="(###) ###-####")
90
+ email = fake.company_email()
91
+ tmpl = random.choice([
92
+ f"{company}\n{addr}\nPhone: {phone}\nEmail: {email}",
93
+ f"Contact:\n{company}\n{addr}\nTel: {phone}\n{email}",
94
+ f"{company} HQ\n{addr}\nMain: {phone}\nInquiries: {email}",
95
+ ])
96
+ s = tmpl.find(company)
97
+ e = tmpl.find(email) + len(email)
98
+ add(tmpl, [(s, e, "company_contact_block")])
99
+ return examples
100
+
101
+
102
+ NEMOTRON_MAP = {
103
+ "company_name": "company_contact_block",
104
+ }
105
+
106
+
107
+ def load_nemotron_split(split, max_examples=5000):
108
+ ds = load_dataset("nvidia/Nemotron-PII", split=split)
109
+ examples = []
110
+ for ex in ds:
111
+ if len(examples) >= max_examples:
112
+ break
113
+ text = ex["text"]
114
+ spans_raw = ex["spans"]
115
+ if isinstance(spans_raw, str):
116
+ try:
117
+ spans_raw = json.loads(spans_raw)
118
+ except json.JSONDecodeError:
119
+ spans_raw = ast.literal_eval(spans_raw)
120
+ spans = []
121
+ for sp in spans_raw:
122
+ lab = NEMOTRON_MAP.get(sp["label"])
123
+ if lab:
124
+ spans.append((sp["start"], sp["end"], lab))
125
+ if spans:
126
+ examples.append({"text": text, "spans": spans})
127
+ return examples
128
+
129
+
130
+ def tokenize_and_align(examples, tokenizer):
131
+ texts = [ex["text"] for ex in examples]
132
+ enc = tokenizer(
133
+ texts,
134
+ truncation=True,
135
+ max_length=512,
136
+ padding=False,
137
+ return_offsets_mapping=True,
138
+ )
139
+
140
+ all_labels = []
141
+ for i, ex in enumerate(examples):
142
+ offsets = enc["offset_mapping"][i]
143
+ labels = ["O"] * len(offsets)
144
+
145
+ for start, end, lab in ex["spans"]:
146
+ covered = []
147
+ for j, (ts, te) in enumerate(offsets):
148
+ if ts is None or te is None:
149
+ continue
150
+ if ts >= end or te <= start:
151
+ continue
152
+ covered.append(j)
153
+
154
+ if not covered:
155
+ continue
156
+
157
+ labels[covered[0]] = f"B-{lab}"
158
+ for idx in covered[1:]:
159
+ labels[idx] = f"I-{lab}"
160
+
161
+ label_ids = []
162
+ for j, (ts, te) in enumerate(offsets):
163
+ if ts is None and te is None:
164
+ label_ids.append(-100)
165
+ else:
166
+ label_ids.append(label2id.get(labels[j], 0))
167
+ all_labels.append(label_ids)
168
+
169
+ enc["labels"] = all_labels
170
+ enc.pop("offset_mapping")
171
+ return enc
172
+
173
+
174
+ def main():
175
+ parser = argparse.ArgumentParser()
176
+ parser.add_argument("--base_model", default="microsoft/deberta-v3-base")
177
+ parser.add_argument("--output_model", default="narcolepticchicken/privacy-filter-sidecar-v3")
178
+ parser.add_argument("--epochs", type=int, default=5)
179
+ parser.add_argument("--batch_size", type=int, default=16)
180
+ parser.add_argument("--grad_accum", type=int, default=2)
181
+ parser.add_argument("--lr", type=float, default=3e-5)
182
+ parser.add_argument("--max_synthetic", type=int, default=5000)
183
+ parser.add_argument("--max_nemotron_train", type=int, default=5000)
184
+ parser.add_argument("--max_nemotron_eval", type=int, default=1000)
185
+ parser.add_argument("--seed", type=int, default=42)
186
+ args = parser.parse_args()
187
+
188
+ random.seed(args.seed)
189
+ np.random.seed(args.seed)
190
+ torch.manual_seed(args.seed)
191
+
192
+ print(f"Loading tokenizer: {args.base_model}")
193
+ tokenizer = AutoTokenizer.from_pretrained(args.base_model, use_fast=True)
194
+
195
+ print(f"Loading model: {args.base_model}")
196
+ model = AutoModelForTokenClassification.from_pretrained(
197
+ args.base_model,
198
+ num_labels=NUM_LABELS,
199
+ id2label=id2label,
200
+ label2id=label2id,
201
+ )
202
+
203
+ print("\n=== Sanity check: tokenizing one example ===")
204
+ test_ex = generate_synthetic_examples(1, args.seed)
205
+ test_tok = tokenize_and_align(test_ex, tokenizer)
206
+ test_labels = test_tok["labels"][0]
207
+ non_o = sum(1 for lid in test_labels if lid != -100 and lid != 0)
208
+ special = sum(1 for lid in test_labels if lid == -100)
209
+ print(f" Tokens: {len(test_labels)}, Special (-100): {special}, Non-O labels: {non_o}")
210
+ if non_o == 0:
211
+ print(" ERROR: No non-O labels found! Exiting.")
212
+ sys.exit(1)
213
+ print(" OK - labels are aligned.\n")
214
+
215
+ print("Generating synthetic data...")
216
+ synth = generate_synthetic_examples(args.max_synthetic, args.seed)
217
+ print(f" Synthetic: {len(synth)}")
218
+
219
+ print("Loading Nemotron-PII (filtered to company_name only)...")
220
+ nemotron_train = load_nemotron_split("train", args.max_nemotron_train)
221
+ nemotron_eval = load_nemotron_split("test", args.max_nemotron_eval)
222
+ print(f" Nemotron train: {len(nemotron_train)}, eval: {len(nemotron_eval)}")
223
+
224
+ train_examples = synth + nemotron_train
225
+ eval_examples = nemotron_eval
226
+
227
+ print("Tokenizing train...")
228
+ train_tok = tokenize_and_align(train_examples, tokenizer)
229
+ print("Tokenizing eval...")
230
+ eval_tok = tokenize_and_align(eval_examples, tokenizer)
231
+
232
+ train_ds = Dataset.from_dict(train_tok)
233
+ eval_ds = Dataset.from_dict(eval_tok)
234
+
235
+ print("\n=== Label distribution check ===")
236
+ all_train_labels = [lid for row in train_tok["labels"] for lid in row if lid != -100]
237
+ for cat in CATEGORIES:
238
+ b_id = label2id[f"B-{cat}"]
239
+ i_id = label2id[f"I-{cat}"]
240
+ count = sum(1 for lid in all_train_labels if lid in (b_id, i_id))
241
+ print(f" {cat}: {count} tokens")
242
+ if sum(1 for lid in all_train_labels if lid != 0) == 0:
243
+ print(" ERROR: All labels are O! Exiting.")
244
+ sys.exit(1)
245
+
246
+ data_collator = DataCollatorForTokenClassification(tokenizer)
247
+
248
+ training_args = TrainingArguments(
249
+ output_dir="/app/sidecar-checkpoints",
250
+ learning_rate=args.lr,
251
+ per_device_train_batch_size=args.batch_size,
252
+ per_device_eval_batch_size=args.batch_size,
253
+ num_train_epochs=args.epochs,
254
+ weight_decay=0.01,
255
+ eval_strategy="epoch",
256
+ save_strategy="epoch",
257
+ load_best_model_at_end=True,
258
+ metric_for_best_model="f1",
259
+ greater_is_better=True,
260
+ logging_strategy="steps",
261
+ logging_steps=10,
262
+ logging_first_step=True,
263
+ disable_tqdm=True,
264
+ push_to_hub=True,
265
+ hub_model_id=args.output_model,
266
+ report_to="trackio",
267
+ run_name=f"sidecar-{args.base_model.split('/')[-1]}-lr{args.lr}-bs{args.batch_size}",
268
+ project="privacy-filter-enhanced",
269
+ seed=args.seed,
270
+ bf16=True,
271
+ gradient_accumulation_steps=args.grad_accum,
272
+ dataloader_num_workers=2,
273
+ warmup_ratio=0.1,
274
+ )
275
+
276
+ trainer = Trainer(
277
+ model=model,
278
+ args=training_args,
279
+ train_dataset=train_ds,
280
+ eval_dataset=eval_ds,
281
+ processing_class=tokenizer,
282
+ data_collator=data_collator,
283
+ compute_metrics=compute_metrics,
284
+ callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
285
+ )
286
+
287
+ print("\n=== Starting training ===")
288
+ trainer.train()
289
+ print("\n=== Pushing to hub ===")
290
+ trainer.push_to_hub(commit_message="Sidecar NER: fax + cc_last4 + contact_block")
291
+ print("\nDone!")
292
+
293
+
294
+ if __name__ == "__main__":
295
+ main()