File size: 5,029 Bytes
6a82282
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
abcf7cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a82282
abcf7cd
6a82282
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
"""GLiNER (urchade/gliner_medium-v2.1) typed-entity extraction over the
RAG retriever's top paragraphs.

Adds structured fields to the reconciler's grounding context. For each
RAG chunk the specialist emits, GLiNER produces a list of typed spans
with one of five labels:

    nyc_location          (e.g. "Coney Island")
    dollar_amount         (e.g. "$5.6 million")
    date_range            (e.g. "fiscal year 2025-2027")
    agency                (e.g. "NYC DEP")
    infrastructure_project (e.g. "Bluebelt expansion")

The doc_id for emission is `gliner_<source>` where `<source>` is the
RAG chunk's doc_id stripped of its `rag_` prefix. So `rag_comptroller`
becomes `gliner_comptroller`. The reconciler can then cite typed
fields with `[gliner_comptroller]`.

License: Apache-2.0 — `urchade/gliner_medium-v2.1` (NOT the
`gliner_base` variant, which is CC-BY-NC-4.0). See
experiments/shared/licenses.md.
"""

from __future__ import annotations

import logging
import os
from dataclasses import dataclass

log = logging.getLogger("riprap.gliner")

ENTITY_LABELS = [
    "nyc_location",
    "dollar_amount",
    "date_range",
    "agency",
    "infrastructure_project",
]

DEFAULT_THRESHOLD = float(os.environ.get("RIPRAP_GLINER_THRESHOLD", "0.45"))
MODEL_NAME = os.environ.get("RIPRAP_GLINER_MODEL", "urchade/gliner_medium-v2.1")
ENABLE = os.environ.get("RIPRAP_GLINER_ENABLE", "1").lower() in ("1", "true", "yes")

_MODEL = None  # lazy


@dataclass
class Extraction:
    label: str
    text: str
    score: float


def _ensure_model():
    """Lazy GLiNER load. Returns None if disabled or load fails so
    callers can silently fall back to no-op."""
    global _MODEL
    if not ENABLE:
        return None
    if _MODEL is not None:
        return _MODEL
    try:
        from gliner import GLiNER
        log.info("gliner: loading %s", MODEL_NAME)
        _MODEL = GLiNER.from_pretrained(MODEL_NAME)
    except Exception:
        log.exception("gliner: load failed; specialist will no-op")
        _MODEL = False  # sentinel
    return _MODEL or None


def warm():
    _ensure_model()


def _source_short(rag_doc_id: str) -> str:
    """`rag_comptroller` -> `comptroller`. Anything not prefixed `rag_`
    passes through unchanged."""
    return rag_doc_id[4:] if rag_doc_id.startswith("rag_") else rag_doc_id


def extract_for_chunk(text: str, threshold: float = DEFAULT_THRESHOLD) -> list[Extraction]:
    if not text:
        return []

    # v0.4.5 — try the MI300X service first. The remote handles its
    # own GLiNER load; this lets cpu-basic surfaces run typed
    # extraction without baking gliner into the image.
    try:
        from app import inference as _inf
        if _inf.remote_enabled():
            remote = _inf.gliner_extract(text, ENTITY_LABELS)
            if remote.get("ok"):
                return [
                    Extraction(label=e["label"], text=e["text"],
                               score=float(e.get("score", 0)))
                    for e in remote.get("entities", [])
                    if e.get("score", 0) >= threshold
                ]
    except _inf.RemoteUnreachable as e:
        log.info("gliner: remote unreachable (%s); local fallback", e)
    except Exception:
        log.exception("gliner: remote call failed; local fallback")

    model = _ensure_model()
    if model is None:
        return []
    raw = model.predict_entities(text, ENTITY_LABELS, threshold=threshold)
    return [Extraction(label=r["label"], text=r["text"],
                       score=float(r["score"])) for r in raw]


def extract_for_rag_hits(hits: list[dict],
                         threshold: float = DEFAULT_THRESHOLD,
                         max_hits: int = 3) -> dict[str, dict]:
    """Run GLiNER on the top-`max_hits` RAG hits. Returns a dict keyed by
    short source id (e.g. "comptroller") with the structured payload
    that the FSM stores into state["gliner"] and that
    reconcile.build_documents() consumes."""
    out: dict[str, dict] = {}
    if not hits:
        return out
    for h in hits[:max_hits]:
        source = _source_short(h.get("doc_id", "rag_unknown"))
        ents = extract_for_chunk(h.get("text", ""), threshold=threshold)
        if not ents:
            continue
        # Dedup verbatim repeats (common in agency PDFs that repeat
        # "DEP" 13 times in a methodology section).
        seen = set()
        deduped: list[Extraction] = []
        for e in ents:
            key = (e.label, e.text.lower())
            if key in seen:
                continue
            seen.add(key)
            deduped.append(e)
        out[source] = {
            "rag_doc_id": h.get("doc_id"),
            "title": h.get("title"),
            "paragraph_excerpt": h.get("text", "")[:240]
            + ("…" if len(h.get("text", "")) > 240 else ""),
            "n_entities": len(deduped),
            "entities": [{"label": e.label, "text": e.text,
                          "score": round(e.score, 3)} for e in deduped],
        }
    return out