| """Compatibility adapter for the web app inference API. |
| |
| This module bridges the Flask app's expected interface to the improved |
| inference utilities in download_imp/run_inference.py. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import os |
| from pathlib import Path |
| from typing import Any |
|
|
| import cv2 |
| import numpy as np |
| import torch |
|
|
| try: |
| from groq import Groq |
| except ImportError: |
| Groq = None |
|
|
| try: |
| import cloudinary |
| import cloudinary.uploader |
| import cloudinary.api |
| except ImportError: |
| cloudinary = None |
|
|
| from download_imp import run_inference as core |
|
|
| ARCH = core.BACKBONE |
| IMG_SIZE = core.IMG_SIZE |
| SUBTYPES = core.SUBTYPES |
|
|
|
|
| def _parse_fold_selection(value: str | None) -> str | int: |
| """Parse fold selection from env-style values. |
| |
| Accepted values: "ensemble", "best", or an integer fold id. |
| """ |
| raw = (value or "ensemble").strip().lower() |
| if raw in ("", "ensemble", "all"): |
| return "ensemble" |
| if raw == "best": |
| |
| return 4 |
| if raw.isdigit(): |
| return int(raw) |
| return "ensemble" |
|
|
|
|
| class _Compose: |
| def __init__(self, transforms: list[Any]): |
| self.transforms = transforms |
|
|
| def __call__(self, x: np.ndarray) -> torch.Tensor: |
| out = x |
| for t in self.transforms: |
| out = t(out) |
| return out |
|
|
|
|
| class _ToPILImage: |
| def __call__(self, x: np.ndarray) -> np.ndarray: |
| |
| return x |
|
|
|
|
| class _ToTensor: |
| def __call__(self, x: np.ndarray) -> torch.Tensor: |
| arr = np.asarray(x, dtype=np.float32) |
| if arr.ndim != 3: |
| raise ValueError("Expected HWC image array") |
| |
| return torch.from_numpy(np.transpose(arr, (2, 0, 1))) |
|
|
|
|
| class _Normalize: |
| def __init__(self, mean: list[float], std: list[float]): |
| self.mean = torch.tensor(mean, dtype=torch.float32).view(-1, 1, 1) |
| self.std = torch.tensor(std, dtype=torch.float32).view(-1, 1, 1) |
|
|
| def __call__(self, x: torch.Tensor) -> torch.Tensor: |
| return (x - self.mean) / (self.std + 1e-7) |
|
|
|
|
| class T: |
| Compose = _Compose |
| ToPILImage = _ToPILImage |
| ToTensor = _ToTensor |
| Normalize = _Normalize |
|
|
|
|
| def build_model(_arch: str | None = None): |
| return core.build_model() |
|
|
|
|
| def load_runtime_models(device: str, fold_selection: str | None = None): |
| """Load one or many fold models for web inference.""" |
| parsed = _parse_fold_selection(fold_selection) |
| models, loaded_folds = core.load_models(device, fold_selection=parsed) |
| grad_cams = [GradCAM(m) for m in models] |
| return models, grad_cams, loaded_folds |
|
|
|
|
| class GradCAM(core.GradCAM): |
| def __init__(self, model, _arch: str | None = None): |
| super().__init__(model) |
|
|
|
|
| def dicom_to_rgb(dcm_path: str, size: int = IMG_SIZE) -> np.ndarray: |
| return core.load_single_dicom_3ch(Path(dcm_path), size=size) |
|
|
|
|
| def infer_single( |
| img_rgb: np.ndarray, |
| model, |
| grad_cam: GradCAM, |
| transform, |
| device: str, |
| temperature: float, |
| ) -> dict[str, Any]: |
| return infer_batch([img_rgb], model, grad_cam, transform, device, temperature)[0] |
|
|
|
|
| def infer_batch( |
| images_rgb: list[np.ndarray], |
| model, |
| grad_cam: GradCAM, |
| transform, |
| device: str, |
| temperature: float, |
| ) -> list[dict[str, Any]]: |
| |
| |
| if device == "cuda": |
| with torch.inference_mode(): |
| t3 = torch.stack([transform(img) for img in images_rgb], dim=0).to(device) |
| else: |
| t3 = torch.stack([transform(img) for img in images_rgb], dim=0).to(device) |
| t9 = torch.cat([t3, t3, t3], dim=1) |
| if isinstance(model, list) and isinstance(grad_cam, list): |
| fold_logits = [] |
| fold_cams = [] |
| for _m, cam_obj in zip(model, grad_cam): |
| logits_i, cam_i = cam_obj.generate(t9, class_idx=0) |
| fold_logits.append(logits_i) |
| fold_cams.append(cam_i) |
| logits = np.mean(np.stack(fold_logits, axis=0), axis=0) |
| cam = np.mean(np.stack(fold_cams, axis=0), axis=0) |
| else: |
| logits, cam = grad_cam.generate(t9, class_idx=0) |
|
|
| if len(images_rgb) == 1: |
| logits = np.atleast_2d(logits) |
| if cam.ndim == 2: |
| cam = np.expand_dims(cam, axis=0) |
|
|
| raw_probs = core.sigmoid_np(logits) |
| cal_probs = core.sigmoid_np(logits / max(float(temperature), 1e-6)) |
|
|
| results = [] |
| for idx in range(len(images_rgb)): |
| results.append({ |
| "raw_logits": logits[idx], |
| "raw_probs": raw_probs[idx], |
| "cal_probs": cal_probs[idx], |
| "raw_prob_any": float(np.atleast_1d(raw_probs[idx])[0]), |
| "cal_prob_any": float(np.atleast_1d(cal_probs[idx])[0]), |
| "cam": cam[idx], |
| }) |
| return results |
|
|
|
|
| def generate_medical_summary(inference: dict[str, Any], calib_cfg: dict[str, Any], report: dict[str, Any]) -> str: |
| """Generate a medical summary using Groq LLM API.""" |
| if not Groq: |
| return "LLM integration not available (groq package not installed)." |
| |
| groq_api_key = os.environ.get("GROQ_API_KEY") |
| if not groq_api_key: |
| return "LLM integration not configured (Missing GROQ_API_KEY)." |
|
|
| try: |
| client = Groq(api_key=groq_api_key) |
| |
| prob = float(inference.get("cal_prob_any", 0.0)) |
| threshold = float(calib_cfg.get("threshold_at_spec90", 0.5)) |
| is_positive = prob >= threshold |
| |
| triage = report.get("triage", {}) |
| action = triage.get("action", "Unknown") |
| urgency = triage.get("urgency", "Unknown") |
| |
| prompt = f""" |
| You are an expert AI medical assistant analyzing a CT scan for Intracranial Hemorrhage. |
| |
| Scan Results: |
| - Probability of Hemorrhage: {prob:.2%} |
| - Decision Threshold: {threshold:.2%} |
| - AI Assessment: {"Positive for Hemorrhage" if is_positive else "Negative for Hemorrhage"} |
| - Urgency: {urgency} |
| - Recommended Action: {action} |
| |
| Based on this data, write a concise, professional 3-sentence medical triage summary. |
| Focus strictly on the AI's findings. Do not hallucinate patient data. |
| """ |
| |
| model_name = os.environ.get("LLM_MODEL", "llama-3.1-8b-instant") |
| response = client.chat.completions.create( |
| messages=[{"role": "user", "content": prompt}], |
| model=model_name, |
| temperature=0.2, |
| max_tokens=150, |
| ) |
| |
| return response.choices[0].message.content.strip() |
| except Exception as e: |
| return f"Failed to generate LLM summary: {str(e)}" |
|
|
|
|
| def build_report( |
| image_id: str, |
| inference: dict[str, Any], |
| calib_cfg: dict[str, Any], |
| reports_dir: Path, |
| img_rgb: np.ndarray, |
| true_label: int | None = None, |
| ) -> dict[str, Any]: |
| reports_dir.mkdir(parents=True, exist_ok=True) |
|
|
| preview_path = reports_dir / f"{image_id}_preview.png" |
| heatmap_path = reports_dir / f"{image_id}_gradcam.png" |
|
|
| rgb_u8 = (np.clip(img_rgb, 0.0, 1.0) * 255.0).astype(np.uint8) |
| cv2.imwrite(str(preview_path), cv2.cvtColor(rgb_u8, cv2.COLOR_RGB2BGR)) |
|
|
| overlay_rgb = core.make_overlay(rgb_u8, inference["cam"], alpha=0.45) |
| cv2.imwrite(str(heatmap_path), cv2.cvtColor(overlay_rgb, cv2.COLOR_RGB2BGR)) |
|
|
| probs_dict = { |
| name: float(inference["cal_probs"][idx]) |
| for idx, name in enumerate(SUBTYPES) |
| } |
| threshold = float(calib_cfg.get("threshold_at_spec90", 0.5)) |
|
|
| report = core.build_slice_report( |
| image_id=image_id, |
| patient_id="UNKNOWN", |
| probs=probs_dict, |
| calib_cfg=calib_cfg, |
| threshold=threshold, |
| loaded_folds=[0], |
| report_image_path=str(preview_path), |
| heatmap_path=str(heatmap_path), |
| true_label=true_label, |
| ) |
|
|
| report.setdefault("prediction", {}) |
| report["prediction"]["decision_threshold"] = report["prediction"].get("decision_threshold_any", threshold) |
| report["prediction"]["raw_probability"] = round(float(inference["raw_prob_any"]), 6) |
| report["prediction"]["calibrated_probability"] = round(float(inference["cal_prob_any"]), 6) |
|
|
| report["llm_summary"] = generate_medical_summary(inference, calib_cfg, report) |
|
|
| groq_api_key = os.environ.get("GROQ_API_KEY") |
| if Groq and groq_api_key: |
| report["llm_provider"] = "groq" |
| report["llm_model"] = os.environ.get("LLM_MODEL", "llama-3.1-8b-instant") |
|
|
| |
| cloud_name = os.environ.get("CLOUDINARY_CLOUD_NAME") |
| api_key = os.environ.get("CLOUDINARY_API_KEY") |
| api_secret = os.environ.get("CLOUDINARY_API_SECRET") |
| |
| if cloudinary and cloud_name and api_key and api_secret: |
| try: |
| cloudinary.config( |
| cloud_name=cloud_name, |
| api_key=api_key, |
| api_secret=api_secret, |
| secure=True |
| ) |
| |
| |
| preview_res = cloudinary.uploader.upload(str(preview_path), folder="ich_previews") |
| report["cloudinary_preview_url"] = preview_res.get("secure_url") |
| |
| |
| heatmap_res = cloudinary.uploader.upload(str(heatmap_path), folder="ich_heatmaps") |
| report["cloudinary_heatmap_url"] = heatmap_res.get("secure_url") |
| |
| |
| preview_path.unlink(missing_ok=True) |
| heatmap_path.unlink(missing_ok=True) |
| |
| except Exception as e: |
| print(f"Cloudinary upload failed: {e}") |
|
|
| return report |
|
|