File size: 1,570 Bytes
754de8e
 
 
f949461
 
 
754de8e
f949461
 
 
754de8e
 
 
 
 
 
 
 
 
f949461
754de8e
 
 
 
f949461
 
 
 
 
 
 
 
 
 
754de8e
 
 
f949461
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
"""
OpenAI Privacy Filter — inference module for DLP Paste-Proxy.

Thin wrapper around `transformers.pipeline("token-classification", ...)` for
`openai/privacy-filter`. The pipeline's `aggregation_strategy="simple"` takes
care of BIOES → char-level span reconstruction for us.

Public API is unchanged from earlier revisions:
    predict_text(text) -> (source_text, spans)
where each span is {label, start, end, text}.
"""

from __future__ import annotations

import functools
import os

import torch

MODEL_REPO = os.getenv("MODEL_ID", "openai/privacy-filter")
HF_TOKEN = os.getenv("HF_TOKEN", None)


@functools.lru_cache(maxsize=1)
def _get_pipeline():
    from transformers import pipeline
    return pipeline(
        task="token-classification",
        model=MODEL_REPO,
        aggregation_strategy="simple",
        device=0,
        torch_dtype=torch.bfloat16,
        token=HF_TOKEN,
    )


def predict_text(text: str) -> tuple[str, list[dict]]:
    """Returns (source_text, spans). `spans` is a list of
    {label, start, end, text} with character offsets into `text`."""
    if not text or not text.strip():
        return text, []
    pipe = _get_pipeline()
    results = pipe(text)
    spans = []
    for r in results:
        label = r.get("entity_group") or r.get("entity")
        if not label or label == "O":
            continue
        s, e = int(r["start"]), int(r["end"])
        if e <= s or s < 0 or e > len(text):
            continue
        spans.append({"label": label, "start": s, "end": e, "text": text[s:e]})
    return text, spans