#!/usr/bin/env python3 """ Evaluate privacy-filter-enhanced on fax_number, credit_card_last4, company_contact_block. Compare against baseline openai/privacy-filter. """ import json import torch from transformers import AutoTokenizer, AutoModelForTokenClassification from seqeval.metrics import classification_report TEST_CASES = [ { "text": "Please send the contract via fax to (555) 867-5309. Attn: Legal Dept.", "expected": [("(555) 867-5309", "fax_number")] }, { "text": "Your Visa card ending in 4242 was charged $99.00 on 2024-01-15.", "expected": [("4242", "credit_card_last4")] }, { "text": "Acme Corp\n123 Main St, Suite 100\nPhone: (555) 987-6543\nEmail: contact@acme.com", "expected": [("Acme Corp\n123 Main St, Suite 100\nPhone: (555) 987-6543\nEmail: contact@acme.com", "company_contact_block")] }, { "text": "From: Alice Smith \u003calice@example.com\u003e\nFax: (212) 555-0199\nCard ending in 8765", "expected": [("Alice Smith", "private_person"), ("alice@example.com", "private_email"), ("(212) 555-0199", "fax_number"), ("8765", "credit_card_last4")] }, ] def predict(model_name, text): tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) model = AutoModelForTokenClassification.from_pretrained(model_name, trust_remote_code=True) device = "cuda" if torch.cuda.is_available() else "cpu" model = model.to(device) inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(device) with torch.no_grad(): logits = model(**inputs).logits preds = torch.argmax(logits, dim=2)[0].cpu().numpy() tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) labels = [model.config.id2label.get(p, "O") for p in preds] # Decode BIOES spans spans = [] i = 0 while i < len(labels): lab = labels[i] if lab == "O" or lab.startswith("I-") or lab.startswith("E-"): i += 1 continue if lab.startswith("S-"): cat = lab[2:] spans.append((tokens[i], cat)) i += 1 elif lab.startswith("B-"): cat = lab[2:] j = i + 1 while j < len(labels) and labels[j] in (f"I-{cat}", f"E-{cat}"): j += 1 span_tokens = tokens[i:j] spans.append(("".join(span_tokens).replace("##", ""), cat)) i = j else: i += 1 return spans def main(): models = { "baseline": "openai/privacy-filter", "enhanced": "narcolepticchicken/privacy-filter-enhanced", } for model_name, model_path in models.items(): print(f"\n{'='*60}") print(f"Model: {model_name}") print(f"{'='*60}") for case in TEST_CASES: text = case["text"] expected = case["expected"] try: preds = predict(model_path, text) except Exception as e: preds = [] print(f"ERROR: {e}") print(f"\nText: {text[:80]}...") print(f"Expected: {expected}") print(f"Predicted: {preds}") # Simple match check matched = 0 for exp_text, exp_label in expected: found = any(exp_text in p[0] and p[1] == exp_label for p in preds) if found: matched += 1 print(f"Matches: {matched}/{len(expected)}") if __name__ == "__main__": main()