Mobe1's picture
Swap stage-2 v3 β†’ v4 (mAP@.5 0.572 β†’ 0.601; caries 0.035 β†’ 0.071 from DENTEX-c diagnosis labels at 1664x928)
45f026a verified
#!/usr/bin/env python3
# /// script
# requires-python = ">=3.11,<3.12"
# dependencies = [
# "gradio>=5.34.0,<6.0",
# "transformers>=4.45.0,<5.0",
# "torch>=2.5.0,<3.0",
# "pillow>=10.0.0,<12.0",
# "huggingface-hub>=0.25.0,<1.0",
# "httpx>=0.27.0,<1.0",
# ]
# ///
"""Gradio demo: upload an OPG, see argos-dentsight stage-1 + stage-2 detections.
Local run:
uv run demo/app.py
HF Spaces deploy: see demo/README.md (rename to root + push).
Two private model repos are loaded:
- `Mobe1/argos-dentsight-stage1-fdi-v3` β€” 32 FDI tooth localizer
- `Mobe1/argos-dentsight-stage2-conditions-v4` β€” 13 dental-condition detector
Both repos are private, so the demo needs HF_TOKEN with read access. Locally
that means a `.env` at the repo root containing HF_TOKEN=...; on HF Spaces it
means setting HF_TOKEN as a Space secret.
"""
from __future__ import annotations
import json
import os
import re
import time
from typing import Any
import gradio as gr
import httpx
import torch
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForObjectDetection
try:
import spaces # type: ignore[import-not-found]
_gpu = spaces.GPU
except ImportError:
# Local dev without HF Spaces ZeroGPU: decorator is a no-op.
def _gpu(fn=None, **_kwargs):
if callable(fn):
return fn
return lambda f: f
REPO_ID_STAGE1 = "Mobe1/argos-dentsight-stage1-fdi-v3"
REPO_ID_STAGE2 = "Mobe1/argos-dentsight-stage2-conditions-v4"
DEFAULT_THR_STAGE1 = 0.02 # v3 score range ~[0.02, 0.08]
DEFAULT_THR_STAGE2 = 0.005 # v3 cold-start widened the band slightly but still ~[0.005, 0.022]
STAGE2_MAX_DETECTIONS = 30 # cap displayed boxes β€” model emits 300 raw, most are noise
UNASSIGNED_MAX_DISPLAY = 5 # cap on top-N unassigned conditions surfaced to the user
# ---------------------------------------------------------------------------
# Vendored from src/argos_dentsight/postproc/{constants.py,geometry.py,assignment.py}
# Demo runs as a self-contained UV-style file; we copy the small subset we need
# rather than packaging src/ into the Space build.
# ---------------------------------------------------------------------------
FDI_CLASSES: frozenset[str] = frozenset(
{
"11", "12", "13", "14", "15", "16", "17", "18",
"21", "22", "23", "24", "25", "26", "27", "28",
"31", "32", "33", "34", "35", "36", "37", "38",
"41", "42", "43", "44", "45", "46", "47", "48",
}
)
WISDOM_TEETH: frozenset[str] = frozenset({"18", "28", "38", "48"})
PATHOLOGICAL_CLASSES: frozenset[str] = frozenset(
{
"caries",
"calculus",
"impacted",
"periapical-radiolucency",
"root-stump",
"missing",
"other-finding",
}
)
TREATMENT_CLASSES: frozenset[str] = frozenset(
{"RC-treated", "crown", "restoration", "bridge", "implant", "tooth-bud"}
)
ASSIGNMENT_OVERLAP_THRESHOLD: float = 0.5 # IoMin β‰₯ 0.5 (smaller box β‰₯50% inside larger)
# Paediatric-class suppression: see filtering.suppress_paediatric_classes_on_adult_opgs.
# Stage-2 misclassifies bright high-density material (crowns, fillings) as
# `tooth-bud` at scores up to 0.028 β€” comparable to legitimate detections.
# `tooth-bud` is paediatric anatomy; on an adult OPG (β‰₯20 detected FDIs)
# we drop them. Below the threshold (mixed dentition) we leave them alone.
ADULT_OPG_FDI_THRESHOLD: int = 20
PAEDIATRIC_CLASSES: frozenset[str] = frozenset({"tooth-bud"})
def _suppress_paediatric_on_adult(
teeth_preds: list[dict[str, Any]],
cond_preds: list[dict[str, Any]],
) -> list[dict[str, Any]]:
if len(teeth_preds) < ADULT_OPG_FDI_THRESHOLD:
return list(cond_preds)
return [c for c in cond_preds if c["class_name"] not in PAEDIATRIC_CLASSES]
def _intersection_area(box1: list[float], box2: list[float]) -> float:
x1, y1, x2, y2 = box1
x1_p, y1_p, x2_p, y2_p = box2
inter_x1 = max(x1, x1_p)
inter_y1 = max(y1, y1_p)
inter_x2 = min(x2, x2_p)
inter_y2 = min(y2, y2_p)
return max(0.0, inter_x2 - inter_x1) * max(0.0, inter_y2 - inter_y1)
def _box_area(box: list[float]) -> float:
x1, y1, x2, y2 = box
return max(0.0, x2 - x1) * max(0.0, y2 - y1)
def _compute_iomin(box1: list[float], box2: list[float]) -> float:
"""Intersection over min(area1, area2). Symmetric, range [0, 1].
Vendored from postproc.geometry.compute_iomin. Replaces the earlier
one-directional containment metric because RC-treated detections often
have bboxes slightly larger than the tooth β€” one-directional containment
(intersection / cond_area) drops below 0.5 even when the cond fully
covers the tooth. IoMin always normalizes by the smaller box, so the
metric is symmetric in box-size asymmetry.
"""
a1 = _box_area(box1)
a2 = _box_area(box2)
min_area = min(a1, a2)
if min_area <= 0.0:
return 0.0
return _intersection_area(box1, box2) / min_area
def _assign_diseases_to_teeth(
teeth_preds: list[dict[str, Any]],
cond_preds: list[dict[str, Any]],
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
"""Vendored from postproc.assignment.assign_diseases_to_teeth.
Each condition is assigned to whichever tooth has the highest IoMin
overlap (intersection / area-of-smaller-box), provided IoMin β‰₯
`ASSIGNMENT_OVERLAP_THRESHOLD`. `implant` is left unassigned per source
rule.
Returns:
structured_teeth: list of {fdi, score, bbox, conditions: [...]} per tooth
unassigned: list of condition preds with no qualifying overlap match
"""
structured_teeth: list[dict[str, Any]] = [
{
"fdi": tooth["fdi"],
"score": tooth["score"],
"bbox": tooth["bbox"],
"conditions": [],
}
for tooth in teeth_preds
]
assigned_flags = [False] * len(cond_preds)
for i, cond in enumerate(cond_preds):
if cond["class_name"] == "implant":
continue
best_idx: int | None = None
best_overlap = 0.0
for j, tooth in enumerate(teeth_preds):
o = _compute_iomin(cond["bbox"], tooth["bbox"])
if o > best_overlap:
best_overlap = o
best_idx = j
if best_idx is not None and best_overlap >= ASSIGNMENT_OVERLAP_THRESHOLD:
structured_teeth[best_idx]["conditions"].append(cond)
assigned_flags[i] = True
unassigned = [
c for i, c in enumerate(cond_preds) if not assigned_flags[i]
]
return structured_teeth, unassigned
# ---------------------------------------------------------------------------
# Phase 8 β€” LLM medical-report generation (vendored from src/argos_dentsight/llm/)
#
# Same vendoring strategy as the postproc block: copy the small subset the
# Space needs rather than packaging src/ into the build. If you change either
# the prompt or the grounding-caption format here, mirror the change in
# src/argos_dentsight/llm/{prompts.py,grounding_caption.py} or the offline
# batch outputs will diverge.
# ---------------------------------------------------------------------------
OLLAMA_ENDPOINT: str = "https://ollama.com/api/chat"
OLLAMA_TIMEOUT_S: float = 180.0 # gemma4:31b can take 30-90s on long captions
REPORT_MODEL_OPTIONS: list[str] = [
# Each entry empirically verified against https://ollama.com/api/chat with
# a 2-token probe. Tags listed on Ollama's model pages but not served on
# the cloud chat endpoint (e.g. gemma4:26b, gemma4:e4b) 404 here even
# though they exist for local `ollama run`.
"gemma4:31b", # ~7s, balanced default
"deepseek-v4-flash", # fastest reasoning option (~1s warm)
"deepseek-v4-pro", # most thorough; verbose (~28s end-to-end)
"kimi-k2.6", # multimodal-capable, ~2s warm
"gemini-3-flash-preview", # Google flash variant, fast
]
DEFAULT_REPORT_MODEL: str = "gemma4:31b"
# Generator-throttle interval. The model emits ~22ms/token (gemma) or ~15ms/
# token (deepseek). Yielding to Gradio per-chunk = 50-70 UI updates/sec,
# which the browser's Markdown re-layout cannot keep up with as the
# accumulated text grows β€” perceived as a "slowdown midway". Coalescing
# into ~14 yields/sec keeps the stream visually smooth without changing
# total time-to-completion.
PROGRESS_YIELD_INTERVAL_S: float = 0.07
# Vendored verbatim from src/argos_dentsight/llm/prompts.py::REPORT_SYSTEM_PROMPT
REPORT_SYSTEM_PROMPT: str = """
You are a professional oral radiologist assistant tasked with generating precise and clinically accurate oral panoramic X-ray examination reports based on structured localization data.
The structured data contains all detected teeth and dental conditions. Each condition is associated with a specific tooth number.
If a finding is not directly on a tooth, it will have 'tooth_id': 'unknown' and a 'near_tooth': '[tooth_id]' field, which you should report as "near tooth #[tooth_id]".
Generate a formal and comprehensive oral examination report **ONLY** containing two mandatory sections:
1. **Teeth-Specific Observations**
2. **Clinical Summary & Recommendations**
The **Teeth-Specific Observations** section must comprise three subsections:
1. **General Condition**: Outlines overall dental status, including the count of visualized teeth and wisdom teeth status (e.g., presence or impaction).
2. **Pathological Findings**: Documents dental diseases such as caries, impacted teeth, calculus, or periapical radiolucency.
3. **Historical Interventions**: Details prior treatments like fillings (restorations), crowns, root canal treatments, or implants.
Each finding in the structured data has a confidence score. You must apply the following processing rules **ONLY** for the **Pathological Findings** subsection:
* For confidence scores **< 0.80**: Use terms like "suspicious for...", "suggests...", or "areas of concern noted for..." in the description.
* For confidence scores **β‰₯ 0.80**: Use definitive descriptors such as "sign of...", "shows evidence of...", or "clear indication of...".
The **Historical Interventions** subsection should always use definitive language (e.g., "presence of a crown," "rc-treated tooth noted"), as these are observed facts.
Please strictly follow the following requirements:
* **Adherence to FDI numbering system** (e.g., "#11", "#26").
* **Use professional medical terminology** while maintaining clarity.
* **DO NOT** include or reference the confidence scores in any form in the final report. Their *only* use is to determine the certainty language ("suspicious" vs. "sign of").
* **DO NOT** generate any administrative content like 'Patient Name', 'Date', etc.
* **Generate a new Clinical Summary & Recommendations** section. This section is critical and must be created from the findings. It must include:
1. **Priority Concerns**: The most urgent issues found (e.g., "Deep caries on #28", "Impacted wisdom tooth #18 requiring evaluation").
2. **Preventive Measures**: Recommendations for prevention (e.g., "Monitor areas of suspected calculus", "Reinforce oral hygiene").
3. **Follow-up Protocol**: Specific recall or follow-up actions (e.g., "6-month recall for monitoring", "Referral to endodontist for #26").
Now, generate a new report for the following input:
"""
def _clean_llm_output(raw_text: str) -> str:
"""Strip a leading `<think>...</think>` block (Gemma 4 reasoning trace).
Vendored from src/argos_dentsight/llm/output_cleaner.py::clean_llm_output.
No-op for non-thinking models.
"""
match = re.search(r"</think>(.*)", raw_text, re.DOTALL | re.IGNORECASE)
if match:
return match.group(1).strip()
return raw_text.strip()
def _format_section(title: str, items: list) -> str:
if not items:
return f"{title} (total: 0):\n[\n]"
section_title = f"{title} (total: {len(items)}):"
formatted = [
" " + json.dumps(item, ensure_ascii=False).replace('"', "'") for item in items
]
return f"{section_title}\n[\n" + ",\n".join(formatted) + "\n]"
def _build_grounding_caption(per_tooth_json: list[dict[str, Any]]) -> str:
"""Render the LLM-ready structured caption from the demo's per-tooth JSON.
The demo's per_tooth_json is the output of `_per_tooth_outputs` β€” a list of
one dict per detected tooth followed by a single trailing dict
`{unassigned_top, unassigned_total}`. This function adapts that shape into
the multi-section caption format expected by `REPORT_SYSTEM_PROMPT`.
Sections emitted (in order):
- Teeth visibility with center points
- Wisdom teeth detection
- Missing Teeth (FDIs not detected β€” derived; see caveat in app docstring)
- Dental Pathological Findings
- Historical Treatments
"""
teeth = [t for t in per_tooth_json if "fdi" in t]
trailing = next(
(t for t in per_tooth_json if "unassigned_top" in t),
{"unassigned_top": [], "unassigned_total": 0},
)
unassigned = trailing.get("unassigned_top", [])
# Teeth visibility (center points)
visibility = []
for t in teeth:
x1, y1, x2, y2 = t["bbox_xyxy"]
visibility.append(
{
"point_2d": [round((x1 + x2) / 2), round((y1 + y2) / 2)],
"tooth_id": t["fdi"],
"score": round(t.get("tooth_score", 0), 2),
}
)
# Wisdom teeth (subset of detected teeth that are 18/28/38/48)
wisdom = []
for t in teeth:
if t["fdi"] not in WISDOM_TEETH:
continue
is_impacted = any(c["class_name"] == "impacted" for c in t.get("conditions", []))
x1, y1, x2, y2 = t["bbox_xyxy"]
wisdom.append(
{
"box_2d": [round(x1), round(y1), round(x2), round(y2)],
"tooth_id": t["fdi"],
"is_impacted": is_impacted,
"score": round(t.get("tooth_score", 0), 2),
}
)
# Teeth Not Visualized: FDIs not present in the detector's output.
# Renamed from spec Β§8's "Missing Teeth" because empirically the LLM treats
# that label as a clinical absence-finding. With stage-1's known weak
# classes (e.g. FDI 48 mAP=0.41) and bias-collapse-compressed scores, an
# FDI absence in the list is "model didn't surface it", NOT "patient is
# missing it" β€” a distinction worth keeping rigid for the LLM.
detected_fdis = {t["fdi"] for t in teeth}
missing_items = [{"tooth_id": fdi} for fdi in sorted(FDI_CLASSES - detected_fdis)]
# Build a quick lookup of each tooth's center for `near_tooth` derivation
# (used on unassigned conditions; LLM is told to phrase those as
# "near tooth #X" via the canonical prompt).
tooth_centers = [
(
t["fdi"],
(t["bbox_xyxy"][0] + t["bbox_xyxy"][2]) / 2,
(t["bbox_xyxy"][1] + t["bbox_xyxy"][3]) / 2,
)
for t in teeth
]
def _closest_fdi(box: list[float]) -> str | None:
if not tooth_centers:
return None
cx, cy = (box[0] + box[2]) / 2, (box[1] + box[3]) / 2
best_fdi: str | None = None
best_d = float("inf")
for fdi, tx, ty in tooth_centers:
d = (cx - tx) ** 2 + (cy - ty) ** 2
if d < best_d:
best_d, best_fdi = d, fdi
return best_fdi
# Pathological + Historical, walking each tooth's conditions and the
# unassigned tail. Each tooth's conditions inherit that tooth's FDI;
# unassigned ones get tooth_id='unknown' + a near_tooth pointer.
pathological: list[dict[str, Any]] = []
historical: list[dict[str, Any]] = []
for t in teeth:
for c in t.get("conditions", []):
entry = {
"box_2d": [round(b) for b in c["bbox_xyxy"]],
"tooth_id": t["fdi"],
"label": c["class_name"],
"score": round(c.get("score", 0), 2),
}
if c["class_name"] in PATHOLOGICAL_CLASSES:
pathological.append(entry)
elif c["class_name"] in TREATMENT_CLASSES:
historical.append(entry)
for c in unassigned:
entry: dict[str, Any] = {
"box_2d": [round(b) for b in c["bbox_xyxy"]],
"tooth_id": "unknown",
"label": c["class_name"],
"score": round(c.get("score", 0), 2),
}
nearest = _closest_fdi(c["bbox_xyxy"])
if nearest is not None:
entry["near_tooth"] = nearest
if c["class_name"] in PATHOLOGICAL_CLASSES:
pathological.append(entry)
elif c["class_name"] in TREATMENT_CLASSES:
historical.append(entry)
parts: list[str] = [
"This localization caption provides multi-dimensional spatial analysis "
"of the patient's panoramic dental radiograph.",
_format_section("Teeth visibility with center points", visibility),
_format_section("Wisdom teeth detection", wisdom),
_format_section(
"Teeth Not Visualized in This Image (model did not surface a "
"detection β€” could be genuinely absent or below detection threshold)",
missing_items,
),
_format_section("Dental Pathological Findings", pathological),
_format_section("Historical Treatments", historical),
]
return "\n\n".join(parts)
def _stream_ollama_chat(
*, model: str, system: str, user: str, api_key: str
):
"""Generator: yield assistant content chunks from Ollama Cloud's streaming
chat endpoint.
Each yielded chunk is the *incremental* text since the last chunk (not
cumulative). Callers accumulate. Raises `RuntimeError` on non-200 or
transport failure with a user-safe message (key never leaked). The
advantage over the previous non-streaming call: users see the report
appearing token-by-token instead of waiting 8-30s for it to land at once,
which on cold Ollama Cloud models can stretch to a minute+.
"""
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
payload = {
"model": model,
"messages": [
{"role": "system", "content": system},
{"role": "user", "content": user},
],
"stream": True,
}
try:
client = httpx.Client(timeout=OLLAMA_TIMEOUT_S)
try:
with client.stream(
"POST", OLLAMA_ENDPOINT, headers=headers, json=payload
) as r:
if r.status_code != 200:
body_bytes = b"".join(r.iter_bytes())
try:
err = json.loads(body_bytes).get("error", "")
except Exception:
err = body_bytes.decode("utf-8", "replace")[:200]
raise RuntimeError(
f"Ollama Cloud HTTP {r.status_code} on `{model}`: {err}"
)
for line in r.iter_lines():
if not line:
continue
try:
chunk = json.loads(line)
except json.JSONDecodeError:
continue
content = chunk.get("message", {}).get("content", "")
if content:
yield content
if chunk.get("done"):
break
finally:
client.close()
except httpx.TimeoutException as e:
raise RuntimeError(
f"Ollama Cloud timed out after {OLLAMA_TIMEOUT_S:.0f}s on `{model}`. "
"Try a smaller model (e.g. gemma4:26b) or rerun."
) from e
except httpx.HTTPError as e:
raise RuntimeError(f"Network error calling Ollama Cloud: {type(e).__name__}") from e
def _clean_streaming_view(accumulated: str) -> str | None:
"""Cleaned output for progressive display while a stream is in flight.
Returns:
- `None` if we're still inside an open `<think>` block (caller should
keep the existing placeholder visible).
- cleaned text otherwise (everything after `</think>`, or the raw
accumulated text if no think block is present).
"""
lower = accumulated.lower()
if "<think>" in lower and "</think>" not in lower:
return None
return _clean_llm_output(accumulated)
def _load_token() -> str | None:
"""Read HF_TOKEN from env. Falls back to a .env file at the repo root."""
token = os.environ.get("HF_TOKEN")
if token:
return token
env_path = os.path.join(os.path.dirname(__file__), "..", ".env")
if os.path.exists(env_path):
with open(env_path) as f:
for line in f:
line = line.strip()
if line.startswith("HF_TOKEN="):
return line.split("=", 1)[1].strip().strip("\"'")
return None
print(f"Loading stage-1: {REPO_ID_STAGE1}...")
TOKEN = _load_token()
PROC_S1 = AutoImageProcessor.from_pretrained(REPO_ID_STAGE1, token=TOKEN)
MODEL_S1 = AutoModelForObjectDetection.from_pretrained(REPO_ID_STAGE1, token=TOKEN)
MODEL_S1.eval()
ID2LABEL_S1 = MODEL_S1.config.id2label
print(f"Loading stage-2: {REPO_ID_STAGE2}...")
PROC_S2 = AutoImageProcessor.from_pretrained(REPO_ID_STAGE2, token=TOKEN)
MODEL_S2 = AutoModelForObjectDetection.from_pretrained(REPO_ID_STAGE2, token=TOKEN)
MODEL_S2.eval()
ID2LABEL_S2 = MODEL_S2.config.id2label
print(
f"Loaded. Stage-1: {len(ID2LABEL_S1)} FDIs. "
f"Stage-2: {len(ID2LABEL_S2)} conditions."
)
def _stage1_outputs(
res: dict[str, Any], pil: Image.Image
) -> tuple[Any, list[dict[str, Any]], list[dict[str, Any]]]:
"""Per-class top-1 dedup for stage-1: one box per FDI (the class argmax).
Returns (annotated_image_tuple, json_for_ui, raw_preds_for_assignment).
The 3rd return is internal-only β€” predictions in postproc-compatible shape:
{fdi, score, bbox: [x1, y1, x2, y2]} β€” used for IoU assignment.
"""
boxes = res["boxes"].cpu().tolist()
scores = res["scores"].cpu().tolist()
labels = res["labels"].cpu().tolist()
best: dict[int, tuple[float, list[float]]] = {}
for box, score, lid in zip(boxes, scores, labels, strict=True):
cur = best.get(int(lid))
if cur is None or score > cur[0]:
best[int(lid)] = (score, box)
annotations: list[tuple[tuple[int, int, int, int], str]] = []
detections: list[dict[str, Any]] = []
raw_preds: list[dict[str, Any]] = []
for cid, (score, box) in best.items():
x0, y0, x1, y1 = (round(c) for c in box)
fdi = ID2LABEL_S1.get(int(cid), str(cid))
annotations.append(((x0, y0, x1, y1), f"{fdi} ({score:.2f})"))
detections.append(
{
"fdi": fdi,
"score": round(score, 3),
"bbox_xyxy": [round(c, 1) for c in box],
}
)
raw_preds.append({"fdi": fdi, "score": float(score), "bbox": list(box)})
detections.sort(key=lambda d: d["fdi"])
return (pil, annotations), detections, raw_preds
def _stage2_outputs(
res: dict[str, Any], pil: Image.Image
) -> tuple[Any, list[dict[str, Any]], list[dict[str, Any]]]:
"""Stage-2: keep top-N detections above threshold, sorted by score.
No per-class dedup β€” multiple instances per class are valid (multiple
crowns, multiple caries lesions). Cap at STAGE2_MAX_DETECTIONS so the UI
doesn't drown in low-confidence noise.
Returns (annotated_image_tuple, json_for_ui, raw_preds_for_assignment).
raw_preds shape: {class_name, score, bbox: [x1, y1, x2, y2]}.
"""
boxes = res["boxes"].cpu().tolist()
scores = res["scores"].cpu().tolist()
labels = res["labels"].cpu().tolist()
# Sort by score desc, take top N.
indexed = sorted(
zip(boxes, scores, labels, strict=True), key=lambda t: -t[1]
)[:STAGE2_MAX_DETECTIONS]
annotations: list[tuple[tuple[int, int, int, int], str]] = []
detections: list[dict[str, Any]] = []
raw_preds: list[dict[str, Any]] = []
for box, score, lid in indexed:
x0, y0, x1, y1 = (round(c) for c in box)
cname = ID2LABEL_S2.get(int(lid), str(lid))
annotations.append(((x0, y0, x1, y1), f"{cname} ({score:.3f})"))
detections.append(
{
"condition": cname,
"score": round(score, 4),
"bbox_xyxy": [round(c, 1) for c in box],
}
)
raw_preds.append(
{"class_name": cname, "score": float(score), "bbox": list(box)}
)
return (pil, annotations), detections, raw_preds
def _category_for(class_name: str) -> str:
"""Classify a stage-2 condition for the per-tooth panel readout."""
if class_name in PATHOLOGICAL_CLASSES:
return "pathological"
if class_name in TREATMENT_CLASSES:
return "treatment"
return "other"
def _per_tooth_outputs(
teeth_preds: list[dict[str, Any]],
cond_preds: list[dict[str, Any]],
pil: Image.Image,
) -> tuple[Any, str, list[dict[str, Any]]]:
"""Build the per-tooth findings tab outputs.
Pipeline:
1. Suppress paediatric classes (tooth-bud) on adult OPGs.
2. Containment-match conditions to teeth (best-fit wins).
3. Render annotations only for teeth WITH findings (image not cluttered
by 32 redundant tooth boxes β€” the Teeth (FDI) tab already shows those).
4. Emit JSON for downstream LLM-report consumption: one entry per tooth
(regardless of whether it has findings β€” the LLM needs to know what
was visualized) plus a trailing unassigned block.
"""
suppressed_conds = _suppress_paediatric_on_adult(teeth_preds, cond_preds)
suppressed_count = len(cond_preds) - len(suppressed_conds)
structured_teeth, unassigned = _assign_diseases_to_teeth(teeth_preds, suppressed_conds)
# Sort teeth by FDI ascending for stable display.
structured_teeth.sort(key=lambda t: t["fdi"])
annotations: list[tuple[tuple[int, int, int, int], str]] = []
md_lines: list[str] = ["## Per-tooth findings", ""]
json_out: list[dict[str, Any]] = []
teeth_with_findings = 0
for tooth in structured_teeth:
fdi = tooth["fdi"]
x1, y1, x2, y2 = (round(c) for c in tooth["bbox"])
cond_names = [c["class_name"] for c in tooth["conditions"]]
if cond_names:
teeth_with_findings += 1
label = f"{fdi}: {', '.join(cond_names)}"
md_lines.append(
f"- **Tooth {fdi}**: "
+ ", ".join(
f"{c['class_name']} ({c['score']:.3f})"
for c in tooth["conditions"]
)
)
# Only annotate teeth WITH findings β€” Teeth (FDI) tab already
# shows the full per-class argmax of stage-1.
annotations.append(((x1, y1, x2, y2), label))
json_out.append(
{
"fdi": fdi,
"tooth_score": round(tooth["score"], 3),
"bbox_xyxy": [round(c, 1) for c in tooth["bbox"]],
"is_wisdom": fdi in WISDOM_TEETH,
"conditions": [
{
"class_name": c["class_name"],
"score": round(c["score"], 4),
"category": _category_for(c["class_name"]),
"bbox_xyxy": [round(b, 1) for b in c["bbox"]],
}
for c in tooth["conditions"]
],
}
)
if teeth_with_findings == 0:
md_lines.append("- _No teeth had overlap-matched stage-2 findings._")
if suppressed_count > 0:
md_lines.append(
f"\n_Adult-OPG heuristic active: {suppressed_count} paediatric-class "
f"detection(s) (tooth-bud) suppressed because β‰₯{ADULT_OPG_FDI_THRESHOLD} "
f"FDIs were detected._"
)
# Top-N unassigned conditions (most-confident orphan stage-2 detections)
unassigned_sorted = sorted(unassigned, key=lambda c: -c["score"])[:UNASSIGNED_MAX_DISPLAY]
md_lines.append("")
md_lines.append("### Unassigned conditions (likely false positives)")
if not unassigned_sorted:
md_lines.append("- _None._")
else:
# Tally by class for a brief summary line.
tally: dict[str, int] = {}
for c in unassigned:
tally[c["class_name"]] = tally.get(c["class_name"], 0) + 1
tally_str = ", ".join(
f"{name} x {count}" for name, count in sorted(tally.items())
)
md_lines.append(f"_All unassigned (incl. low-rank): {tally_str}_")
md_lines.append("")
for c in unassigned_sorted:
md_lines.append(
f"- {c['class_name']} ({c['score']:.3f}) β€” bbox "
f"{[round(b, 1) for b in c['bbox']]}"
)
json_out.append(
{
"unassigned_top": [
{
"class_name": c["class_name"],
"score": round(c["score"], 4),
"bbox_xyxy": [round(b, 1) for b in c["bbox"]],
}
for c in unassigned_sorted
],
"unassigned_total": len(unassigned),
}
)
md = "\n".join(md_lines)
return (pil, annotations), md, json_out
@_gpu
def _analyze(
image: Image.Image | None, thr_s1: float, thr_s2: float
) -> tuple[
Any,
list[dict[str, Any]],
Any,
list[dict[str, Any]],
Any,
str,
list[dict[str, Any]],
]:
"""Run both stages and return (teeth_ann, teeth_json, conditions_ann,
conditions_json, per_tooth_ann, per_tooth_md, per_tooth_json)."""
if image is None:
return None, [], None, [], None, "", []
pil = image.convert("RGB")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
target_sizes = torch.tensor([[pil.height, pil.width]], device=device)
# Stage-1: FDI tooth detection
m1 = MODEL_S1.to(device)
in_s1 = PROC_S1(images=pil, return_tensors="pt").to(device)
with torch.no_grad():
out_s1 = m1(**in_s1)
res_s1 = PROC_S1.post_process_object_detection(
out_s1, threshold=thr_s1, target_sizes=target_sizes
)[0]
teeth_ann, teeth_json, teeth_raw = _stage1_outputs(res_s1, pil)
# Stage-2: condition detection (same image, fresh forward pass)
m2 = MODEL_S2.to(device)
in_s2 = PROC_S2(images=pil, return_tensors="pt").to(device)
with torch.no_grad():
out_s2 = m2(**in_s2)
res_s2 = PROC_S2.post_process_object_detection(
out_s2, threshold=thr_s2, target_sizes=target_sizes
)[0]
cond_ann, cond_json, cond_raw = _stage2_outputs(res_s2, pil)
# Per-tooth assembly (IoU-matching stage-2 onto stage-1)
per_tooth_ann, per_tooth_md, per_tooth_json = _per_tooth_outputs(
teeth_raw, cond_raw, pil
)
return (
teeth_ann,
teeth_json,
cond_ann,
cond_json,
per_tooth_ann,
per_tooth_md,
per_tooth_json,
)
def _generate_report(
per_tooth_json: list[dict[str, Any]] | None,
model: str,
):
"""Streaming generator that orchestrates the LLM AI-radiology-report call.
Yields:
1. A placeholder immediately (instant feedback).
2. The report content progressively as Ollama Cloud streams tokens β€”
each yield replaces the Markdown component with the latest cleaned
accumulated text. Reasoning models that emit a `<think>` block keep
showing the placeholder until `</think>` arrives.
3. The final cleaned text once the stream ends.
Yields error markdown rather than raising; Gradio shouldn't surface
stack traces to users.
"""
if not per_tooth_json:
yield (
"_No detection results available. Upload an OPG and click "
"**Analyze** first, then come back to this tab._"
)
return
api_key = os.environ.get("OLLAMA_API_KEY")
if not api_key:
yield (
"**Configuration error:** `OLLAMA_API_KEY` is not set on this Space. "
"An admin must add it as a Space secret before this tab can be used."
)
return
if model not in REPORT_MODEL_OPTIONS:
yield f"**Invalid model:** `{model}` is not in the allowed list."
return
placeholder = (
f"⏳ **Generating AI radiology report with `{model}`…**\n\n"
"Tokens will appear here as the model produces them. Cold Ollama "
"Cloud models take 30-60s on the first request, then ~5-10s for "
"subsequent ones. DeepSeek and Qwen variants are more verbose and "
"may take 20-40s end-to-end."
)
yield placeholder
caption = _build_grounding_caption(per_tooth_json)
accumulated = ""
last_yielded = placeholder
last_yield_t = time.monotonic()
try:
for chunk in _stream_ollama_chat(
model=model,
system=REPORT_SYSTEM_PROMPT,
user=caption,
api_key=api_key,
):
accumulated += chunk
now = time.monotonic()
if now - last_yield_t < PROGRESS_YIELD_INTERVAL_S:
continue # accumulate further; UI yield is throttled
view = _clean_streaming_view(accumulated)
if view is None or not view:
continue # still inside <think> block, keep placeholder
if view != last_yielded:
yield view
last_yielded = view
last_yield_t = now
except RuntimeError as e:
yield f"**Report generation failed.** {e}"
return
# Final yield: ensure the last accumulated content (which may have
# arrived inside the throttle window) is rendered.
final = _clean_llm_output(accumulated)
if not final:
yield (
f"**Empty response from `{model}`.** The model returned no content "
"after stripping its reasoning block. Try a different model."
)
return
if final != last_yielded:
yield final
with gr.Blocks(title="Argos-DentSight: FDI + Conditions Demo") as demo:
gr.Markdown(
"""
# Argos-DentSight β€” Two-Stage OPG Demo
Upload an OPG (orthopantomogram / panoramic dental X-ray) to run both
models on it.
- **Stage 1** β€” `Mobe1/argos-dentsight-stage1-fdi-v3` β€” 32 FDI tooth
positions. Trained eval `mAP@.5 = 0.853`.
- **Stage 2** β€” `Mobe1/argos-dentsight-stage2-conditions-v4` β€” 13 dental
condition labels (caries, calculus, RC-treated, impacted, restoration,
crown, periapical-radiolucency, root-stump, bridge, implant, tooth-bud,
missing, other-finding). Trained eval `mAP@.5 = 0.601` (warm-started
from v3 + DENTEX-c diagnosis labels at 1664Γ—928 input; biggest gains
on caries 0.035 β†’ 0.071 and impacted 0.685 β†’ 0.729). Caries / calculus
/ periapical-radiolucency still fail per-class clinical gates by
3–4Γ— β€” they are scale-bound, not classifier-bound.
> **About the confidence scores:** both models' classification weights
> are well-trained but their *biases* are stuck near the focal-loss
> prior, so absolute sigmoid scores are compressed to roughly
> `[0.02, 0.08]`. The *ranking* is correct β€” the demo deliberately uses
> low default thresholds and (for stage-1) per-class argmax so you see
> the model's actual top guess. **Treat scores as a confidence
> ordering, not a probability.**
> **Stage-2 caveat:** the bias-collapse issue is more severe here than
> on stage-1 β€” score range is roughly `[0.005, 0.015]` with little
> between-class separation. Eval mAP_50 = 0.53 means the model gets
> ranking right *on average across the validation set*, but on any
> single OPG it may surface confident-looking detections of conditions
> that aren't there. Strong classes (crown, RC-treated, bridge, implant,
> impacted, tooth-bud, root-stump) are more reliable than weak ones
> (caries, calculus, periapical-radiolucency, other-finding all failed
> the project's per-class acceptance gates by 3-15Γ— margins).
"""
)
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="Upload OPG", height=400)
thr_stage1 = gr.Slider(
minimum=0.0,
maximum=0.5,
value=DEFAULT_THR_STAGE1,
step=0.01,
label="Stage-1 threshold (FDI teeth)",
info="v3 score range ~[0.02, 0.08]; default 0.02 surfaces "
"per-class argmax for most teeth.",
)
thr_stage2 = gr.Slider(
minimum=0.0,
maximum=0.05,
value=DEFAULT_THR_STAGE2,
step=0.001,
label="Stage-2 threshold (conditions)",
info=f"v3 score range is compressed (~[0.005, 0.022]); default "
f"0.005 surfaces top guesses. Top {STAGE2_MAX_DETECTIONS} shown "
f"to avoid noise. Raise to filter; lower to see everything.",
)
run_button = gr.Button("Analyze", variant="primary")
with gr.Column(scale=2):
with gr.Tabs():
with gr.Tab("Teeth (FDI)"):
output_teeth = gr.AnnotatedImage(
label="Detected teeth (top-1 per FDI)",
height=420,
)
output_teeth_json = gr.JSON(label="Teeth detections (sorted by FDI)")
with gr.Tab("Conditions"):
output_cond = gr.AnnotatedImage(
label="Detected conditions",
height=420,
)
output_cond_json = gr.JSON(label="Condition detections (sorted by score)")
with gr.Tab("Per-tooth findings"):
gr.Markdown(
f"Stage-2 condition detections matched onto stage-1 tooth "
f"boxes by **IoMin overlap** (intersection / area-of-smaller-box "
f"β‰₯ {ASSIGNMENT_OVERLAP_THRESHOLD:.0%}; best-fit wins). IoMin is "
f"symmetric in box-size asymmetry, so it works whether the "
f"condition bbox is smaller (e.g. caries) or larger "
f"(e.g. RC-treated) than the tooth box. Only teeth **with "
f"findings** are highlighted in the image (the Teeth (FDI) tab "
f"shows all 32 detections separately). On adult OPGs "
f"(β‰₯{ADULT_OPG_FDI_THRESHOLD} detected FDIs) `tooth-bud` "
f"detections are suppressed β€” stage-2 misclassifies "
f"crowns/fillings as tooth-buds at scores comparable to "
f"legitimate findings. Top {UNASSIGNED_MAX_DISPLAY} unassigned "
f"detections still shown for transparency."
)
output_pertooth = gr.AnnotatedImage(
label="Teeth with assembled finding labels",
height=420,
)
output_pertooth_md = gr.Markdown(label="Per-tooth findings")
output_pertooth_json = gr.JSON(label="Per-tooth findings (structured)")
with gr.Tab("AI Radiology Report"):
gr.Markdown(
"**Generates a structured AI radiology report** by "
"feeding the per-tooth structured findings to a hosted LLM "
"(Ollama Cloud). Run **Analyze** first, then pick a model "
"and click **Generate report**. Switching the model alone "
"does *not* trigger a call β€” each click costs Ollama Cloud "
"credits.\n\n"
"_Score-rule guardrails_: confidence < 0.80 β†’ **suspicious "
"for…**; confidence β‰₯ 0.80 β†’ **sign of…**. Numerical scores "
"are not surfaced to the patient. Refusals on cost / "
"insurance / treatment-alternative questions are enforced "
"in the system prompt."
)
report_model = gr.Dropdown(
choices=REPORT_MODEL_OPTIONS,
value=DEFAULT_REPORT_MODEL,
label="Report model (Ollama Cloud)",
info="Larger models are slower but produce more "
"coherent clinical phrasing.",
)
report_button = gr.Button(
"Generate report", variant="primary"
)
output_report_md = gr.Markdown(
value="_Run **Analyze** first, then click **Generate report**._",
label="AI Radiology Report",
)
run_button.click(
fn=_analyze,
inputs=[input_image, thr_stage1, thr_stage2],
outputs=[
output_teeth,
output_teeth_json,
output_cond,
output_cond_json,
output_pertooth,
output_pertooth_md,
output_pertooth_json,
],
)
# _generate_report is a generator: it yields the placeholder first (so
# the user sees instant feedback that the request was received) and then
# the final report markdown. Gradio renders each yield as a Markdown
# update + automatically shows a busy state on the trigger button.
report_button.click(
fn=_generate_report,
inputs=[output_pertooth_json, report_model],
outputs=output_report_md,
)
gr.Markdown(
"""
### Reading FDI numbers
Two-digit code: first digit = quadrant (1=upper-right, 2=upper-left,
3=lower-left, 4=lower-right), second digit = position from midline
(1=central incisor … 8=third molar / wisdom). E.g. `26` = upper-left
first molar; `48` = lower-right wisdom tooth.
### Stage-1 known weak FDI classes (v3 eval)
- `48` (lower-right wisdom): mAP β‰ˆ 0.41
- `31`, `32`, `41` (lower central / lateral incisors): mAP β‰ˆ 0.55-0.58
- `17` (upper-right 2nd molar): mAP β‰ˆ 0.59 (regressed vs v2)
Detections in those classes deserve a second look.
### Stage-2 known weak condition classes (v2 eval β€” failing per-class floors)
- `caries` mAP 0.04, floor 0.30 β€” advisory only, expect FPs
- `calculus` mAP 0.01, floor 0.10 β€” advisory only
- `periapical-radiolucency` mAP 0.07, floor 0.25 β€” advisory only
- `other-finding` mAP 0.10, floor 0.20 β€” advisory only
Strong condition classes (mAP > 0.5): `crown`, `RC-treated`, `bridge`,
`implant`, `impacted`, `tooth-bud`, `root-stump`.
"""
)
if __name__ == "__main__":
demo.launch()