File size: 3,610 Bytes
dd5bf53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#!/usr/bin/env python3
"""
Evaluate privacy-filter-enhanced on fax_number, credit_card_last4, company_contact_block.
Compare against baseline openai/privacy-filter.
"""

import json
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
from seqeval.metrics import classification_report

TEST_CASES = [
    {
        "text": "Please send the contract via fax to (555) 867-5309. Attn: Legal Dept.",
        "expected": [("(555) 867-5309", "fax_number")]
    },
    {
        "text": "Your Visa card ending in 4242 was charged $99.00 on 2024-01-15.",
        "expected": [("4242", "credit_card_last4")]
    },
    {
        "text": "Acme Corp\n123 Main St, Suite 100\nPhone: (555) 987-6543\nEmail: contact@acme.com",
        "expected": [("Acme Corp\n123 Main St, Suite 100\nPhone: (555) 987-6543\nEmail: contact@acme.com", "company_contact_block")]
    },
    {
        "text": "From: Alice Smith \u003calice@example.com\u003e\nFax: (212) 555-0199\nCard ending in 8765",
        "expected": [("Alice Smith", "private_person"), ("alice@example.com", "private_email"),
                     ("(212) 555-0199", "fax_number"), ("8765", "credit_card_last4")]
    },
]


def predict(model_name, text):
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    model = AutoModelForTokenClassification.from_pretrained(model_name, trust_remote_code=True)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = model.to(device)
    
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(device)
    with torch.no_grad():
        logits = model(**inputs).logits
    preds = torch.argmax(logits, dim=2)[0].cpu().numpy()
    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
    labels = [model.config.id2label.get(p, "O") for p in preds]
    
    # Decode BIOES spans
    spans = []
    i = 0
    while i < len(labels):
        lab = labels[i]
        if lab == "O" or lab.startswith("I-") or lab.startswith("E-"):
            i += 1
            continue
        if lab.startswith("S-"):
            cat = lab[2:]
            spans.append((tokens[i], cat))
            i += 1
        elif lab.startswith("B-"):
            cat = lab[2:]
            j = i + 1
            while j < len(labels) and labels[j] in (f"I-{cat}", f"E-{cat}"):
                j += 1
            span_tokens = tokens[i:j]
            spans.append(("".join(span_tokens).replace("##", ""), cat))
            i = j
        else:
            i += 1
    return spans


def main():
    models = {
        "baseline": "openai/privacy-filter",
        "enhanced": "narcolepticchicken/privacy-filter-enhanced",
    }
    
    for model_name, model_path in models.items():
        print(f"\n{'='*60}")
        print(f"Model: {model_name}")
        print(f"{'='*60}")
        for case in TEST_CASES:
            text = case["text"]
            expected = case["expected"]
            try:
                preds = predict(model_path, text)
            except Exception as e:
                preds = []
                print(f"ERROR: {e}")
            
            print(f"\nText: {text[:80]}...")
            print(f"Expected: {expected}")
            print(f"Predicted: {preds}")
            
            # Simple match check
            matched = 0
            for exp_text, exp_label in expected:
                found = any(exp_text in p[0] and p[1] == exp_label for p in preds)
                if found:
                    matched += 1
            print(f"Matches: {matched}/{len(expected)}")


if __name__ == "__main__":
    main()