Spaces:
Running on Zero
Running on Zero
| #!/usr/bin/env python3 | |
| # /// script | |
| # requires-python = ">=3.11,<3.12" | |
| # dependencies = [ | |
| # "gradio>=5.34.0,<6.0", | |
| # "transformers>=4.45.0,<5.0", | |
| # "torch>=2.5.0,<3.0", | |
| # "pillow>=10.0.0,<12.0", | |
| # "huggingface-hub>=0.25.0,<1.0", | |
| # "httpx>=0.27.0,<1.0", | |
| # ] | |
| # /// | |
| """Gradio demo: upload an OPG, see argos-dentsight stage-1 + stage-2 detections. | |
| Local run: | |
| uv run demo/app.py | |
| HF Spaces deploy: see demo/README.md (rename to root + push). | |
| Two private model repos are loaded: | |
| - `Mobe1/argos-dentsight-stage1-fdi-v3` β 32 FDI tooth localizer | |
| - `Mobe1/argos-dentsight-stage2-conditions-v4` β 13 dental-condition detector | |
| Both repos are private, so the demo needs HF_TOKEN with read access. Locally | |
| that means a `.env` at the repo root containing HF_TOKEN=...; on HF Spaces it | |
| means setting HF_TOKEN as a Space secret. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import re | |
| import time | |
| from typing import Any | |
| import gradio as gr | |
| import httpx | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoImageProcessor, AutoModelForObjectDetection | |
| try: | |
| import spaces # type: ignore[import-not-found] | |
| _gpu = spaces.GPU | |
| except ImportError: | |
| # Local dev without HF Spaces ZeroGPU: decorator is a no-op. | |
| def _gpu(fn=None, **_kwargs): | |
| if callable(fn): | |
| return fn | |
| return lambda f: f | |
| REPO_ID_STAGE1 = "Mobe1/argos-dentsight-stage1-fdi-v3" | |
| REPO_ID_STAGE2 = "Mobe1/argos-dentsight-stage2-conditions-v4" | |
| DEFAULT_THR_STAGE1 = 0.02 # v3 score range ~[0.02, 0.08] | |
| DEFAULT_THR_STAGE2 = 0.005 # v3 cold-start widened the band slightly but still ~[0.005, 0.022] | |
| STAGE2_MAX_DETECTIONS = 30 # cap displayed boxes β model emits 300 raw, most are noise | |
| UNASSIGNED_MAX_DISPLAY = 5 # cap on top-N unassigned conditions surfaced to the user | |
| # --------------------------------------------------------------------------- | |
| # Vendored from src/argos_dentsight/postproc/{constants.py,geometry.py,assignment.py} | |
| # Demo runs as a self-contained UV-style file; we copy the small subset we need | |
| # rather than packaging src/ into the Space build. | |
| # --------------------------------------------------------------------------- | |
| FDI_CLASSES: frozenset[str] = frozenset( | |
| { | |
| "11", "12", "13", "14", "15", "16", "17", "18", | |
| "21", "22", "23", "24", "25", "26", "27", "28", | |
| "31", "32", "33", "34", "35", "36", "37", "38", | |
| "41", "42", "43", "44", "45", "46", "47", "48", | |
| } | |
| ) | |
| WISDOM_TEETH: frozenset[str] = frozenset({"18", "28", "38", "48"}) | |
| PATHOLOGICAL_CLASSES: frozenset[str] = frozenset( | |
| { | |
| "caries", | |
| "calculus", | |
| "impacted", | |
| "periapical-radiolucency", | |
| "root-stump", | |
| "missing", | |
| "other-finding", | |
| } | |
| ) | |
| TREATMENT_CLASSES: frozenset[str] = frozenset( | |
| {"RC-treated", "crown", "restoration", "bridge", "implant", "tooth-bud"} | |
| ) | |
| ASSIGNMENT_OVERLAP_THRESHOLD: float = 0.5 # IoMin β₯ 0.5 (smaller box β₯50% inside larger) | |
| # Paediatric-class suppression: see filtering.suppress_paediatric_classes_on_adult_opgs. | |
| # Stage-2 misclassifies bright high-density material (crowns, fillings) as | |
| # `tooth-bud` at scores up to 0.028 β comparable to legitimate detections. | |
| # `tooth-bud` is paediatric anatomy; on an adult OPG (β₯20 detected FDIs) | |
| # we drop them. Below the threshold (mixed dentition) we leave them alone. | |
| ADULT_OPG_FDI_THRESHOLD: int = 20 | |
| PAEDIATRIC_CLASSES: frozenset[str] = frozenset({"tooth-bud"}) | |
| def _suppress_paediatric_on_adult( | |
| teeth_preds: list[dict[str, Any]], | |
| cond_preds: list[dict[str, Any]], | |
| ) -> list[dict[str, Any]]: | |
| if len(teeth_preds) < ADULT_OPG_FDI_THRESHOLD: | |
| return list(cond_preds) | |
| return [c for c in cond_preds if c["class_name"] not in PAEDIATRIC_CLASSES] | |
| def _intersection_area(box1: list[float], box2: list[float]) -> float: | |
| x1, y1, x2, y2 = box1 | |
| x1_p, y1_p, x2_p, y2_p = box2 | |
| inter_x1 = max(x1, x1_p) | |
| inter_y1 = max(y1, y1_p) | |
| inter_x2 = min(x2, x2_p) | |
| inter_y2 = min(y2, y2_p) | |
| return max(0.0, inter_x2 - inter_x1) * max(0.0, inter_y2 - inter_y1) | |
| def _box_area(box: list[float]) -> float: | |
| x1, y1, x2, y2 = box | |
| return max(0.0, x2 - x1) * max(0.0, y2 - y1) | |
| def _compute_iomin(box1: list[float], box2: list[float]) -> float: | |
| """Intersection over min(area1, area2). Symmetric, range [0, 1]. | |
| Vendored from postproc.geometry.compute_iomin. Replaces the earlier | |
| one-directional containment metric because RC-treated detections often | |
| have bboxes slightly larger than the tooth β one-directional containment | |
| (intersection / cond_area) drops below 0.5 even when the cond fully | |
| covers the tooth. IoMin always normalizes by the smaller box, so the | |
| metric is symmetric in box-size asymmetry. | |
| """ | |
| a1 = _box_area(box1) | |
| a2 = _box_area(box2) | |
| min_area = min(a1, a2) | |
| if min_area <= 0.0: | |
| return 0.0 | |
| return _intersection_area(box1, box2) / min_area | |
| def _assign_diseases_to_teeth( | |
| teeth_preds: list[dict[str, Any]], | |
| cond_preds: list[dict[str, Any]], | |
| ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: | |
| """Vendored from postproc.assignment.assign_diseases_to_teeth. | |
| Each condition is assigned to whichever tooth has the highest IoMin | |
| overlap (intersection / area-of-smaller-box), provided IoMin β₯ | |
| `ASSIGNMENT_OVERLAP_THRESHOLD`. `implant` is left unassigned per source | |
| rule. | |
| Returns: | |
| structured_teeth: list of {fdi, score, bbox, conditions: [...]} per tooth | |
| unassigned: list of condition preds with no qualifying overlap match | |
| """ | |
| structured_teeth: list[dict[str, Any]] = [ | |
| { | |
| "fdi": tooth["fdi"], | |
| "score": tooth["score"], | |
| "bbox": tooth["bbox"], | |
| "conditions": [], | |
| } | |
| for tooth in teeth_preds | |
| ] | |
| assigned_flags = [False] * len(cond_preds) | |
| for i, cond in enumerate(cond_preds): | |
| if cond["class_name"] == "implant": | |
| continue | |
| best_idx: int | None = None | |
| best_overlap = 0.0 | |
| for j, tooth in enumerate(teeth_preds): | |
| o = _compute_iomin(cond["bbox"], tooth["bbox"]) | |
| if o > best_overlap: | |
| best_overlap = o | |
| best_idx = j | |
| if best_idx is not None and best_overlap >= ASSIGNMENT_OVERLAP_THRESHOLD: | |
| structured_teeth[best_idx]["conditions"].append(cond) | |
| assigned_flags[i] = True | |
| unassigned = [ | |
| c for i, c in enumerate(cond_preds) if not assigned_flags[i] | |
| ] | |
| return structured_teeth, unassigned | |
| # --------------------------------------------------------------------------- | |
| # Phase 8 β LLM medical-report generation (vendored from src/argos_dentsight/llm/) | |
| # | |
| # Same vendoring strategy as the postproc block: copy the small subset the | |
| # Space needs rather than packaging src/ into the build. If you change either | |
| # the prompt or the grounding-caption format here, mirror the change in | |
| # src/argos_dentsight/llm/{prompts.py,grounding_caption.py} or the offline | |
| # batch outputs will diverge. | |
| # --------------------------------------------------------------------------- | |
| OLLAMA_ENDPOINT: str = "https://ollama.com/api/chat" | |
| OLLAMA_TIMEOUT_S: float = 180.0 # gemma4:31b can take 30-90s on long captions | |
| REPORT_MODEL_OPTIONS: list[str] = [ | |
| # Each entry empirically verified against https://ollama.com/api/chat with | |
| # a 2-token probe. Tags listed on Ollama's model pages but not served on | |
| # the cloud chat endpoint (e.g. gemma4:26b, gemma4:e4b) 404 here even | |
| # though they exist for local `ollama run`. | |
| "gemma4:31b", # ~7s, balanced default | |
| "deepseek-v4-flash", # fastest reasoning option (~1s warm) | |
| "deepseek-v4-pro", # most thorough; verbose (~28s end-to-end) | |
| "kimi-k2.6", # multimodal-capable, ~2s warm | |
| "gemini-3-flash-preview", # Google flash variant, fast | |
| ] | |
| DEFAULT_REPORT_MODEL: str = "gemma4:31b" | |
| # Generator-throttle interval. The model emits ~22ms/token (gemma) or ~15ms/ | |
| # token (deepseek). Yielding to Gradio per-chunk = 50-70 UI updates/sec, | |
| # which the browser's Markdown re-layout cannot keep up with as the | |
| # accumulated text grows β perceived as a "slowdown midway". Coalescing | |
| # into ~14 yields/sec keeps the stream visually smooth without changing | |
| # total time-to-completion. | |
| PROGRESS_YIELD_INTERVAL_S: float = 0.07 | |
| # Vendored verbatim from src/argos_dentsight/llm/prompts.py::REPORT_SYSTEM_PROMPT | |
| REPORT_SYSTEM_PROMPT: str = """ | |
| You are a professional oral radiologist assistant tasked with generating precise and clinically accurate oral panoramic X-ray examination reports based on structured localization data. | |
| The structured data contains all detected teeth and dental conditions. Each condition is associated with a specific tooth number. | |
| If a finding is not directly on a tooth, it will have 'tooth_id': 'unknown' and a 'near_tooth': '[tooth_id]' field, which you should report as "near tooth #[tooth_id]". | |
| Generate a formal and comprehensive oral examination report **ONLY** containing two mandatory sections: | |
| 1. **Teeth-Specific Observations** | |
| 2. **Clinical Summary & Recommendations** | |
| The **Teeth-Specific Observations** section must comprise three subsections: | |
| 1. **General Condition**: Outlines overall dental status, including the count of visualized teeth and wisdom teeth status (e.g., presence or impaction). | |
| 2. **Pathological Findings**: Documents dental diseases such as caries, impacted teeth, calculus, or periapical radiolucency. | |
| 3. **Historical Interventions**: Details prior treatments like fillings (restorations), crowns, root canal treatments, or implants. | |
| Each finding in the structured data has a confidence score. You must apply the following processing rules **ONLY** for the **Pathological Findings** subsection: | |
| * For confidence scores **< 0.80**: Use terms like "suspicious for...", "suggests...", or "areas of concern noted for..." in the description. | |
| * For confidence scores **β₯ 0.80**: Use definitive descriptors such as "sign of...", "shows evidence of...", or "clear indication of...". | |
| The **Historical Interventions** subsection should always use definitive language (e.g., "presence of a crown," "rc-treated tooth noted"), as these are observed facts. | |
| Please strictly follow the following requirements: | |
| * **Adherence to FDI numbering system** (e.g., "#11", "#26"). | |
| * **Use professional medical terminology** while maintaining clarity. | |
| * **DO NOT** include or reference the confidence scores in any form in the final report. Their *only* use is to determine the certainty language ("suspicious" vs. "sign of"). | |
| * **DO NOT** generate any administrative content like 'Patient Name', 'Date', etc. | |
| * **Generate a new Clinical Summary & Recommendations** section. This section is critical and must be created from the findings. It must include: | |
| 1. **Priority Concerns**: The most urgent issues found (e.g., "Deep caries on #28", "Impacted wisdom tooth #18 requiring evaluation"). | |
| 2. **Preventive Measures**: Recommendations for prevention (e.g., "Monitor areas of suspected calculus", "Reinforce oral hygiene"). | |
| 3. **Follow-up Protocol**: Specific recall or follow-up actions (e.g., "6-month recall for monitoring", "Referral to endodontist for #26"). | |
| Now, generate a new report for the following input: | |
| """ | |
| def _clean_llm_output(raw_text: str) -> str: | |
| """Strip a leading `<think>...</think>` block (Gemma 4 reasoning trace). | |
| Vendored from src/argos_dentsight/llm/output_cleaner.py::clean_llm_output. | |
| No-op for non-thinking models. | |
| """ | |
| match = re.search(r"</think>(.*)", raw_text, re.DOTALL | re.IGNORECASE) | |
| if match: | |
| return match.group(1).strip() | |
| return raw_text.strip() | |
| def _format_section(title: str, items: list) -> str: | |
| if not items: | |
| return f"{title} (total: 0):\n[\n]" | |
| section_title = f"{title} (total: {len(items)}):" | |
| formatted = [ | |
| " " + json.dumps(item, ensure_ascii=False).replace('"', "'") for item in items | |
| ] | |
| return f"{section_title}\n[\n" + ",\n".join(formatted) + "\n]" | |
| def _build_grounding_caption(per_tooth_json: list[dict[str, Any]]) -> str: | |
| """Render the LLM-ready structured caption from the demo's per-tooth JSON. | |
| The demo's per_tooth_json is the output of `_per_tooth_outputs` β a list of | |
| one dict per detected tooth followed by a single trailing dict | |
| `{unassigned_top, unassigned_total}`. This function adapts that shape into | |
| the multi-section caption format expected by `REPORT_SYSTEM_PROMPT`. | |
| Sections emitted (in order): | |
| - Teeth visibility with center points | |
| - Wisdom teeth detection | |
| - Missing Teeth (FDIs not detected β derived; see caveat in app docstring) | |
| - Dental Pathological Findings | |
| - Historical Treatments | |
| """ | |
| teeth = [t for t in per_tooth_json if "fdi" in t] | |
| trailing = next( | |
| (t for t in per_tooth_json if "unassigned_top" in t), | |
| {"unassigned_top": [], "unassigned_total": 0}, | |
| ) | |
| unassigned = trailing.get("unassigned_top", []) | |
| # Teeth visibility (center points) | |
| visibility = [] | |
| for t in teeth: | |
| x1, y1, x2, y2 = t["bbox_xyxy"] | |
| visibility.append( | |
| { | |
| "point_2d": [round((x1 + x2) / 2), round((y1 + y2) / 2)], | |
| "tooth_id": t["fdi"], | |
| "score": round(t.get("tooth_score", 0), 2), | |
| } | |
| ) | |
| # Wisdom teeth (subset of detected teeth that are 18/28/38/48) | |
| wisdom = [] | |
| for t in teeth: | |
| if t["fdi"] not in WISDOM_TEETH: | |
| continue | |
| is_impacted = any(c["class_name"] == "impacted" for c in t.get("conditions", [])) | |
| x1, y1, x2, y2 = t["bbox_xyxy"] | |
| wisdom.append( | |
| { | |
| "box_2d": [round(x1), round(y1), round(x2), round(y2)], | |
| "tooth_id": t["fdi"], | |
| "is_impacted": is_impacted, | |
| "score": round(t.get("tooth_score", 0), 2), | |
| } | |
| ) | |
| # Teeth Not Visualized: FDIs not present in the detector's output. | |
| # Renamed from spec Β§8's "Missing Teeth" because empirically the LLM treats | |
| # that label as a clinical absence-finding. With stage-1's known weak | |
| # classes (e.g. FDI 48 mAP=0.41) and bias-collapse-compressed scores, an | |
| # FDI absence in the list is "model didn't surface it", NOT "patient is | |
| # missing it" β a distinction worth keeping rigid for the LLM. | |
| detected_fdis = {t["fdi"] for t in teeth} | |
| missing_items = [{"tooth_id": fdi} for fdi in sorted(FDI_CLASSES - detected_fdis)] | |
| # Build a quick lookup of each tooth's center for `near_tooth` derivation | |
| # (used on unassigned conditions; LLM is told to phrase those as | |
| # "near tooth #X" via the canonical prompt). | |
| tooth_centers = [ | |
| ( | |
| t["fdi"], | |
| (t["bbox_xyxy"][0] + t["bbox_xyxy"][2]) / 2, | |
| (t["bbox_xyxy"][1] + t["bbox_xyxy"][3]) / 2, | |
| ) | |
| for t in teeth | |
| ] | |
| def _closest_fdi(box: list[float]) -> str | None: | |
| if not tooth_centers: | |
| return None | |
| cx, cy = (box[0] + box[2]) / 2, (box[1] + box[3]) / 2 | |
| best_fdi: str | None = None | |
| best_d = float("inf") | |
| for fdi, tx, ty in tooth_centers: | |
| d = (cx - tx) ** 2 + (cy - ty) ** 2 | |
| if d < best_d: | |
| best_d, best_fdi = d, fdi | |
| return best_fdi | |
| # Pathological + Historical, walking each tooth's conditions and the | |
| # unassigned tail. Each tooth's conditions inherit that tooth's FDI; | |
| # unassigned ones get tooth_id='unknown' + a near_tooth pointer. | |
| pathological: list[dict[str, Any]] = [] | |
| historical: list[dict[str, Any]] = [] | |
| for t in teeth: | |
| for c in t.get("conditions", []): | |
| entry = { | |
| "box_2d": [round(b) for b in c["bbox_xyxy"]], | |
| "tooth_id": t["fdi"], | |
| "label": c["class_name"], | |
| "score": round(c.get("score", 0), 2), | |
| } | |
| if c["class_name"] in PATHOLOGICAL_CLASSES: | |
| pathological.append(entry) | |
| elif c["class_name"] in TREATMENT_CLASSES: | |
| historical.append(entry) | |
| for c in unassigned: | |
| entry: dict[str, Any] = { | |
| "box_2d": [round(b) for b in c["bbox_xyxy"]], | |
| "tooth_id": "unknown", | |
| "label": c["class_name"], | |
| "score": round(c.get("score", 0), 2), | |
| } | |
| nearest = _closest_fdi(c["bbox_xyxy"]) | |
| if nearest is not None: | |
| entry["near_tooth"] = nearest | |
| if c["class_name"] in PATHOLOGICAL_CLASSES: | |
| pathological.append(entry) | |
| elif c["class_name"] in TREATMENT_CLASSES: | |
| historical.append(entry) | |
| parts: list[str] = [ | |
| "This localization caption provides multi-dimensional spatial analysis " | |
| "of the patient's panoramic dental radiograph.", | |
| _format_section("Teeth visibility with center points", visibility), | |
| _format_section("Wisdom teeth detection", wisdom), | |
| _format_section( | |
| "Teeth Not Visualized in This Image (model did not surface a " | |
| "detection β could be genuinely absent or below detection threshold)", | |
| missing_items, | |
| ), | |
| _format_section("Dental Pathological Findings", pathological), | |
| _format_section("Historical Treatments", historical), | |
| ] | |
| return "\n\n".join(parts) | |
| def _stream_ollama_chat( | |
| *, model: str, system: str, user: str, api_key: str | |
| ): | |
| """Generator: yield assistant content chunks from Ollama Cloud's streaming | |
| chat endpoint. | |
| Each yielded chunk is the *incremental* text since the last chunk (not | |
| cumulative). Callers accumulate. Raises `RuntimeError` on non-200 or | |
| transport failure with a user-safe message (key never leaked). The | |
| advantage over the previous non-streaming call: users see the report | |
| appearing token-by-token instead of waiting 8-30s for it to land at once, | |
| which on cold Ollama Cloud models can stretch to a minute+. | |
| """ | |
| headers = { | |
| "Authorization": f"Bearer {api_key}", | |
| "Content-Type": "application/json", | |
| } | |
| payload = { | |
| "model": model, | |
| "messages": [ | |
| {"role": "system", "content": system}, | |
| {"role": "user", "content": user}, | |
| ], | |
| "stream": True, | |
| } | |
| try: | |
| client = httpx.Client(timeout=OLLAMA_TIMEOUT_S) | |
| try: | |
| with client.stream( | |
| "POST", OLLAMA_ENDPOINT, headers=headers, json=payload | |
| ) as r: | |
| if r.status_code != 200: | |
| body_bytes = b"".join(r.iter_bytes()) | |
| try: | |
| err = json.loads(body_bytes).get("error", "") | |
| except Exception: | |
| err = body_bytes.decode("utf-8", "replace")[:200] | |
| raise RuntimeError( | |
| f"Ollama Cloud HTTP {r.status_code} on `{model}`: {err}" | |
| ) | |
| for line in r.iter_lines(): | |
| if not line: | |
| continue | |
| try: | |
| chunk = json.loads(line) | |
| except json.JSONDecodeError: | |
| continue | |
| content = chunk.get("message", {}).get("content", "") | |
| if content: | |
| yield content | |
| if chunk.get("done"): | |
| break | |
| finally: | |
| client.close() | |
| except httpx.TimeoutException as e: | |
| raise RuntimeError( | |
| f"Ollama Cloud timed out after {OLLAMA_TIMEOUT_S:.0f}s on `{model}`. " | |
| "Try a smaller model (e.g. gemma4:26b) or rerun." | |
| ) from e | |
| except httpx.HTTPError as e: | |
| raise RuntimeError(f"Network error calling Ollama Cloud: {type(e).__name__}") from e | |
| def _clean_streaming_view(accumulated: str) -> str | None: | |
| """Cleaned output for progressive display while a stream is in flight. | |
| Returns: | |
| - `None` if we're still inside an open `<think>` block (caller should | |
| keep the existing placeholder visible). | |
| - cleaned text otherwise (everything after `</think>`, or the raw | |
| accumulated text if no think block is present). | |
| """ | |
| lower = accumulated.lower() | |
| if "<think>" in lower and "</think>" not in lower: | |
| return None | |
| return _clean_llm_output(accumulated) | |
| def _load_token() -> str | None: | |
| """Read HF_TOKEN from env. Falls back to a .env file at the repo root.""" | |
| token = os.environ.get("HF_TOKEN") | |
| if token: | |
| return token | |
| env_path = os.path.join(os.path.dirname(__file__), "..", ".env") | |
| if os.path.exists(env_path): | |
| with open(env_path) as f: | |
| for line in f: | |
| line = line.strip() | |
| if line.startswith("HF_TOKEN="): | |
| return line.split("=", 1)[1].strip().strip("\"'") | |
| return None | |
| print(f"Loading stage-1: {REPO_ID_STAGE1}...") | |
| TOKEN = _load_token() | |
| PROC_S1 = AutoImageProcessor.from_pretrained(REPO_ID_STAGE1, token=TOKEN) | |
| MODEL_S1 = AutoModelForObjectDetection.from_pretrained(REPO_ID_STAGE1, token=TOKEN) | |
| MODEL_S1.eval() | |
| ID2LABEL_S1 = MODEL_S1.config.id2label | |
| print(f"Loading stage-2: {REPO_ID_STAGE2}...") | |
| PROC_S2 = AutoImageProcessor.from_pretrained(REPO_ID_STAGE2, token=TOKEN) | |
| MODEL_S2 = AutoModelForObjectDetection.from_pretrained(REPO_ID_STAGE2, token=TOKEN) | |
| MODEL_S2.eval() | |
| ID2LABEL_S2 = MODEL_S2.config.id2label | |
| print( | |
| f"Loaded. Stage-1: {len(ID2LABEL_S1)} FDIs. " | |
| f"Stage-2: {len(ID2LABEL_S2)} conditions." | |
| ) | |
| def _stage1_outputs( | |
| res: dict[str, Any], pil: Image.Image | |
| ) -> tuple[Any, list[dict[str, Any]], list[dict[str, Any]]]: | |
| """Per-class top-1 dedup for stage-1: one box per FDI (the class argmax). | |
| Returns (annotated_image_tuple, json_for_ui, raw_preds_for_assignment). | |
| The 3rd return is internal-only β predictions in postproc-compatible shape: | |
| {fdi, score, bbox: [x1, y1, x2, y2]} β used for IoU assignment. | |
| """ | |
| boxes = res["boxes"].cpu().tolist() | |
| scores = res["scores"].cpu().tolist() | |
| labels = res["labels"].cpu().tolist() | |
| best: dict[int, tuple[float, list[float]]] = {} | |
| for box, score, lid in zip(boxes, scores, labels, strict=True): | |
| cur = best.get(int(lid)) | |
| if cur is None or score > cur[0]: | |
| best[int(lid)] = (score, box) | |
| annotations: list[tuple[tuple[int, int, int, int], str]] = [] | |
| detections: list[dict[str, Any]] = [] | |
| raw_preds: list[dict[str, Any]] = [] | |
| for cid, (score, box) in best.items(): | |
| x0, y0, x1, y1 = (round(c) for c in box) | |
| fdi = ID2LABEL_S1.get(int(cid), str(cid)) | |
| annotations.append(((x0, y0, x1, y1), f"{fdi} ({score:.2f})")) | |
| detections.append( | |
| { | |
| "fdi": fdi, | |
| "score": round(score, 3), | |
| "bbox_xyxy": [round(c, 1) for c in box], | |
| } | |
| ) | |
| raw_preds.append({"fdi": fdi, "score": float(score), "bbox": list(box)}) | |
| detections.sort(key=lambda d: d["fdi"]) | |
| return (pil, annotations), detections, raw_preds | |
| def _stage2_outputs( | |
| res: dict[str, Any], pil: Image.Image | |
| ) -> tuple[Any, list[dict[str, Any]], list[dict[str, Any]]]: | |
| """Stage-2: keep top-N detections above threshold, sorted by score. | |
| No per-class dedup β multiple instances per class are valid (multiple | |
| crowns, multiple caries lesions). Cap at STAGE2_MAX_DETECTIONS so the UI | |
| doesn't drown in low-confidence noise. | |
| Returns (annotated_image_tuple, json_for_ui, raw_preds_for_assignment). | |
| raw_preds shape: {class_name, score, bbox: [x1, y1, x2, y2]}. | |
| """ | |
| boxes = res["boxes"].cpu().tolist() | |
| scores = res["scores"].cpu().tolist() | |
| labels = res["labels"].cpu().tolist() | |
| # Sort by score desc, take top N. | |
| indexed = sorted( | |
| zip(boxes, scores, labels, strict=True), key=lambda t: -t[1] | |
| )[:STAGE2_MAX_DETECTIONS] | |
| annotations: list[tuple[tuple[int, int, int, int], str]] = [] | |
| detections: list[dict[str, Any]] = [] | |
| raw_preds: list[dict[str, Any]] = [] | |
| for box, score, lid in indexed: | |
| x0, y0, x1, y1 = (round(c) for c in box) | |
| cname = ID2LABEL_S2.get(int(lid), str(lid)) | |
| annotations.append(((x0, y0, x1, y1), f"{cname} ({score:.3f})")) | |
| detections.append( | |
| { | |
| "condition": cname, | |
| "score": round(score, 4), | |
| "bbox_xyxy": [round(c, 1) for c in box], | |
| } | |
| ) | |
| raw_preds.append( | |
| {"class_name": cname, "score": float(score), "bbox": list(box)} | |
| ) | |
| return (pil, annotations), detections, raw_preds | |
| def _category_for(class_name: str) -> str: | |
| """Classify a stage-2 condition for the per-tooth panel readout.""" | |
| if class_name in PATHOLOGICAL_CLASSES: | |
| return "pathological" | |
| if class_name in TREATMENT_CLASSES: | |
| return "treatment" | |
| return "other" | |
| def _per_tooth_outputs( | |
| teeth_preds: list[dict[str, Any]], | |
| cond_preds: list[dict[str, Any]], | |
| pil: Image.Image, | |
| ) -> tuple[Any, str, list[dict[str, Any]]]: | |
| """Build the per-tooth findings tab outputs. | |
| Pipeline: | |
| 1. Suppress paediatric classes (tooth-bud) on adult OPGs. | |
| 2. Containment-match conditions to teeth (best-fit wins). | |
| 3. Render annotations only for teeth WITH findings (image not cluttered | |
| by 32 redundant tooth boxes β the Teeth (FDI) tab already shows those). | |
| 4. Emit JSON for downstream LLM-report consumption: one entry per tooth | |
| (regardless of whether it has findings β the LLM needs to know what | |
| was visualized) plus a trailing unassigned block. | |
| """ | |
| suppressed_conds = _suppress_paediatric_on_adult(teeth_preds, cond_preds) | |
| suppressed_count = len(cond_preds) - len(suppressed_conds) | |
| structured_teeth, unassigned = _assign_diseases_to_teeth(teeth_preds, suppressed_conds) | |
| # Sort teeth by FDI ascending for stable display. | |
| structured_teeth.sort(key=lambda t: t["fdi"]) | |
| annotations: list[tuple[tuple[int, int, int, int], str]] = [] | |
| md_lines: list[str] = ["## Per-tooth findings", ""] | |
| json_out: list[dict[str, Any]] = [] | |
| teeth_with_findings = 0 | |
| for tooth in structured_teeth: | |
| fdi = tooth["fdi"] | |
| x1, y1, x2, y2 = (round(c) for c in tooth["bbox"]) | |
| cond_names = [c["class_name"] for c in tooth["conditions"]] | |
| if cond_names: | |
| teeth_with_findings += 1 | |
| label = f"{fdi}: {', '.join(cond_names)}" | |
| md_lines.append( | |
| f"- **Tooth {fdi}**: " | |
| + ", ".join( | |
| f"{c['class_name']} ({c['score']:.3f})" | |
| for c in tooth["conditions"] | |
| ) | |
| ) | |
| # Only annotate teeth WITH findings β Teeth (FDI) tab already | |
| # shows the full per-class argmax of stage-1. | |
| annotations.append(((x1, y1, x2, y2), label)) | |
| json_out.append( | |
| { | |
| "fdi": fdi, | |
| "tooth_score": round(tooth["score"], 3), | |
| "bbox_xyxy": [round(c, 1) for c in tooth["bbox"]], | |
| "is_wisdom": fdi in WISDOM_TEETH, | |
| "conditions": [ | |
| { | |
| "class_name": c["class_name"], | |
| "score": round(c["score"], 4), | |
| "category": _category_for(c["class_name"]), | |
| "bbox_xyxy": [round(b, 1) for b in c["bbox"]], | |
| } | |
| for c in tooth["conditions"] | |
| ], | |
| } | |
| ) | |
| if teeth_with_findings == 0: | |
| md_lines.append("- _No teeth had overlap-matched stage-2 findings._") | |
| if suppressed_count > 0: | |
| md_lines.append( | |
| f"\n_Adult-OPG heuristic active: {suppressed_count} paediatric-class " | |
| f"detection(s) (tooth-bud) suppressed because β₯{ADULT_OPG_FDI_THRESHOLD} " | |
| f"FDIs were detected._" | |
| ) | |
| # Top-N unassigned conditions (most-confident orphan stage-2 detections) | |
| unassigned_sorted = sorted(unassigned, key=lambda c: -c["score"])[:UNASSIGNED_MAX_DISPLAY] | |
| md_lines.append("") | |
| md_lines.append("### Unassigned conditions (likely false positives)") | |
| if not unassigned_sorted: | |
| md_lines.append("- _None._") | |
| else: | |
| # Tally by class for a brief summary line. | |
| tally: dict[str, int] = {} | |
| for c in unassigned: | |
| tally[c["class_name"]] = tally.get(c["class_name"], 0) + 1 | |
| tally_str = ", ".join( | |
| f"{name} x {count}" for name, count in sorted(tally.items()) | |
| ) | |
| md_lines.append(f"_All unassigned (incl. low-rank): {tally_str}_") | |
| md_lines.append("") | |
| for c in unassigned_sorted: | |
| md_lines.append( | |
| f"- {c['class_name']} ({c['score']:.3f}) β bbox " | |
| f"{[round(b, 1) for b in c['bbox']]}" | |
| ) | |
| json_out.append( | |
| { | |
| "unassigned_top": [ | |
| { | |
| "class_name": c["class_name"], | |
| "score": round(c["score"], 4), | |
| "bbox_xyxy": [round(b, 1) for b in c["bbox"]], | |
| } | |
| for c in unassigned_sorted | |
| ], | |
| "unassigned_total": len(unassigned), | |
| } | |
| ) | |
| md = "\n".join(md_lines) | |
| return (pil, annotations), md, json_out | |
| def _analyze( | |
| image: Image.Image | None, thr_s1: float, thr_s2: float | |
| ) -> tuple[ | |
| Any, | |
| list[dict[str, Any]], | |
| Any, | |
| list[dict[str, Any]], | |
| Any, | |
| str, | |
| list[dict[str, Any]], | |
| ]: | |
| """Run both stages and return (teeth_ann, teeth_json, conditions_ann, | |
| conditions_json, per_tooth_ann, per_tooth_md, per_tooth_json).""" | |
| if image is None: | |
| return None, [], None, [], None, "", [] | |
| pil = image.convert("RGB") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| target_sizes = torch.tensor([[pil.height, pil.width]], device=device) | |
| # Stage-1: FDI tooth detection | |
| m1 = MODEL_S1.to(device) | |
| in_s1 = PROC_S1(images=pil, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| out_s1 = m1(**in_s1) | |
| res_s1 = PROC_S1.post_process_object_detection( | |
| out_s1, threshold=thr_s1, target_sizes=target_sizes | |
| )[0] | |
| teeth_ann, teeth_json, teeth_raw = _stage1_outputs(res_s1, pil) | |
| # Stage-2: condition detection (same image, fresh forward pass) | |
| m2 = MODEL_S2.to(device) | |
| in_s2 = PROC_S2(images=pil, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| out_s2 = m2(**in_s2) | |
| res_s2 = PROC_S2.post_process_object_detection( | |
| out_s2, threshold=thr_s2, target_sizes=target_sizes | |
| )[0] | |
| cond_ann, cond_json, cond_raw = _stage2_outputs(res_s2, pil) | |
| # Per-tooth assembly (IoU-matching stage-2 onto stage-1) | |
| per_tooth_ann, per_tooth_md, per_tooth_json = _per_tooth_outputs( | |
| teeth_raw, cond_raw, pil | |
| ) | |
| return ( | |
| teeth_ann, | |
| teeth_json, | |
| cond_ann, | |
| cond_json, | |
| per_tooth_ann, | |
| per_tooth_md, | |
| per_tooth_json, | |
| ) | |
| def _generate_report( | |
| per_tooth_json: list[dict[str, Any]] | None, | |
| model: str, | |
| ): | |
| """Streaming generator that orchestrates the LLM AI-radiology-report call. | |
| Yields: | |
| 1. A placeholder immediately (instant feedback). | |
| 2. The report content progressively as Ollama Cloud streams tokens β | |
| each yield replaces the Markdown component with the latest cleaned | |
| accumulated text. Reasoning models that emit a `<think>` block keep | |
| showing the placeholder until `</think>` arrives. | |
| 3. The final cleaned text once the stream ends. | |
| Yields error markdown rather than raising; Gradio shouldn't surface | |
| stack traces to users. | |
| """ | |
| if not per_tooth_json: | |
| yield ( | |
| "_No detection results available. Upload an OPG and click " | |
| "**Analyze** first, then come back to this tab._" | |
| ) | |
| return | |
| api_key = os.environ.get("OLLAMA_API_KEY") | |
| if not api_key: | |
| yield ( | |
| "**Configuration error:** `OLLAMA_API_KEY` is not set on this Space. " | |
| "An admin must add it as a Space secret before this tab can be used." | |
| ) | |
| return | |
| if model not in REPORT_MODEL_OPTIONS: | |
| yield f"**Invalid model:** `{model}` is not in the allowed list." | |
| return | |
| placeholder = ( | |
| f"β³ **Generating AI radiology report with `{model}`β¦**\n\n" | |
| "Tokens will appear here as the model produces them. Cold Ollama " | |
| "Cloud models take 30-60s on the first request, then ~5-10s for " | |
| "subsequent ones. DeepSeek and Qwen variants are more verbose and " | |
| "may take 20-40s end-to-end." | |
| ) | |
| yield placeholder | |
| caption = _build_grounding_caption(per_tooth_json) | |
| accumulated = "" | |
| last_yielded = placeholder | |
| last_yield_t = time.monotonic() | |
| try: | |
| for chunk in _stream_ollama_chat( | |
| model=model, | |
| system=REPORT_SYSTEM_PROMPT, | |
| user=caption, | |
| api_key=api_key, | |
| ): | |
| accumulated += chunk | |
| now = time.monotonic() | |
| if now - last_yield_t < PROGRESS_YIELD_INTERVAL_S: | |
| continue # accumulate further; UI yield is throttled | |
| view = _clean_streaming_view(accumulated) | |
| if view is None or not view: | |
| continue # still inside <think> block, keep placeholder | |
| if view != last_yielded: | |
| yield view | |
| last_yielded = view | |
| last_yield_t = now | |
| except RuntimeError as e: | |
| yield f"**Report generation failed.** {e}" | |
| return | |
| # Final yield: ensure the last accumulated content (which may have | |
| # arrived inside the throttle window) is rendered. | |
| final = _clean_llm_output(accumulated) | |
| if not final: | |
| yield ( | |
| f"**Empty response from `{model}`.** The model returned no content " | |
| "after stripping its reasoning block. Try a different model." | |
| ) | |
| return | |
| if final != last_yielded: | |
| yield final | |
| with gr.Blocks(title="Argos-DentSight: FDI + Conditions Demo") as demo: | |
| gr.Markdown( | |
| """ | |
| # Argos-DentSight β Two-Stage OPG Demo | |
| Upload an OPG (orthopantomogram / panoramic dental X-ray) to run both | |
| models on it. | |
| - **Stage 1** β `Mobe1/argos-dentsight-stage1-fdi-v3` β 32 FDI tooth | |
| positions. Trained eval `mAP@.5 = 0.853`. | |
| - **Stage 2** β `Mobe1/argos-dentsight-stage2-conditions-v4` β 13 dental | |
| condition labels (caries, calculus, RC-treated, impacted, restoration, | |
| crown, periapical-radiolucency, root-stump, bridge, implant, tooth-bud, | |
| missing, other-finding). Trained eval `mAP@.5 = 0.601` (warm-started | |
| from v3 + DENTEX-c diagnosis labels at 1664Γ928 input; biggest gains | |
| on caries 0.035 β 0.071 and impacted 0.685 β 0.729). Caries / calculus | |
| / periapical-radiolucency still fail per-class clinical gates by | |
| 3β4Γ β they are scale-bound, not classifier-bound. | |
| > **About the confidence scores:** both models' classification weights | |
| > are well-trained but their *biases* are stuck near the focal-loss | |
| > prior, so absolute sigmoid scores are compressed to roughly | |
| > `[0.02, 0.08]`. The *ranking* is correct β the demo deliberately uses | |
| > low default thresholds and (for stage-1) per-class argmax so you see | |
| > the model's actual top guess. **Treat scores as a confidence | |
| > ordering, not a probability.** | |
| > **Stage-2 caveat:** the bias-collapse issue is more severe here than | |
| > on stage-1 β score range is roughly `[0.005, 0.015]` with little | |
| > between-class separation. Eval mAP_50 = 0.53 means the model gets | |
| > ranking right *on average across the validation set*, but on any | |
| > single OPG it may surface confident-looking detections of conditions | |
| > that aren't there. Strong classes (crown, RC-treated, bridge, implant, | |
| > impacted, tooth-bud, root-stump) are more reliable than weak ones | |
| > (caries, calculus, periapical-radiolucency, other-finding all failed | |
| > the project's per-class acceptance gates by 3-15Γ margins). | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_image = gr.Image(type="pil", label="Upload OPG", height=400) | |
| thr_stage1 = gr.Slider( | |
| minimum=0.0, | |
| maximum=0.5, | |
| value=DEFAULT_THR_STAGE1, | |
| step=0.01, | |
| label="Stage-1 threshold (FDI teeth)", | |
| info="v3 score range ~[0.02, 0.08]; default 0.02 surfaces " | |
| "per-class argmax for most teeth.", | |
| ) | |
| thr_stage2 = gr.Slider( | |
| minimum=0.0, | |
| maximum=0.05, | |
| value=DEFAULT_THR_STAGE2, | |
| step=0.001, | |
| label="Stage-2 threshold (conditions)", | |
| info=f"v3 score range is compressed (~[0.005, 0.022]); default " | |
| f"0.005 surfaces top guesses. Top {STAGE2_MAX_DETECTIONS} shown " | |
| f"to avoid noise. Raise to filter; lower to see everything.", | |
| ) | |
| run_button = gr.Button("Analyze", variant="primary") | |
| with gr.Column(scale=2): | |
| with gr.Tabs(): | |
| with gr.Tab("Teeth (FDI)"): | |
| output_teeth = gr.AnnotatedImage( | |
| label="Detected teeth (top-1 per FDI)", | |
| height=420, | |
| ) | |
| output_teeth_json = gr.JSON(label="Teeth detections (sorted by FDI)") | |
| with gr.Tab("Conditions"): | |
| output_cond = gr.AnnotatedImage( | |
| label="Detected conditions", | |
| height=420, | |
| ) | |
| output_cond_json = gr.JSON(label="Condition detections (sorted by score)") | |
| with gr.Tab("Per-tooth findings"): | |
| gr.Markdown( | |
| f"Stage-2 condition detections matched onto stage-1 tooth " | |
| f"boxes by **IoMin overlap** (intersection / area-of-smaller-box " | |
| f"β₯ {ASSIGNMENT_OVERLAP_THRESHOLD:.0%}; best-fit wins). IoMin is " | |
| f"symmetric in box-size asymmetry, so it works whether the " | |
| f"condition bbox is smaller (e.g. caries) or larger " | |
| f"(e.g. RC-treated) than the tooth box. Only teeth **with " | |
| f"findings** are highlighted in the image (the Teeth (FDI) tab " | |
| f"shows all 32 detections separately). On adult OPGs " | |
| f"(β₯{ADULT_OPG_FDI_THRESHOLD} detected FDIs) `tooth-bud` " | |
| f"detections are suppressed β stage-2 misclassifies " | |
| f"crowns/fillings as tooth-buds at scores comparable to " | |
| f"legitimate findings. Top {UNASSIGNED_MAX_DISPLAY} unassigned " | |
| f"detections still shown for transparency." | |
| ) | |
| output_pertooth = gr.AnnotatedImage( | |
| label="Teeth with assembled finding labels", | |
| height=420, | |
| ) | |
| output_pertooth_md = gr.Markdown(label="Per-tooth findings") | |
| output_pertooth_json = gr.JSON(label="Per-tooth findings (structured)") | |
| with gr.Tab("AI Radiology Report"): | |
| gr.Markdown( | |
| "**Generates a structured AI radiology report** by " | |
| "feeding the per-tooth structured findings to a hosted LLM " | |
| "(Ollama Cloud). Run **Analyze** first, then pick a model " | |
| "and click **Generate report**. Switching the model alone " | |
| "does *not* trigger a call β each click costs Ollama Cloud " | |
| "credits.\n\n" | |
| "_Score-rule guardrails_: confidence < 0.80 β **suspicious " | |
| "forβ¦**; confidence β₯ 0.80 β **sign ofβ¦**. Numerical scores " | |
| "are not surfaced to the patient. Refusals on cost / " | |
| "insurance / treatment-alternative questions are enforced " | |
| "in the system prompt." | |
| ) | |
| report_model = gr.Dropdown( | |
| choices=REPORT_MODEL_OPTIONS, | |
| value=DEFAULT_REPORT_MODEL, | |
| label="Report model (Ollama Cloud)", | |
| info="Larger models are slower but produce more " | |
| "coherent clinical phrasing.", | |
| ) | |
| report_button = gr.Button( | |
| "Generate report", variant="primary" | |
| ) | |
| output_report_md = gr.Markdown( | |
| value="_Run **Analyze** first, then click **Generate report**._", | |
| label="AI Radiology Report", | |
| ) | |
| run_button.click( | |
| fn=_analyze, | |
| inputs=[input_image, thr_stage1, thr_stage2], | |
| outputs=[ | |
| output_teeth, | |
| output_teeth_json, | |
| output_cond, | |
| output_cond_json, | |
| output_pertooth, | |
| output_pertooth_md, | |
| output_pertooth_json, | |
| ], | |
| ) | |
| # _generate_report is a generator: it yields the placeholder first (so | |
| # the user sees instant feedback that the request was received) and then | |
| # the final report markdown. Gradio renders each yield as a Markdown | |
| # update + automatically shows a busy state on the trigger button. | |
| report_button.click( | |
| fn=_generate_report, | |
| inputs=[output_pertooth_json, report_model], | |
| outputs=output_report_md, | |
| ) | |
| gr.Markdown( | |
| """ | |
| ### Reading FDI numbers | |
| Two-digit code: first digit = quadrant (1=upper-right, 2=upper-left, | |
| 3=lower-left, 4=lower-right), second digit = position from midline | |
| (1=central incisor β¦ 8=third molar / wisdom). E.g. `26` = upper-left | |
| first molar; `48` = lower-right wisdom tooth. | |
| ### Stage-1 known weak FDI classes (v3 eval) | |
| - `48` (lower-right wisdom): mAP β 0.41 | |
| - `31`, `32`, `41` (lower central / lateral incisors): mAP β 0.55-0.58 | |
| - `17` (upper-right 2nd molar): mAP β 0.59 (regressed vs v2) | |
| Detections in those classes deserve a second look. | |
| ### Stage-2 known weak condition classes (v2 eval β failing per-class floors) | |
| - `caries` mAP 0.04, floor 0.30 β advisory only, expect FPs | |
| - `calculus` mAP 0.01, floor 0.10 β advisory only | |
| - `periapical-radiolucency` mAP 0.07, floor 0.25 β advisory only | |
| - `other-finding` mAP 0.10, floor 0.20 β advisory only | |
| Strong condition classes (mAP > 0.5): `crown`, `RC-treated`, `bridge`, | |
| `implant`, `impacted`, `tooth-bud`, `root-stump`. | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |