File size: 3,494 Bytes
89c6a38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#!/usr/bin/env python3
"""Ensemble inference: openai/privacy-filter + sidecar BERT for gap categories."""
from transformers import AutoTokenizer, AutoModelForTokenClassification
import torch

# Load original privacy-filter
print("Loading openai/privacy-filter...")
tok_pf = AutoTokenizer.from_pretrained("openai/privacy-filter", trust_remote_code=True)
model_pf = AutoModelForTokenClassification.from_pretrained("openai/privacy-filter", trust_remote_code=True)
model_pf.eval()

# Load sidecar for gap categories
print("Loading sidecar BERT...")
tok_sc = AutoTokenizer.from_pretrained("narcolepticchicken/privacy-filter-sidecar-bert")
model_sc = AutoModelForTokenClassification.from_pretrained("narcolepticchicken/privacy-filter-sidecar-bert")
model_sc.eval()

# Original privacy-filter labels
PF_LABELS = [
    "O", "B-account_number", "I-account_number", "E-account_number", "S-account_number",
    "B-private_address", "I-private_address", "E-private_address", "S-private_address",
    "B-private_date", "I-private_date", "E-private_date", "S-private_date",
    "B-private_email", "I-private_email", "E-private_email", "S-private_email",
    "B-private_person", "I-private_person", "E-private_person", "S-private_person",
    "B-private_phone", "I-private_phone", "E-private_phone", "S-private_phone",
    "B-private_url", "I-private_url", "E-private_url", "S-private_url",
    "B-secret", "I-secret", "E-secret", "S-secret",
]

# Sidecar labels
SC_LABELS = ["O", "B-fax_number", "I-fax_number", "B-credit_card_last4", "I-credit_card_last4",
              "B-company_contact_block", "I-company_contact_block"]

def run_pf(text):
    """Run original privacy-filter."""
    inputs = tok_pf(text, return_tensors="pt", truncation=True, max_length=512)
    with torch.no_grad():
        logits = model_pf(**inputs).logits
    preds = torch.argmax(logits, dim=2)[0].cpu().numpy()
    tokens = tok_pf.convert_ids_to_tokens(inputs["input_ids"][0])
    return [(tok, PF_LABELS[p]) for tok, p in zip(tokens, preds)]

def run_sc(text):
    """Run sidecar BERT for gap categories."""
    inputs = tok_sc(text, return_tensors="pt", truncation=True, max_length=512)
    with torch.no_grad():
        logits = model_sc(**inputs).logits
    preds = torch.argmax(logits, dim=2)[0].cpu().numpy()
    tokens = tok_sc.convert_ids_to_tokens(inputs["input_ids"][0])
    return [(tok, SC_LABELS[p]) for tok, p in zip(tokens, preds)]

def ensemble(text):
    """Combine both models: PF for original 8 cats, sidecar for 3 gap cats."""
    pf_results = run_pf(text)
    sc_results = run_sc(text)
    return {"privacy_filter": pf_results, "sidecar": sc_results}


if __name__ == "__main__":
    test_cases = [
        "Please fax the contract to (555) 123-4567.",
        "Your card ending in 4242 has been charged.",
        "Acme Corp\n123 Main St, Springfield, IL 62701\nPhone: (555) 987-6543\nEmail: legal@acme.com",
        "Contact John Doe at john@example.com or call (555) 111-2222. Fax: (555) 333-4444.",
    ]

    for text in test_cases:
        print(f"\n{'='*60}")
        print(f"TEXT: {text}")
        results = ensemble(text)
        print("\n--- Privacy Filter (original 8 categories) ---")
        for tok, label in results["privacy_filter"]:
            if label != "O":
                print(f"  {tok}: {label}")
        print("\n--- Sidecar (fax + CC-last4 + contact block) ---")
        for tok, label in results["sidecar"]:
            if label != "O":
                print(f"  {tok}: {label}")