Spaces:
Configuration error
Configuration error
File size: 5,029 Bytes
6a82282 abcf7cd 6a82282 abcf7cd 6a82282 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 | """GLiNER (urchade/gliner_medium-v2.1) typed-entity extraction over the
RAG retriever's top paragraphs.
Adds structured fields to the reconciler's grounding context. For each
RAG chunk the specialist emits, GLiNER produces a list of typed spans
with one of five labels:
nyc_location (e.g. "Coney Island")
dollar_amount (e.g. "$5.6 million")
date_range (e.g. "fiscal year 2025-2027")
agency (e.g. "NYC DEP")
infrastructure_project (e.g. "Bluebelt expansion")
The doc_id for emission is `gliner_<source>` where `<source>` is the
RAG chunk's doc_id stripped of its `rag_` prefix. So `rag_comptroller`
becomes `gliner_comptroller`. The reconciler can then cite typed
fields with `[gliner_comptroller]`.
License: Apache-2.0 — `urchade/gliner_medium-v2.1` (NOT the
`gliner_base` variant, which is CC-BY-NC-4.0). See
experiments/shared/licenses.md.
"""
from __future__ import annotations
import logging
import os
from dataclasses import dataclass
log = logging.getLogger("riprap.gliner")
ENTITY_LABELS = [
"nyc_location",
"dollar_amount",
"date_range",
"agency",
"infrastructure_project",
]
DEFAULT_THRESHOLD = float(os.environ.get("RIPRAP_GLINER_THRESHOLD", "0.45"))
MODEL_NAME = os.environ.get("RIPRAP_GLINER_MODEL", "urchade/gliner_medium-v2.1")
ENABLE = os.environ.get("RIPRAP_GLINER_ENABLE", "1").lower() in ("1", "true", "yes")
_MODEL = None # lazy
@dataclass
class Extraction:
label: str
text: str
score: float
def _ensure_model():
"""Lazy GLiNER load. Returns None if disabled or load fails so
callers can silently fall back to no-op."""
global _MODEL
if not ENABLE:
return None
if _MODEL is not None:
return _MODEL
try:
from gliner import GLiNER
log.info("gliner: loading %s", MODEL_NAME)
_MODEL = GLiNER.from_pretrained(MODEL_NAME)
except Exception:
log.exception("gliner: load failed; specialist will no-op")
_MODEL = False # sentinel
return _MODEL or None
def warm():
_ensure_model()
def _source_short(rag_doc_id: str) -> str:
"""`rag_comptroller` -> `comptroller`. Anything not prefixed `rag_`
passes through unchanged."""
return rag_doc_id[4:] if rag_doc_id.startswith("rag_") else rag_doc_id
def extract_for_chunk(text: str, threshold: float = DEFAULT_THRESHOLD) -> list[Extraction]:
if not text:
return []
# v0.4.5 — try the MI300X service first. The remote handles its
# own GLiNER load; this lets cpu-basic surfaces run typed
# extraction without baking gliner into the image.
try:
from app import inference as _inf
if _inf.remote_enabled():
remote = _inf.gliner_extract(text, ENTITY_LABELS)
if remote.get("ok"):
return [
Extraction(label=e["label"], text=e["text"],
score=float(e.get("score", 0)))
for e in remote.get("entities", [])
if e.get("score", 0) >= threshold
]
except _inf.RemoteUnreachable as e:
log.info("gliner: remote unreachable (%s); local fallback", e)
except Exception:
log.exception("gliner: remote call failed; local fallback")
model = _ensure_model()
if model is None:
return []
raw = model.predict_entities(text, ENTITY_LABELS, threshold=threshold)
return [Extraction(label=r["label"], text=r["text"],
score=float(r["score"])) for r in raw]
def extract_for_rag_hits(hits: list[dict],
threshold: float = DEFAULT_THRESHOLD,
max_hits: int = 3) -> dict[str, dict]:
"""Run GLiNER on the top-`max_hits` RAG hits. Returns a dict keyed by
short source id (e.g. "comptroller") with the structured payload
that the FSM stores into state["gliner"] and that
reconcile.build_documents() consumes."""
out: dict[str, dict] = {}
if not hits:
return out
for h in hits[:max_hits]:
source = _source_short(h.get("doc_id", "rag_unknown"))
ents = extract_for_chunk(h.get("text", ""), threshold=threshold)
if not ents:
continue
# Dedup verbatim repeats (common in agency PDFs that repeat
# "DEP" 13 times in a methodology section).
seen = set()
deduped: list[Extraction] = []
for e in ents:
key = (e.label, e.text.lower())
if key in seen:
continue
seen.add(key)
deduped.append(e)
out[source] = {
"rag_doc_id": h.get("doc_id"),
"title": h.get("title"),
"paragraph_excerpt": h.get("text", "")[:240]
+ ("…" if len(h.get("text", "")) > 240 else ""),
"n_entities": len(deduped),
"entities": [{"label": e.label, "text": e.text,
"score": round(e.score, 3)} for e in deduped],
}
return out
|