#!/usr/bin/env python3 """Ensemble inference: openai/privacy-filter + sidecar BERT for gap categories.""" from transformers import AutoTokenizer, AutoModelForTokenClassification import torch # Load original privacy-filter print("Loading openai/privacy-filter...") tok_pf = AutoTokenizer.from_pretrained("openai/privacy-filter", trust_remote_code=True) model_pf = AutoModelForTokenClassification.from_pretrained("openai/privacy-filter", trust_remote_code=True) model_pf.eval() # Load sidecar for gap categories print("Loading sidecar BERT...") tok_sc = AutoTokenizer.from_pretrained("narcolepticchicken/privacy-filter-sidecar-bert") model_sc = AutoModelForTokenClassification.from_pretrained("narcolepticchicken/privacy-filter-sidecar-bert") model_sc.eval() # Original privacy-filter labels PF_LABELS = [ "O", "B-account_number", "I-account_number", "E-account_number", "S-account_number", "B-private_address", "I-private_address", "E-private_address", "S-private_address", "B-private_date", "I-private_date", "E-private_date", "S-private_date", "B-private_email", "I-private_email", "E-private_email", "S-private_email", "B-private_person", "I-private_person", "E-private_person", "S-private_person", "B-private_phone", "I-private_phone", "E-private_phone", "S-private_phone", "B-private_url", "I-private_url", "E-private_url", "S-private_url", "B-secret", "I-secret", "E-secret", "S-secret", ] # Sidecar labels SC_LABELS = ["O", "B-fax_number", "I-fax_number", "B-credit_card_last4", "I-credit_card_last4", "B-company_contact_block", "I-company_contact_block"] def run_pf(text): """Run original privacy-filter.""" inputs = tok_pf(text, return_tensors="pt", truncation=True, max_length=512) with torch.no_grad(): logits = model_pf(**inputs).logits preds = torch.argmax(logits, dim=2)[0].cpu().numpy() tokens = tok_pf.convert_ids_to_tokens(inputs["input_ids"][0]) return [(tok, PF_LABELS[p]) for tok, p in zip(tokens, preds)] def run_sc(text): """Run sidecar BERT for gap categories.""" inputs = tok_sc(text, return_tensors="pt", truncation=True, max_length=512) with torch.no_grad(): logits = model_sc(**inputs).logits preds = torch.argmax(logits, dim=2)[0].cpu().numpy() tokens = tok_sc.convert_ids_to_tokens(inputs["input_ids"][0]) return [(tok, SC_LABELS[p]) for tok, p in zip(tokens, preds)] def ensemble(text): """Combine both models: PF for original 8 cats, sidecar for 3 gap cats.""" pf_results = run_pf(text) sc_results = run_sc(text) return {"privacy_filter": pf_results, "sidecar": sc_results} if __name__ == "__main__": test_cases = [ "Please fax the contract to (555) 123-4567.", "Your card ending in 4242 has been charged.", "Acme Corp\n123 Main St, Springfield, IL 62701\nPhone: (555) 987-6543\nEmail: legal@acme.com", "Contact John Doe at john@example.com or call (555) 111-2222. Fax: (555) 333-4444.", ] for text in test_cases: print(f"\n{'='*60}") print(f"TEXT: {text}") results = ensemble(text) print("\n--- Privacy Filter (original 8 categories) ---") for tok, label in results["privacy_filter"]: if label != "O": print(f" {tok}: {label}") print("\n--- Sidecar (fax + CC-last4 + contact block) ---") for tok, label in results["sidecar"]: if label != "O": print(f" {tok}: {label}")