Spaces:
Running on Zero
Running on Zero
File size: 1,570 Bytes
754de8e f949461 754de8e f949461 754de8e f949461 754de8e f949461 754de8e f949461 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 | """
OpenAI Privacy Filter — inference module for DLP Paste-Proxy.
Thin wrapper around `transformers.pipeline("token-classification", ...)` for
`openai/privacy-filter`. The pipeline's `aggregation_strategy="simple"` takes
care of BIOES → char-level span reconstruction for us.
Public API is unchanged from earlier revisions:
predict_text(text) -> (source_text, spans)
where each span is {label, start, end, text}.
"""
from __future__ import annotations
import functools
import os
import torch
MODEL_REPO = os.getenv("MODEL_ID", "openai/privacy-filter")
HF_TOKEN = os.getenv("HF_TOKEN", None)
@functools.lru_cache(maxsize=1)
def _get_pipeline():
from transformers import pipeline
return pipeline(
task="token-classification",
model=MODEL_REPO,
aggregation_strategy="simple",
device=0,
torch_dtype=torch.bfloat16,
token=HF_TOKEN,
)
def predict_text(text: str) -> tuple[str, list[dict]]:
"""Returns (source_text, spans). `spans` is a list of
{label, start, end, text} with character offsets into `text`."""
if not text or not text.strip():
return text, []
pipe = _get_pipeline()
results = pipe(text)
spans = []
for r in results:
label = r.get("entity_group") or r.get("entity")
if not label or label == "O":
continue
s, e = int(r["start"]), int(r["end"])
if e <= s or s < 0 or e > len(text):
continue
spans.append({"label": label, "start": s, "end": e, "text": text[s:e]})
return text, spans
|