| |
| """Ensemble inference: openai/privacy-filter + sidecar BERT for gap categories.""" |
| from transformers import AutoTokenizer, AutoModelForTokenClassification |
| import torch |
|
|
| |
| print("Loading openai/privacy-filter...") |
| tok_pf = AutoTokenizer.from_pretrained("openai/privacy-filter", trust_remote_code=True) |
| model_pf = AutoModelForTokenClassification.from_pretrained("openai/privacy-filter", trust_remote_code=True) |
| model_pf.eval() |
|
|
| |
| print("Loading sidecar BERT...") |
| tok_sc = AutoTokenizer.from_pretrained("narcolepticchicken/privacy-filter-sidecar-bert") |
| model_sc = AutoModelForTokenClassification.from_pretrained("narcolepticchicken/privacy-filter-sidecar-bert") |
| model_sc.eval() |
|
|
| |
| PF_LABELS = [ |
| "O", "B-account_number", "I-account_number", "E-account_number", "S-account_number", |
| "B-private_address", "I-private_address", "E-private_address", "S-private_address", |
| "B-private_date", "I-private_date", "E-private_date", "S-private_date", |
| "B-private_email", "I-private_email", "E-private_email", "S-private_email", |
| "B-private_person", "I-private_person", "E-private_person", "S-private_person", |
| "B-private_phone", "I-private_phone", "E-private_phone", "S-private_phone", |
| "B-private_url", "I-private_url", "E-private_url", "S-private_url", |
| "B-secret", "I-secret", "E-secret", "S-secret", |
| ] |
|
|
| |
| SC_LABELS = ["O", "B-fax_number", "I-fax_number", "B-credit_card_last4", "I-credit_card_last4", |
| "B-company_contact_block", "I-company_contact_block"] |
|
|
| def run_pf(text): |
| """Run original privacy-filter.""" |
| inputs = tok_pf(text, return_tensors="pt", truncation=True, max_length=512) |
| with torch.no_grad(): |
| logits = model_pf(**inputs).logits |
| preds = torch.argmax(logits, dim=2)[0].cpu().numpy() |
| tokens = tok_pf.convert_ids_to_tokens(inputs["input_ids"][0]) |
| return [(tok, PF_LABELS[p]) for tok, p in zip(tokens, preds)] |
|
|
| def run_sc(text): |
| """Run sidecar BERT for gap categories.""" |
| inputs = tok_sc(text, return_tensors="pt", truncation=True, max_length=512) |
| with torch.no_grad(): |
| logits = model_sc(**inputs).logits |
| preds = torch.argmax(logits, dim=2)[0].cpu().numpy() |
| tokens = tok_sc.convert_ids_to_tokens(inputs["input_ids"][0]) |
| return [(tok, SC_LABELS[p]) for tok, p in zip(tokens, preds)] |
|
|
| def ensemble(text): |
| """Combine both models: PF for original 8 cats, sidecar for 3 gap cats.""" |
| pf_results = run_pf(text) |
| sc_results = run_sc(text) |
| return {"privacy_filter": pf_results, "sidecar": sc_results} |
|
|
|
|
| if __name__ == "__main__": |
| test_cases = [ |
| "Please fax the contract to (555) 123-4567.", |
| "Your card ending in 4242 has been charged.", |
| "Acme Corp\n123 Main St, Springfield, IL 62701\nPhone: (555) 987-6543\nEmail: legal@acme.com", |
| "Contact John Doe at john@example.com or call (555) 111-2222. Fax: (555) 333-4444.", |
| ] |
|
|
| for text in test_cases: |
| print(f"\n{'='*60}") |
| print(f"TEXT: {text}") |
| results = ensemble(text) |
| print("\n--- Privacy Filter (original 8 categories) ---") |
| for tok, label in results["privacy_filter"]: |
| if label != "O": |
| print(f" {tok}: {label}") |
| print("\n--- Sidecar (fax + CC-last4 + contact block) ---") |
| for tok, label in results["sidecar"]: |
| if label != "O": |
| print(f" {tok}: {label}") |
|
|