privacy-filter-enhanced / ensemble_inference.py
narcolepticchicken's picture
Add ensemble inference script
89c6a38 verified
#!/usr/bin/env python3
"""Ensemble inference: openai/privacy-filter + sidecar BERT for gap categories."""
from transformers import AutoTokenizer, AutoModelForTokenClassification
import torch
# Load original privacy-filter
print("Loading openai/privacy-filter...")
tok_pf = AutoTokenizer.from_pretrained("openai/privacy-filter", trust_remote_code=True)
model_pf = AutoModelForTokenClassification.from_pretrained("openai/privacy-filter", trust_remote_code=True)
model_pf.eval()
# Load sidecar for gap categories
print("Loading sidecar BERT...")
tok_sc = AutoTokenizer.from_pretrained("narcolepticchicken/privacy-filter-sidecar-bert")
model_sc = AutoModelForTokenClassification.from_pretrained("narcolepticchicken/privacy-filter-sidecar-bert")
model_sc.eval()
# Original privacy-filter labels
PF_LABELS = [
"O", "B-account_number", "I-account_number", "E-account_number", "S-account_number",
"B-private_address", "I-private_address", "E-private_address", "S-private_address",
"B-private_date", "I-private_date", "E-private_date", "S-private_date",
"B-private_email", "I-private_email", "E-private_email", "S-private_email",
"B-private_person", "I-private_person", "E-private_person", "S-private_person",
"B-private_phone", "I-private_phone", "E-private_phone", "S-private_phone",
"B-private_url", "I-private_url", "E-private_url", "S-private_url",
"B-secret", "I-secret", "E-secret", "S-secret",
]
# Sidecar labels
SC_LABELS = ["O", "B-fax_number", "I-fax_number", "B-credit_card_last4", "I-credit_card_last4",
"B-company_contact_block", "I-company_contact_block"]
def run_pf(text):
"""Run original privacy-filter."""
inputs = tok_pf(text, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
logits = model_pf(**inputs).logits
preds = torch.argmax(logits, dim=2)[0].cpu().numpy()
tokens = tok_pf.convert_ids_to_tokens(inputs["input_ids"][0])
return [(tok, PF_LABELS[p]) for tok, p in zip(tokens, preds)]
def run_sc(text):
"""Run sidecar BERT for gap categories."""
inputs = tok_sc(text, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
logits = model_sc(**inputs).logits
preds = torch.argmax(logits, dim=2)[0].cpu().numpy()
tokens = tok_sc.convert_ids_to_tokens(inputs["input_ids"][0])
return [(tok, SC_LABELS[p]) for tok, p in zip(tokens, preds)]
def ensemble(text):
"""Combine both models: PF for original 8 cats, sidecar for 3 gap cats."""
pf_results = run_pf(text)
sc_results = run_sc(text)
return {"privacy_filter": pf_results, "sidecar": sc_results}
if __name__ == "__main__":
test_cases = [
"Please fax the contract to (555) 123-4567.",
"Your card ending in 4242 has been charged.",
"Acme Corp\n123 Main St, Springfield, IL 62701\nPhone: (555) 987-6543\nEmail: legal@acme.com",
"Contact John Doe at john@example.com or call (555) 111-2222. Fax: (555) 333-4444.",
]
for text in test_cases:
print(f"\n{'='*60}")
print(f"TEXT: {text}")
results = ensemble(text)
print("\n--- Privacy Filter (original 8 categories) ---")
for tok, label in results["privacy_filter"]:
if label != "O":
print(f" {tok}: {label}")
print("\n--- Sidecar (fax + CC-last4 + contact block) ---")
for tok, label in results["sidecar"]:
if label != "O":
print(f" {tok}: {label}")