#!/usr/bin/env python3 """Train a BERT sidecar NER model for 3 new PII categories.""" import json, random, argparse, ast, sys import numpy as np import torch from datasets import load_dataset, Dataset from transformers import ( AutoTokenizer, AutoModelForTokenClassification, TrainingArguments, Trainer, DataCollatorForTokenClassification, EarlyStoppingCallback ) import evaluate CATEGORIES = ["fax_number", "credit_card_last4", "company_contact_block"] LABELS = ["O"] for cat in CATEGORIES: for p in ("B", "I"): LABELS.append(f"{p}-{cat}") label2id = {l: i for i, l in enumerate(LABELS)} id2label = {i: l for l, i in label2id.items()} NUM_LABELS = len(LABELS) seqeval = evaluate.load("seqeval") def compute_metrics(p): predictions, labels = p predictions = np.argmax(predictions, axis=2) true_preds = [ [id2label[pred] for pred, lab in zip(pred_row, lab_row) if lab != -100] for pred_row, lab_row in zip(predictions, labels) ] true_labs = [ [id2label[lab] for pred, lab in zip(pred_row, lab_row) if lab != -100] for pred_row, lab_row in zip(predictions, labels) ] results = seqeval.compute(predictions=true_preds, references=true_labs) return { "precision": results["overall_precision"], "recall": results["overall_recall"], "f1": results["overall_f1"], "accuracy": results["overall_accuracy"], } from faker import Faker fake = Faker() def generate_synthetic_examples(n=5000, seed=42): random.seed(seed) fake.seed_instance(seed) examples = [] def add(text, spans): examples.append({"text": text, "spans": spans}) for _ in range(n): r = random.random() if r < 0.33: fax = fake.numerify(text="(###) ###-####") tmpl = random.choice([ f"Please fax documents to {fax}.", f"Fax: {fax}\nAttn: Legal", f"Secure fax line: {fax}", f"You can reach us at phone (555) 123-4567 or fax {fax}.", ]) s = tmpl.find(fax) add(tmpl, [(s, s + len(fax), "fax_number")]) elif r < 0.66: last4 = fake.numerify(text="####") tmpl = random.choice([ f"Card ending in {last4} charged.", f"Visa ****-****-****-{last4}", f"Last 4 digits: {last4}", f"Card on file ...{last4}", ]) s = tmpl.find(last4) add(tmpl, [(s, s + len(last4), "credit_card_last4")]) else: company = fake.company() addr = ( fake.street_address() + ", " + fake.city() + ", " + fake.state_abbr() + " " + fake.zipcode() ) phone = fake.numerify(text="(###) ###-####") email = fake.company_email() tmpl = random.choice([ f"{company}\n{addr}\nPhone: {phone}\nEmail: {email}", f"Contact:\n{company}\n{addr}\nTel: {phone}\n{email}", f"{company} HQ\n{addr}\nMain: {phone}\nInquiries: {email}", ]) s = tmpl.find(company) e = tmpl.find(email) + len(email) add(tmpl, [(s, e, "company_contact_block")]) return examples NEMOTRON_MAP = { "company_name": "company_contact_block", } def load_nemotron_split(split, max_examples=5000): ds = load_dataset("nvidia/Nemotron-PII", split=split) examples = [] for ex in ds: if len(examples) >= max_examples: break text = ex["text"] spans_raw = ex["spans"] if isinstance(spans_raw, str): try: spans_raw = json.loads(spans_raw) except json.JSONDecodeError: spans_raw = ast.literal_eval(spans_raw) spans = [] for sp in spans_raw: lab = NEMOTRON_MAP.get(sp["label"]) if lab: spans.append((sp["start"], sp["end"], lab)) if spans: examples.append({"text": text, "spans": spans}) return examples def tokenize_and_align(examples, tokenizer): texts = [ex["text"] for ex in examples] enc = tokenizer( texts, truncation=True, max_length=512, padding=False, return_offsets_mapping=True, ) all_labels = [] for i, ex in enumerate(examples): offsets = enc["offset_mapping"][i] labels = ["O"] * len(offsets) for start, end, lab in ex["spans"]: covered = [] for j, (ts, te) in enumerate(offsets): if ts is None or te is None: continue if ts >= end or te <= start: continue covered.append(j) if not covered: continue labels[covered[0]] = f"B-{lab}" for idx in covered[1:]: labels[idx] = f"I-{lab}" label_ids = [] for j, (ts, te) in enumerate(offsets): if ts is None and te is None: label_ids.append(-100) else: label_ids.append(label2id.get(labels[j], 0)) all_labels.append(label_ids) enc["labels"] = all_labels enc.pop("offset_mapping") return enc def main(): parser = argparse.ArgumentParser() parser.add_argument("--base_model", default="bert-base-uncased") parser.add_argument("--output_model", default="narcolepticchicken/privacy-filter-sidecar-bert") parser.add_argument("--epochs", type=int, default=5) parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--grad_accum", type=int, default=1) parser.add_argument("--lr", type=float, default=5e-5) parser.add_argument("--max_synthetic", type=int, default=5000) parser.add_argument("--max_nemotron_train", type=int, default=5000) parser.add_argument("--max_nemotron_eval", type=int, default=1000) parser.add_argument("--seed", type=int, default=42) args = parser.parse_args() random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) print(f"Loading tokenizer: {args.base_model}") tokenizer = AutoTokenizer.from_pretrained(args.base_model, use_fast=True) print(f"Loading model: {args.base_model}") model = AutoModelForTokenClassification.from_pretrained( args.base_model, num_labels=NUM_LABELS, id2label=id2label, label2id=label2id, ) print("\n=== Sanity check: tokenizing one example ===") test_ex = generate_synthetic_examples(1, args.seed) test_tok = tokenize_and_align(test_ex, tokenizer) test_labels = test_tok["labels"][0] non_o = sum(1 for lid in test_labels if lid != -100 and lid != 0) special = sum(1 for lid in test_labels if lid == -100) print(f" Tokens: {len(test_labels)}, Special (-100): {special}, Non-O labels: {non_o}") if non_o == 0: print(" ERROR: No non-O labels found! Exiting.") sys.exit(1) print(" OK - labels are aligned.\n") print("Generating synthetic data...") synth = generate_synthetic_examples(args.max_synthetic, args.seed) print(f" Synthetic: {len(synth)}") print("Loading Nemotron-PII (filtered to company_name only)...") nemotron_train = load_nemotron_split("train", args.max_nemotron_train) nemotron_eval = load_nemotron_split("test", args.max_nemotron_eval) print(f" Nemotron train: {len(nemotron_train)}, eval: {len(nemotron_eval)}") train_examples = synth + nemotron_train eval_examples = nemotron_eval print("Tokenizing train...") train_tok = tokenize_and_align(train_examples, tokenizer) print("Tokenizing eval...") eval_tok = tokenize_and_align(eval_examples, tokenizer) train_ds = Dataset.from_dict(train_tok) eval_ds = Dataset.from_dict(eval_tok) print("\n=== Label distribution check ===") all_train_labels = [lid for row in train_tok["labels"] for lid in row if lid != -100] for cat in CATEGORIES: b_id = label2id[f"B-{cat}"] i_id = label2id[f"I-{cat}"] count = sum(1 for lid in all_train_labels if lid in (b_id, i_id)) print(f" {cat}: {count} tokens") if sum(1 for lid in all_train_labels if lid != 0) == 0: print(" ERROR: All labels are O! Exiting.") sys.exit(1) data_collator = DataCollatorForTokenClassification(tokenizer) training_args = TrainingArguments( output_dir="/app/sidecar-checkpoints", learning_rate=args.lr, per_device_train_batch_size=args.batch_size, per_device_eval_batch_size=args.batch_size, num_train_epochs=args.epochs, weight_decay=0.01, eval_strategy="epoch", save_strategy="epoch", load_best_model_at_end=True, metric_for_best_model="f1", greater_is_better=True, logging_strategy="steps", logging_steps=10, logging_first_step=True, disable_tqdm=True, push_to_hub=True, hub_model_id=args.output_model, report_to="trackio", run_name=f"sidecar-bert-lr{args.lr}-bs{args.batch_size}", project="privacy-filter-enhanced", seed=args.seed, bf16=False, fp16=True, gradient_accumulation_steps=args.grad_accum, dataloader_num_workers=2, warmup_ratio=0.1, ) trainer = Trainer( model=model, args=training_args, train_dataset=train_ds, eval_dataset=eval_ds, processing_class=tokenizer, data_collator=data_collator, compute_metrics=compute_metrics, callbacks=[EarlyStoppingCallback(early_stopping_patience=2)], ) print("\n=== Starting training ===") trainer.train() print("\n=== Pushing to hub ===") trainer.push_to_hub(commit_message="Sidecar NER: fax + cc_last4 + contact_block") print("\nDone!") if __name__ == "__main__": main()