"""Compatibility adapter for the web app inference API. This module bridges the Flask app's expected interface to the improved inference utilities in download_imp/run_inference.py. """ from __future__ import annotations import os from pathlib import Path from typing import Any import cv2 import numpy as np import torch try: from groq import Groq except ImportError: Groq = None try: import cloudinary import cloudinary.uploader import cloudinary.api except ImportError: cloudinary = None from download_imp import run_inference as core ARCH = core.BACKBONE IMG_SIZE = core.IMG_SIZE SUBTYPES = core.SUBTYPES def _parse_fold_selection(value: str | None) -> str | int: """Parse fold selection from env-style values. Accepted values: "ensemble", "best", or an integer fold id. """ raw = (value or "ensemble").strip().lower() if raw in ("", "ensemble", "all"): return "ensemble" if raw == "best": # From B4 performance report per-fold any-AUC table. return 4 if raw.isdigit(): return int(raw) return "ensemble" class _Compose: def __init__(self, transforms: list[Any]): self.transforms = transforms def __call__(self, x: np.ndarray) -> torch.Tensor: out = x for t in self.transforms: out = t(out) return out class _ToPILImage: def __call__(self, x: np.ndarray) -> np.ndarray: # The web app pipeline does not require PIL specifically. return x class _ToTensor: def __call__(self, x: np.ndarray) -> torch.Tensor: arr = np.asarray(x, dtype=np.float32) if arr.ndim != 3: raise ValueError("Expected HWC image array") # Convert HWC -> CHW return torch.from_numpy(np.transpose(arr, (2, 0, 1))) class _Normalize: def __init__(self, mean: list[float], std: list[float]): self.mean = torch.tensor(mean, dtype=torch.float32).view(-1, 1, 1) self.std = torch.tensor(std, dtype=torch.float32).view(-1, 1, 1) def __call__(self, x: torch.Tensor) -> torch.Tensor: return (x - self.mean) / (self.std + 1e-7) class T: Compose = _Compose ToPILImage = _ToPILImage ToTensor = _ToTensor Normalize = _Normalize def build_model(_arch: str | None = None): return core.build_model() def load_runtime_models(device: str, fold_selection: str | None = None): """Load one or many fold models for web inference.""" parsed = _parse_fold_selection(fold_selection) models, loaded_folds = core.load_models(device, fold_selection=parsed) grad_cams = [GradCAM(m) for m in models] return models, grad_cams, loaded_folds class GradCAM(core.GradCAM): def __init__(self, model, _arch: str | None = None): super().__init__(model) def dicom_to_rgb(dcm_path: str, size: int = IMG_SIZE) -> np.ndarray: return core.load_single_dicom_3ch(Path(dcm_path), size=size) def infer_single( img_rgb: np.ndarray, model, grad_cam: GradCAM, transform, device: str, temperature: float, ) -> dict[str, Any]: return infer_batch([img_rgb], model, grad_cam, transform, device, temperature)[0] def infer_batch( images_rgb: list[np.ndarray], model, grad_cam: GradCAM, transform, device: str, temperature: float, ) -> list[dict[str, Any]]: # Build 3ch tensor from the app's transform pipeline, then tile to 9ch # because the trained model expects 2.5D channels. if device == "cuda": with torch.inference_mode(): t3 = torch.stack([transform(img) for img in images_rgb], dim=0).to(device) else: t3 = torch.stack([transform(img) for img in images_rgb], dim=0).to(device) t9 = torch.cat([t3, t3, t3], dim=1) if isinstance(model, list) and isinstance(grad_cam, list): fold_logits = [] fold_cams = [] for _m, cam_obj in zip(model, grad_cam): logits_i, cam_i = cam_obj.generate(t9, class_idx=0) fold_logits.append(logits_i) fold_cams.append(cam_i) logits = np.mean(np.stack(fold_logits, axis=0), axis=0) cam = np.mean(np.stack(fold_cams, axis=0), axis=0) else: logits, cam = grad_cam.generate(t9, class_idx=0) if len(images_rgb) == 1: logits = np.atleast_2d(logits) if cam.ndim == 2: cam = np.expand_dims(cam, axis=0) raw_probs = core.sigmoid_np(logits) cal_probs = core.sigmoid_np(logits / max(float(temperature), 1e-6)) results = [] for idx in range(len(images_rgb)): results.append({ "raw_logits": logits[idx], "raw_probs": raw_probs[idx], "cal_probs": cal_probs[idx], "raw_prob_any": float(np.atleast_1d(raw_probs[idx])[0]), "cal_prob_any": float(np.atleast_1d(cal_probs[idx])[0]), "cam": cam[idx], }) return results def generate_medical_summary(inference: dict[str, Any], calib_cfg: dict[str, Any], report: dict[str, Any]) -> str: """Generate a medical summary using Groq LLM API.""" if not Groq: return "LLM integration not available (groq package not installed)." groq_api_key = os.environ.get("GROQ_API_KEY") if not groq_api_key: return "LLM integration not configured (Missing GROQ_API_KEY)." try: client = Groq(api_key=groq_api_key) prob = float(inference.get("cal_prob_any", 0.0)) threshold = float(calib_cfg.get("threshold_at_spec90", 0.5)) is_positive = prob >= threshold triage = report.get("triage", {}) action = triage.get("action", "Unknown") urgency = triage.get("urgency", "Unknown") prompt = f""" You are an expert AI medical assistant analyzing a CT scan for Intracranial Hemorrhage. Scan Results: - Probability of Hemorrhage: {prob:.2%} - Decision Threshold: {threshold:.2%} - AI Assessment: {"Positive for Hemorrhage" if is_positive else "Negative for Hemorrhage"} - Urgency: {urgency} - Recommended Action: {action} Based on this data, write a concise, professional 3-sentence medical triage summary. Focus strictly on the AI's findings. Do not hallucinate patient data. """ model_name = os.environ.get("LLM_MODEL", "llama-3.1-8b-instant") response = client.chat.completions.create( messages=[{"role": "user", "content": prompt}], model=model_name, temperature=0.2, max_tokens=150, ) return response.choices[0].message.content.strip() except Exception as e: return f"Failed to generate LLM summary: {str(e)}" def build_report( image_id: str, inference: dict[str, Any], calib_cfg: dict[str, Any], reports_dir: Path, img_rgb: np.ndarray, true_label: int | None = None, ) -> dict[str, Any]: reports_dir.mkdir(parents=True, exist_ok=True) preview_path = reports_dir / f"{image_id}_preview.png" heatmap_path = reports_dir / f"{image_id}_gradcam.png" rgb_u8 = (np.clip(img_rgb, 0.0, 1.0) * 255.0).astype(np.uint8) cv2.imwrite(str(preview_path), cv2.cvtColor(rgb_u8, cv2.COLOR_RGB2BGR)) overlay_rgb = core.make_overlay(rgb_u8, inference["cam"], alpha=0.45) cv2.imwrite(str(heatmap_path), cv2.cvtColor(overlay_rgb, cv2.COLOR_RGB2BGR)) probs_dict = { name: float(inference["cal_probs"][idx]) for idx, name in enumerate(SUBTYPES) } threshold = float(calib_cfg.get("threshold_at_spec90", 0.5)) report = core.build_slice_report( image_id=image_id, patient_id="UNKNOWN", probs=probs_dict, calib_cfg=calib_cfg, threshold=threshold, loaded_folds=[0], report_image_path=str(preview_path), heatmap_path=str(heatmap_path), true_label=true_label, ) report.setdefault("prediction", {}) report["prediction"]["decision_threshold"] = report["prediction"].get("decision_threshold_any", threshold) report["prediction"]["raw_probability"] = round(float(inference["raw_prob_any"]), 6) report["prediction"]["calibrated_probability"] = round(float(inference["cal_prob_any"]), 6) report["llm_summary"] = generate_medical_summary(inference, calib_cfg, report) groq_api_key = os.environ.get("GROQ_API_KEY") if Groq and groq_api_key: report["llm_provider"] = "groq" report["llm_model"] = os.environ.get("LLM_MODEL", "llama-3.1-8b-instant") # Cloudinary Integration cloud_name = os.environ.get("CLOUDINARY_CLOUD_NAME") api_key = os.environ.get("CLOUDINARY_API_KEY") api_secret = os.environ.get("CLOUDINARY_API_SECRET") if cloudinary and cloud_name and api_key and api_secret: try: cloudinary.config( cloud_name=cloud_name, api_key=api_key, api_secret=api_secret, secure=True ) # Upload preview preview_res = cloudinary.uploader.upload(str(preview_path), folder="ich_previews") report["cloudinary_preview_url"] = preview_res.get("secure_url") # Upload heatmap heatmap_res = cloudinary.uploader.upload(str(heatmap_path), folder="ich_heatmaps") report["cloudinary_heatmap_url"] = heatmap_res.get("secure_url") # Delete local copies to save space since we have them in the cloud preview_path.unlink(missing_ok=True) heatmap_path.unlink(missing_ok=True) except Exception as e: print(f"Cloudinary upload failed: {e}") return report