| |
| """Train a BERT sidecar NER model for 3 new PII categories.""" |
| import json, random, argparse, ast, sys |
| import numpy as np |
| import torch |
| from datasets import load_dataset, Dataset |
| from transformers import ( |
| AutoTokenizer, AutoModelForTokenClassification, |
| TrainingArguments, Trainer, DataCollatorForTokenClassification, |
| EarlyStoppingCallback |
| ) |
| import evaluate |
|
|
| CATEGORIES = ["fax_number", "credit_card_last4", "company_contact_block"] |
| LABELS = ["O"] |
| for cat in CATEGORIES: |
| for p in ("B", "I"): |
| LABELS.append(f"{p}-{cat}") |
|
|
| label2id = {l: i for i, l in enumerate(LABELS)} |
| id2label = {i: l for l, i in label2id.items()} |
| NUM_LABELS = len(LABELS) |
|
|
| seqeval = evaluate.load("seqeval") |
|
|
|
|
| def compute_metrics(p): |
| predictions, labels = p |
| predictions = np.argmax(predictions, axis=2) |
| true_preds = [ |
| [id2label[pred] for pred, lab in zip(pred_row, lab_row) if lab != -100] |
| for pred_row, lab_row in zip(predictions, labels) |
| ] |
| true_labs = [ |
| [id2label[lab] for pred, lab in zip(pred_row, lab_row) if lab != -100] |
| for pred_row, lab_row in zip(predictions, labels) |
| ] |
| results = seqeval.compute(predictions=true_preds, references=true_labs) |
| return { |
| "precision": results["overall_precision"], |
| "recall": results["overall_recall"], |
| "f1": results["overall_f1"], |
| "accuracy": results["overall_accuracy"], |
| } |
|
|
|
|
| from faker import Faker |
| fake = Faker() |
|
|
|
|
| def generate_synthetic_examples(n=5000, seed=42): |
| random.seed(seed) |
| fake.seed_instance(seed) |
| examples = [] |
|
|
| def add(text, spans): |
| examples.append({"text": text, "spans": spans}) |
|
|
| for _ in range(n): |
| r = random.random() |
| if r < 0.33: |
| fax = fake.numerify(text="(###) ###-####") |
| tmpl = random.choice([ |
| f"Please fax documents to {fax}.", |
| f"Fax: {fax}\nAttn: Legal", |
| f"Secure fax line: {fax}", |
| f"You can reach us at phone (555) 123-4567 or fax {fax}.", |
| ]) |
| s = tmpl.find(fax) |
| add(tmpl, [(s, s + len(fax), "fax_number")]) |
|
|
| elif r < 0.66: |
| last4 = fake.numerify(text="####") |
| tmpl = random.choice([ |
| f"Card ending in {last4} charged.", |
| f"Visa ****-****-****-{last4}", |
| f"Last 4 digits: {last4}", |
| f"Card on file ...{last4}", |
| ]) |
| s = tmpl.find(last4) |
| add(tmpl, [(s, s + len(last4), "credit_card_last4")]) |
|
|
| else: |
| company = fake.company() |
| addr = ( |
| fake.street_address() + ", " + fake.city() + ", " |
| + fake.state_abbr() + " " + fake.zipcode() |
| ) |
| phone = fake.numerify(text="(###) ###-####") |
| email = fake.company_email() |
| tmpl = random.choice([ |
| f"{company}\n{addr}\nPhone: {phone}\nEmail: {email}", |
| f"Contact:\n{company}\n{addr}\nTel: {phone}\n{email}", |
| f"{company} HQ\n{addr}\nMain: {phone}\nInquiries: {email}", |
| ]) |
| s = tmpl.find(company) |
| e = tmpl.find(email) + len(email) |
| add(tmpl, [(s, e, "company_contact_block")]) |
| return examples |
|
|
|
|
| NEMOTRON_MAP = { |
| "company_name": "company_contact_block", |
| } |
|
|
|
|
| def load_nemotron_split(split, max_examples=5000): |
| ds = load_dataset("nvidia/Nemotron-PII", split=split) |
| examples = [] |
| for ex in ds: |
| if len(examples) >= max_examples: |
| break |
| text = ex["text"] |
| spans_raw = ex["spans"] |
| if isinstance(spans_raw, str): |
| try: |
| spans_raw = json.loads(spans_raw) |
| except json.JSONDecodeError: |
| spans_raw = ast.literal_eval(spans_raw) |
| spans = [] |
| for sp in spans_raw: |
| lab = NEMOTRON_MAP.get(sp["label"]) |
| if lab: |
| spans.append((sp["start"], sp["end"], lab)) |
| if spans: |
| examples.append({"text": text, "spans": spans}) |
| return examples |
|
|
|
|
| def tokenize_and_align(examples, tokenizer): |
| texts = [ex["text"] for ex in examples] |
| enc = tokenizer( |
| texts, |
| truncation=True, |
| max_length=512, |
| padding=False, |
| return_offsets_mapping=True, |
| ) |
|
|
| all_labels = [] |
| for i, ex in enumerate(examples): |
| offsets = enc["offset_mapping"][i] |
| labels = ["O"] * len(offsets) |
|
|
| for start, end, lab in ex["spans"]: |
| covered = [] |
| for j, (ts, te) in enumerate(offsets): |
| if ts is None or te is None: |
| continue |
| if ts >= end or te <= start: |
| continue |
| covered.append(j) |
|
|
| if not covered: |
| continue |
|
|
| labels[covered[0]] = f"B-{lab}" |
| for idx in covered[1:]: |
| labels[idx] = f"I-{lab}" |
|
|
| label_ids = [] |
| for j, (ts, te) in enumerate(offsets): |
| if ts is None and te is None: |
| label_ids.append(-100) |
| else: |
| label_ids.append(label2id.get(labels[j], 0)) |
| all_labels.append(label_ids) |
|
|
| enc["labels"] = all_labels |
| enc.pop("offset_mapping") |
| return enc |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--base_model", default="bert-base-uncased") |
| parser.add_argument("--output_model", default="narcolepticchicken/privacy-filter-sidecar-bert") |
| parser.add_argument("--epochs", type=int, default=5) |
| parser.add_argument("--batch_size", type=int, default=32) |
| parser.add_argument("--grad_accum", type=int, default=1) |
| parser.add_argument("--lr", type=float, default=5e-5) |
| parser.add_argument("--max_synthetic", type=int, default=5000) |
| parser.add_argument("--max_nemotron_train", type=int, default=5000) |
| parser.add_argument("--max_nemotron_eval", type=int, default=1000) |
| parser.add_argument("--seed", type=int, default=42) |
| args = parser.parse_args() |
|
|
| random.seed(args.seed) |
| np.random.seed(args.seed) |
| torch.manual_seed(args.seed) |
|
|
| print(f"Loading tokenizer: {args.base_model}") |
| tokenizer = AutoTokenizer.from_pretrained(args.base_model, use_fast=True) |
|
|
| print(f"Loading model: {args.base_model}") |
| model = AutoModelForTokenClassification.from_pretrained( |
| args.base_model, |
| num_labels=NUM_LABELS, |
| id2label=id2label, |
| label2id=label2id, |
| ) |
|
|
| print("\n=== Sanity check: tokenizing one example ===") |
| test_ex = generate_synthetic_examples(1, args.seed) |
| test_tok = tokenize_and_align(test_ex, tokenizer) |
| test_labels = test_tok["labels"][0] |
| non_o = sum(1 for lid in test_labels if lid != -100 and lid != 0) |
| special = sum(1 for lid in test_labels if lid == -100) |
| print(f" Tokens: {len(test_labels)}, Special (-100): {special}, Non-O labels: {non_o}") |
| if non_o == 0: |
| print(" ERROR: No non-O labels found! Exiting.") |
| sys.exit(1) |
| print(" OK - labels are aligned.\n") |
|
|
| print("Generating synthetic data...") |
| synth = generate_synthetic_examples(args.max_synthetic, args.seed) |
| print(f" Synthetic: {len(synth)}") |
|
|
| print("Loading Nemotron-PII (filtered to company_name only)...") |
| nemotron_train = load_nemotron_split("train", args.max_nemotron_train) |
| nemotron_eval = load_nemotron_split("test", args.max_nemotron_eval) |
| print(f" Nemotron train: {len(nemotron_train)}, eval: {len(nemotron_eval)}") |
|
|
| train_examples = synth + nemotron_train |
| eval_examples = nemotron_eval |
|
|
| print("Tokenizing train...") |
| train_tok = tokenize_and_align(train_examples, tokenizer) |
| print("Tokenizing eval...") |
| eval_tok = tokenize_and_align(eval_examples, tokenizer) |
|
|
| train_ds = Dataset.from_dict(train_tok) |
| eval_ds = Dataset.from_dict(eval_tok) |
|
|
| print("\n=== Label distribution check ===") |
| all_train_labels = [lid for row in train_tok["labels"] for lid in row if lid != -100] |
| for cat in CATEGORIES: |
| b_id = label2id[f"B-{cat}"] |
| i_id = label2id[f"I-{cat}"] |
| count = sum(1 for lid in all_train_labels if lid in (b_id, i_id)) |
| print(f" {cat}: {count} tokens") |
| if sum(1 for lid in all_train_labels if lid != 0) == 0: |
| print(" ERROR: All labels are O! Exiting.") |
| sys.exit(1) |
|
|
| data_collator = DataCollatorForTokenClassification(tokenizer) |
|
|
| training_args = TrainingArguments( |
| output_dir="/app/sidecar-checkpoints", |
| learning_rate=args.lr, |
| per_device_train_batch_size=args.batch_size, |
| per_device_eval_batch_size=args.batch_size, |
| num_train_epochs=args.epochs, |
| weight_decay=0.01, |
| eval_strategy="epoch", |
| save_strategy="epoch", |
| load_best_model_at_end=True, |
| metric_for_best_model="f1", |
| greater_is_better=True, |
| logging_strategy="steps", |
| logging_steps=10, |
| logging_first_step=True, |
| disable_tqdm=True, |
| push_to_hub=True, |
| hub_model_id=args.output_model, |
| report_to="trackio", |
| run_name=f"sidecar-bert-lr{args.lr}-bs{args.batch_size}", |
| project="privacy-filter-enhanced", |
| seed=args.seed, |
| bf16=False, |
| fp16=True, |
| gradient_accumulation_steps=args.grad_accum, |
| dataloader_num_workers=2, |
| warmup_ratio=0.1, |
| ) |
|
|
| trainer = Trainer( |
| model=model, |
| args=training_args, |
| train_dataset=train_ds, |
| eval_dataset=eval_ds, |
| processing_class=tokenizer, |
| data_collator=data_collator, |
| compute_metrics=compute_metrics, |
| callbacks=[EarlyStoppingCallback(early_stopping_patience=2)], |
| ) |
|
|
| print("\n=== Starting training ===") |
| trainer.train() |
| print("\n=== Pushing to hub ===") |
| trainer.push_to_hub(commit_message="Sidecar NER: fax + cc_last4 + contact_block") |
| print("\nDone!") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|