File size: 9,315 Bytes
87b3314 a7bd4f1 87b3314 a7bd4f1 87b3314 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 | #!/usr/bin/env python3
import json, random, argparse, ast
import numpy as np, torch
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForTokenClassification, TrainingArguments, Trainer, DataCollatorForTokenClassification
import evaluate
CATEGORIES = ["account_number","private_address","private_date","private_email","private_person","private_phone","private_url","secret","fax_number","credit_card_last4","company_contact_block"]
LABELS = ["O"]
for cat in CATEGORIES:
for p in ("B","I","E","S"): LABELS.append(f"{p}-{cat}")
label2id = {l:i for i,l in enumerate(LABELS)}
id2label = {i:l for l,i in label2id.items()}
NUM_LABELS = len(LABELS)
seqeval = evaluate.load("seqeval")
def compute_metrics(p):
predictions, labels = p
predictions = np.argmax(predictions, axis=2)
true_preds = [[id2label[pred] for pred,lab in zip(pred_row,lab_row) if lab!=-100] for pred_row,lab_row in zip(predictions,labels)]
true_labs = [[id2label[lab] for pred,lab in zip(pred_row,lab_row) if lab!=-100] for pred_row,lab_row in zip(predictions,labels)]
results = seqeval.compute(predictions=true_preds, references=true_labs)
return {"precision":results["overall_precision"],"recall":results["overall_recall"],"f1":results["overall_f1"],"accuracy":results["overall_accuracy"]}
from faker import Faker
fake = Faker()
def generate_synthetic_examples(n=5000, seed=42):
random.seed(seed); fake.seed_instance(seed)
examples = []
def add(text, spans): examples.append({"text":text,"spans":spans})
for _ in range(n):
r = random.random()
if r < 0.25:
fax = fake.numerify(text="(###) ###-####")
tmpl = random.choice([f"Please fax documents to {fax}.",f"Fax: {fax}\nAttn: Legal",f"Secure fax line: {fax}",f"You can reach us at phone (555) 123-4567 or fax {fax}.",f"Facsimile: {fax}"])
s = tmpl.find(fax); add(tmpl, [(s,s+len(fax),"fax_number")])
elif r < 0.5:
last4 = fake.numerify(text="####")
tmpl = random.choice([f"Card ending in {last4} charged.",f"Visa ****-****-****-{last4}",f"Last 4 digits: {last4}",f"Card on file ...{last4}",f"XXXX-XXXX-XXXX-{last4}"])
s = tmpl.find(last4); add(tmpl, [(s,s+len(last4),"credit_card_last4")])
elif r < 0.75:
company = fake.company(); addr = fake.street_address() + ", " + fake.city() + ", " + fake.state_abbr() + " " + fake.zipcode()
phone = fake.numerify(text="(###) ###-####"); email = fake.company_email()
tmpl = random.choice([f"{company}\n{addr}\nPhone: {phone}\nEmail: {email}",f"Contact:\n{company}\n{addr}\nTel: {phone}\n{email}",f"{company} HQ\n{addr}\nMain: {phone}\nInquiries: {email}"])
s = tmpl.find(company); e = tmpl.find(email)+len(email); add(tmpl, [(s,e,"company_contact_block")])
else:
person = fake.name(); email = fake.email(); fax = fake.numerify(text="(###) ###-####")
phone = fake.numerify(text="(###) ###-####"); last4 = fake.numerify(text="####")
company = fake.company(); addr = fake.street_address() + ", " + fake.city() + ", " + fake.state_abbr() + " " + fake.zipcode()
tmpl = random.choice([f"From: {person} <{email}>\nTo: Legal\nFax: {fax}\nPhone: {phone}\nCard: {last4}\n{company}\n{addr}",f"Client: {person}\nEmail: {email}\nFax: {fax}\nTel: {phone}\nPayment: ****{last4}\nEmployer: {company}\n{addr}"])
spans = []
for sub,lab in [(person,"private_person"),(email,"private_email"),(fax,"fax_number"),(phone,"private_phone"),(last4,"credit_card_last4"),(company,"company_contact_block"),(addr,"private_address")]:
idx = tmpl.find(sub)
if idx >= 0: spans.append((idx,idx+len(sub),lab))
add(tmpl, spans)
return examples
NEMOTRON_MAP = {"first_name":"private_person","last_name":"private_person","full_name":"private_person","name":"private_person","email":"private_email","phone_number":"private_phone","street_address":"private_address","address":"private_address","date_of_birth":"private_date","date":"private_date","credit_card_number":"account_number","ssn":"account_number","company_name":"company_contact_block","url":"private_url","secret":"secret","api_key":"secret","password":"secret","token":"secret"}
def load_nemotron_split(split, max_examples=10000):
ds = load_dataset("nvidia/Nemotron-PII", split=split)
examples = []
for ex in ds:
if len(examples) >= max_examples: break
text = ex["text"]
spans_raw = ex["spans"]
if isinstance(spans_raw, str):
try:
spans_raw = json.loads(spans_raw)
except json.JSONDecodeError:
spans_raw = ast.literal_eval(spans_raw)
spans = []
for sp in spans_raw:
lab = NEMOTRON_MAP.get(sp["label"])
if lab: spans.append((sp["start"],sp["end"],lab))
if spans: examples.append({"text":text,"spans":spans})
return examples
def tokenize_and_align(examples, tokenizer):
tokenized = tokenizer([ex["text"] for ex in examples], truncation=True, max_length=512, return_offsets_mapping=True)
all_labels = []
for i,ex in enumerate(examples):
offsets = tokenized["offset_mapping"][i]; labels = ["O"]*len(offsets)
for start,end,lab in ex["spans"]:
covered = [j for j,(ts,te) in enumerate(offsets) if ts is not None and te is not None and ts < end and te > start]
if not covered: continue
if len(covered)==1: labels[covered[0]] = f"S-{lab}"
else:
labels[covered[0]] = f"B-{lab}"
for idx in covered[1:-1]: labels[idx] = f"I-{lab}"
labels[covered[-1]] = f"E-{lab}"
label_ids = []
for j,(ts,te) in enumerate(offsets):
if ts is None and te is None: label_ids.append(-100)
else: label_ids.append(label2id.get(labels[j],0))
all_labels.append(label_ids)
tokenized["labels"] = all_labels
tokenized.pop("offset_mapping")
return tokenized
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", default="openai/privacy-filter")
parser.add_argument("--output_model", default="narcolepticchicken/privacy-filter-enhanced")
parser.add_argument("--epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--grad_accum", type=int, default=4)
parser.add_argument("--lr", type=float, default=2e-5)
parser.add_argument("--max_synthetic", type=int, default=5000)
parser.add_argument("--max_nemotron_train", type=int, default=15000)
parser.add_argument("--max_nemotron_eval", type=int, default=2000)
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
random.seed(args.seed); np.random.seed(args.seed); torch.manual_seed(args.seed)
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
model = AutoModelForTokenClassification.from_pretrained(args.model, num_labels=NUM_LABELS, id2label=id2label, label2id=label2id, trust_remote_code=True, ignore_mismatched_sizes=True)
print("Generating synthetic data...")
synth = generate_synthetic_examples(args.max_synthetic, args.seed)
print(f"Synthetic examples: {len(synth)}")
print("Loading Nemotron-PII...")
nemotron_train = load_nemotron_split("train", args.max_nemotron_train)
nemotron_eval = load_nemotron_split("test", args.max_nemotron_eval)
print(f"Nemotron train: {len(nemotron_train)}, eval: {len(nemotron_eval)}")
train_examples = synth + nemotron_train
eval_examples = nemotron_eval
print("Tokenizing...")
train_tok = tokenize_and_align(train_examples, tokenizer)
eval_tok = tokenize_and_align(eval_examples, tokenizer)
train_ds = Dataset.from_dict(train_tok)
eval_ds = Dataset.from_dict(eval_tok)
data_collator = DataCollatorForTokenClassification(tokenizer)
training_args = TrainingArguments(
output_dir="/app/privacy-filter-checkpoints",
learning_rate=args.lr,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
num_train_epochs=args.epochs,
weight_decay=0.01,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="f1",
greater_is_better=True,
logging_strategy="steps",
logging_steps=50,
logging_first_step=True,
disable_tqdm=True,
push_to_hub=True,
hub_model_id=args.output_model,
report_to="trackio",
run_name=f"privacy-filter-enhanced-lr{args.lr}-bs{args.batch_size}-ep{args.epochs}",
project="privacy-filter-enhanced",
seed=args.seed,
bf16=True,
gradient_accumulation_steps=args.grad_accum,
dataloader_num_workers=4,
)
trainer = Trainer(model=model, args=training_args, train_dataset=train_ds, eval_dataset=eval_ds, processing_class=tokenizer, data_collator=data_collator, compute_metrics=compute_metrics)
print("Starting training...")
trainer.train()
print("Pushing to hub...")
trainer.push_to_hub()
print("Done!")
if __name__ == "__main__":
main()
|