privacy-filter-enhanced / train_sidecar.py
narcolepticchicken's picture
Switch to bert-base-uncased, larger batch, fp16
0899f8d verified
#!/usr/bin/env python3
"""Train a BERT sidecar NER model for 3 new PII categories."""
import json, random, argparse, ast, sys
import numpy as np
import torch
from datasets import load_dataset, Dataset
from transformers import (
AutoTokenizer, AutoModelForTokenClassification,
TrainingArguments, Trainer, DataCollatorForTokenClassification,
EarlyStoppingCallback
)
import evaluate
CATEGORIES = ["fax_number", "credit_card_last4", "company_contact_block"]
LABELS = ["O"]
for cat in CATEGORIES:
for p in ("B", "I"):
LABELS.append(f"{p}-{cat}")
label2id = {l: i for i, l in enumerate(LABELS)}
id2label = {i: l for l, i in label2id.items()}
NUM_LABELS = len(LABELS)
seqeval = evaluate.load("seqeval")
def compute_metrics(p):
predictions, labels = p
predictions = np.argmax(predictions, axis=2)
true_preds = [
[id2label[pred] for pred, lab in zip(pred_row, lab_row) if lab != -100]
for pred_row, lab_row in zip(predictions, labels)
]
true_labs = [
[id2label[lab] for pred, lab in zip(pred_row, lab_row) if lab != -100]
for pred_row, lab_row in zip(predictions, labels)
]
results = seqeval.compute(predictions=true_preds, references=true_labs)
return {
"precision": results["overall_precision"],
"recall": results["overall_recall"],
"f1": results["overall_f1"],
"accuracy": results["overall_accuracy"],
}
from faker import Faker
fake = Faker()
def generate_synthetic_examples(n=5000, seed=42):
random.seed(seed)
fake.seed_instance(seed)
examples = []
def add(text, spans):
examples.append({"text": text, "spans": spans})
for _ in range(n):
r = random.random()
if r < 0.33:
fax = fake.numerify(text="(###) ###-####")
tmpl = random.choice([
f"Please fax documents to {fax}.",
f"Fax: {fax}\nAttn: Legal",
f"Secure fax line: {fax}",
f"You can reach us at phone (555) 123-4567 or fax {fax}.",
])
s = tmpl.find(fax)
add(tmpl, [(s, s + len(fax), "fax_number")])
elif r < 0.66:
last4 = fake.numerify(text="####")
tmpl = random.choice([
f"Card ending in {last4} charged.",
f"Visa ****-****-****-{last4}",
f"Last 4 digits: {last4}",
f"Card on file ...{last4}",
])
s = tmpl.find(last4)
add(tmpl, [(s, s + len(last4), "credit_card_last4")])
else:
company = fake.company()
addr = (
fake.street_address() + ", " + fake.city() + ", "
+ fake.state_abbr() + " " + fake.zipcode()
)
phone = fake.numerify(text="(###) ###-####")
email = fake.company_email()
tmpl = random.choice([
f"{company}\n{addr}\nPhone: {phone}\nEmail: {email}",
f"Contact:\n{company}\n{addr}\nTel: {phone}\n{email}",
f"{company} HQ\n{addr}\nMain: {phone}\nInquiries: {email}",
])
s = tmpl.find(company)
e = tmpl.find(email) + len(email)
add(tmpl, [(s, e, "company_contact_block")])
return examples
NEMOTRON_MAP = {
"company_name": "company_contact_block",
}
def load_nemotron_split(split, max_examples=5000):
ds = load_dataset("nvidia/Nemotron-PII", split=split)
examples = []
for ex in ds:
if len(examples) >= max_examples:
break
text = ex["text"]
spans_raw = ex["spans"]
if isinstance(spans_raw, str):
try:
spans_raw = json.loads(spans_raw)
except json.JSONDecodeError:
spans_raw = ast.literal_eval(spans_raw)
spans = []
for sp in spans_raw:
lab = NEMOTRON_MAP.get(sp["label"])
if lab:
spans.append((sp["start"], sp["end"], lab))
if spans:
examples.append({"text": text, "spans": spans})
return examples
def tokenize_and_align(examples, tokenizer):
texts = [ex["text"] for ex in examples]
enc = tokenizer(
texts,
truncation=True,
max_length=512,
padding=False,
return_offsets_mapping=True,
)
all_labels = []
for i, ex in enumerate(examples):
offsets = enc["offset_mapping"][i]
labels = ["O"] * len(offsets)
for start, end, lab in ex["spans"]:
covered = []
for j, (ts, te) in enumerate(offsets):
if ts is None or te is None:
continue
if ts >= end or te <= start:
continue
covered.append(j)
if not covered:
continue
labels[covered[0]] = f"B-{lab}"
for idx in covered[1:]:
labels[idx] = f"I-{lab}"
label_ids = []
for j, (ts, te) in enumerate(offsets):
if ts is None and te is None:
label_ids.append(-100)
else:
label_ids.append(label2id.get(labels[j], 0))
all_labels.append(label_ids)
enc["labels"] = all_labels
enc.pop("offset_mapping")
return enc
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--base_model", default="bert-base-uncased")
parser.add_argument("--output_model", default="narcolepticchicken/privacy-filter-sidecar-bert")
parser.add_argument("--epochs", type=int, default=5)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--grad_accum", type=int, default=1)
parser.add_argument("--lr", type=float, default=5e-5)
parser.add_argument("--max_synthetic", type=int, default=5000)
parser.add_argument("--max_nemotron_train", type=int, default=5000)
parser.add_argument("--max_nemotron_eval", type=int, default=1000)
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
print(f"Loading tokenizer: {args.base_model}")
tokenizer = AutoTokenizer.from_pretrained(args.base_model, use_fast=True)
print(f"Loading model: {args.base_model}")
model = AutoModelForTokenClassification.from_pretrained(
args.base_model,
num_labels=NUM_LABELS,
id2label=id2label,
label2id=label2id,
)
print("\n=== Sanity check: tokenizing one example ===")
test_ex = generate_synthetic_examples(1, args.seed)
test_tok = tokenize_and_align(test_ex, tokenizer)
test_labels = test_tok["labels"][0]
non_o = sum(1 for lid in test_labels if lid != -100 and lid != 0)
special = sum(1 for lid in test_labels if lid == -100)
print(f" Tokens: {len(test_labels)}, Special (-100): {special}, Non-O labels: {non_o}")
if non_o == 0:
print(" ERROR: No non-O labels found! Exiting.")
sys.exit(1)
print(" OK - labels are aligned.\n")
print("Generating synthetic data...")
synth = generate_synthetic_examples(args.max_synthetic, args.seed)
print(f" Synthetic: {len(synth)}")
print("Loading Nemotron-PII (filtered to company_name only)...")
nemotron_train = load_nemotron_split("train", args.max_nemotron_train)
nemotron_eval = load_nemotron_split("test", args.max_nemotron_eval)
print(f" Nemotron train: {len(nemotron_train)}, eval: {len(nemotron_eval)}")
train_examples = synth + nemotron_train
eval_examples = nemotron_eval
print("Tokenizing train...")
train_tok = tokenize_and_align(train_examples, tokenizer)
print("Tokenizing eval...")
eval_tok = tokenize_and_align(eval_examples, tokenizer)
train_ds = Dataset.from_dict(train_tok)
eval_ds = Dataset.from_dict(eval_tok)
print("\n=== Label distribution check ===")
all_train_labels = [lid for row in train_tok["labels"] for lid in row if lid != -100]
for cat in CATEGORIES:
b_id = label2id[f"B-{cat}"]
i_id = label2id[f"I-{cat}"]
count = sum(1 for lid in all_train_labels if lid in (b_id, i_id))
print(f" {cat}: {count} tokens")
if sum(1 for lid in all_train_labels if lid != 0) == 0:
print(" ERROR: All labels are O! Exiting.")
sys.exit(1)
data_collator = DataCollatorForTokenClassification(tokenizer)
training_args = TrainingArguments(
output_dir="/app/sidecar-checkpoints",
learning_rate=args.lr,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
num_train_epochs=args.epochs,
weight_decay=0.01,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="f1",
greater_is_better=True,
logging_strategy="steps",
logging_steps=10,
logging_first_step=True,
disable_tqdm=True,
push_to_hub=True,
hub_model_id=args.output_model,
report_to="trackio",
run_name=f"sidecar-bert-lr{args.lr}-bs{args.batch_size}",
project="privacy-filter-enhanced",
seed=args.seed,
bf16=False,
fp16=True,
gradient_accumulation_steps=args.grad_accum,
dataloader_num_workers=2,
warmup_ratio=0.1,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=eval_ds,
processing_class=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
)
print("\n=== Starting training ===")
trainer.train()
print("\n=== Pushing to hub ===")
trainer.push_to_hub(commit_message="Sidecar NER: fax + cc_last4 + contact_block")
print("\nDone!")
if __name__ == "__main__":
main()