narcolepticchicken's picture
Upload evaluate.py
dd5bf53 verified
#!/usr/bin/env python3
"""
Evaluate privacy-filter-enhanced on fax_number, credit_card_last4, company_contact_block.
Compare against baseline openai/privacy-filter.
"""
import json
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
from seqeval.metrics import classification_report
TEST_CASES = [
{
"text": "Please send the contract via fax to (555) 867-5309. Attn: Legal Dept.",
"expected": [("(555) 867-5309", "fax_number")]
},
{
"text": "Your Visa card ending in 4242 was charged $99.00 on 2024-01-15.",
"expected": [("4242", "credit_card_last4")]
},
{
"text": "Acme Corp\n123 Main St, Suite 100\nPhone: (555) 987-6543\nEmail: contact@acme.com",
"expected": [("Acme Corp\n123 Main St, Suite 100\nPhone: (555) 987-6543\nEmail: contact@acme.com", "company_contact_block")]
},
{
"text": "From: Alice Smith \u003calice@example.com\u003e\nFax: (212) 555-0199\nCard ending in 8765",
"expected": [("Alice Smith", "private_person"), ("alice@example.com", "private_email"),
("(212) 555-0199", "fax_number"), ("8765", "credit_card_last4")]
},
]
def predict(model_name, text):
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForTokenClassification.from_pretrained(model_name, trust_remote_code=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(device)
with torch.no_grad():
logits = model(**inputs).logits
preds = torch.argmax(logits, dim=2)[0].cpu().numpy()
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
labels = [model.config.id2label.get(p, "O") for p in preds]
# Decode BIOES spans
spans = []
i = 0
while i < len(labels):
lab = labels[i]
if lab == "O" or lab.startswith("I-") or lab.startswith("E-"):
i += 1
continue
if lab.startswith("S-"):
cat = lab[2:]
spans.append((tokens[i], cat))
i += 1
elif lab.startswith("B-"):
cat = lab[2:]
j = i + 1
while j < len(labels) and labels[j] in (f"I-{cat}", f"E-{cat}"):
j += 1
span_tokens = tokens[i:j]
spans.append(("".join(span_tokens).replace("##", ""), cat))
i = j
else:
i += 1
return spans
def main():
models = {
"baseline": "openai/privacy-filter",
"enhanced": "narcolepticchicken/privacy-filter-enhanced",
}
for model_name, model_path in models.items():
print(f"\n{'='*60}")
print(f"Model: {model_name}")
print(f"{'='*60}")
for case in TEST_CASES:
text = case["text"]
expected = case["expected"]
try:
preds = predict(model_path, text)
except Exception as e:
preds = []
print(f"ERROR: {e}")
print(f"\nText: {text[:80]}...")
print(f"Expected: {expected}")
print(f"Predicted: {preds}")
# Simple match check
matched = 0
for exp_text, exp_label in expected:
found = any(exp_text in p[0] and p[1] == exp_label for p in preds)
if found:
matched += 1
print(f"Matches: {matched}/{len(expected)}")
if __name__ == "__main__":
main()