ICH-Detection-Pipeline / run_interface.py
Harshit Ghosh
fix timezone and model scaler variable
f772e3c
"""Compatibility adapter for the web app inference API.
This module bridges the Flask app's expected interface to the improved
inference utilities in download_imp/run_inference.py.
"""
from __future__ import annotations
import os
from pathlib import Path
from typing import Any
import cv2
import numpy as np
import torch
try:
from groq import Groq
except ImportError:
Groq = None
try:
import cloudinary
import cloudinary.uploader
import cloudinary.api
except ImportError:
cloudinary = None
from download_imp import run_inference as core
ARCH = core.BACKBONE
IMG_SIZE = core.IMG_SIZE
SUBTYPES = core.SUBTYPES
def _parse_fold_selection(value: str | None) -> str | int:
"""Parse fold selection from env-style values.
Accepted values: "ensemble", "best", or an integer fold id.
"""
raw = (value or "ensemble").strip().lower()
if raw in ("", "ensemble", "all"):
return "ensemble"
if raw == "best":
# From B4 performance report per-fold any-AUC table.
return 4
if raw.isdigit():
return int(raw)
return "ensemble"
class _Compose:
def __init__(self, transforms: list[Any]):
self.transforms = transforms
def __call__(self, x: np.ndarray) -> torch.Tensor:
out = x
for t in self.transforms:
out = t(out)
return out
class _ToPILImage:
def __call__(self, x: np.ndarray) -> np.ndarray:
# The web app pipeline does not require PIL specifically.
return x
class _ToTensor:
def __call__(self, x: np.ndarray) -> torch.Tensor:
arr = np.asarray(x, dtype=np.float32)
if arr.ndim != 3:
raise ValueError("Expected HWC image array")
# Convert HWC -> CHW
return torch.from_numpy(np.transpose(arr, (2, 0, 1)))
class _Normalize:
def __init__(self, mean: list[float], std: list[float]):
self.mean = torch.tensor(mean, dtype=torch.float32).view(-1, 1, 1)
self.std = torch.tensor(std, dtype=torch.float32).view(-1, 1, 1)
def __call__(self, x: torch.Tensor) -> torch.Tensor:
return (x - self.mean) / (self.std + 1e-7)
class T:
Compose = _Compose
ToPILImage = _ToPILImage
ToTensor = _ToTensor
Normalize = _Normalize
def build_model(_arch: str | None = None):
return core.build_model()
def load_runtime_models(device: str, fold_selection: str | None = None):
"""Load one or many fold models for web inference."""
parsed = _parse_fold_selection(fold_selection)
models, loaded_folds = core.load_models(device, fold_selection=parsed)
grad_cams = [GradCAM(m) for m in models]
return models, grad_cams, loaded_folds
class GradCAM(core.GradCAM):
def __init__(self, model, _arch: str | None = None):
super().__init__(model)
def dicom_to_rgb(dcm_path: str, size: int = IMG_SIZE) -> np.ndarray:
return core.load_single_dicom_3ch(Path(dcm_path), size=size)
def infer_single(
img_rgb: np.ndarray,
model,
grad_cam: GradCAM,
transform,
device: str,
temperature: float,
) -> dict[str, Any]:
return infer_batch([img_rgb], model, grad_cam, transform, device, temperature)[0]
def infer_batch(
images_rgb: list[np.ndarray],
model,
grad_cam: GradCAM,
transform,
device: str,
temperature: float,
) -> list[dict[str, Any]]:
# Build 3ch tensor from the app's transform pipeline, then tile to 9ch
# because the trained model expects 2.5D channels.
if device == "cuda":
with torch.inference_mode():
t3 = torch.stack([transform(img) for img in images_rgb], dim=0).to(device)
else:
t3 = torch.stack([transform(img) for img in images_rgb], dim=0).to(device)
t9 = torch.cat([t3, t3, t3], dim=1)
if isinstance(model, list) and isinstance(grad_cam, list):
fold_logits = []
fold_cams = []
for _m, cam_obj in zip(model, grad_cam):
logits_i, cam_i = cam_obj.generate(t9, class_idx=0)
fold_logits.append(logits_i)
fold_cams.append(cam_i)
logits = np.mean(np.stack(fold_logits, axis=0), axis=0)
cam = np.mean(np.stack(fold_cams, axis=0), axis=0)
else:
logits, cam = grad_cam.generate(t9, class_idx=0)
if len(images_rgb) == 1:
logits = np.atleast_2d(logits)
if cam.ndim == 2:
cam = np.expand_dims(cam, axis=0)
raw_probs = core.sigmoid_np(logits)
cal_probs = core.sigmoid_np(logits / max(float(temperature), 1e-6))
results = []
for idx in range(len(images_rgb)):
results.append({
"raw_logits": logits[idx],
"raw_probs": raw_probs[idx],
"cal_probs": cal_probs[idx],
"raw_prob_any": float(np.atleast_1d(raw_probs[idx])[0]),
"cal_prob_any": float(np.atleast_1d(cal_probs[idx])[0]),
"cam": cam[idx],
})
return results
def generate_medical_summary(inference: dict[str, Any], calib_cfg: dict[str, Any], report: dict[str, Any]) -> str:
"""Generate a medical summary using Groq LLM API."""
if not Groq:
return "LLM integration not available (groq package not installed)."
groq_api_key = os.environ.get("GROQ_API_KEY")
if not groq_api_key:
return "LLM integration not configured (Missing GROQ_API_KEY)."
try:
client = Groq(api_key=groq_api_key)
prob = float(inference.get("cal_prob_any", 0.0))
threshold = float(calib_cfg.get("threshold_at_spec90", 0.5))
is_positive = prob >= threshold
triage = report.get("triage", {})
action = triage.get("action", "Unknown")
urgency = triage.get("urgency", "Unknown")
prompt = f"""
You are an expert AI medical assistant analyzing a CT scan for Intracranial Hemorrhage.
Scan Results:
- Probability of Hemorrhage: {prob:.2%}
- Decision Threshold: {threshold:.2%}
- AI Assessment: {"Positive for Hemorrhage" if is_positive else "Negative for Hemorrhage"}
- Urgency: {urgency}
- Recommended Action: {action}
Based on this data, write a concise, professional 3-sentence medical triage summary.
Focus strictly on the AI's findings. Do not hallucinate patient data.
"""
model_name = os.environ.get("LLM_MODEL", "llama-3.1-8b-instant")
response = client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model=model_name,
temperature=0.2,
max_tokens=150,
)
return response.choices[0].message.content.strip()
except Exception as e:
return f"Failed to generate LLM summary: {str(e)}"
def build_report(
image_id: str,
inference: dict[str, Any],
calib_cfg: dict[str, Any],
reports_dir: Path,
img_rgb: np.ndarray,
true_label: int | None = None,
) -> dict[str, Any]:
reports_dir.mkdir(parents=True, exist_ok=True)
preview_path = reports_dir / f"{image_id}_preview.png"
heatmap_path = reports_dir / f"{image_id}_gradcam.png"
rgb_u8 = (np.clip(img_rgb, 0.0, 1.0) * 255.0).astype(np.uint8)
cv2.imwrite(str(preview_path), cv2.cvtColor(rgb_u8, cv2.COLOR_RGB2BGR))
overlay_rgb = core.make_overlay(rgb_u8, inference["cam"], alpha=0.45)
cv2.imwrite(str(heatmap_path), cv2.cvtColor(overlay_rgb, cv2.COLOR_RGB2BGR))
probs_dict = {
name: float(inference["cal_probs"][idx])
for idx, name in enumerate(SUBTYPES)
}
threshold = float(calib_cfg.get("threshold_at_spec90", 0.5))
report = core.build_slice_report(
image_id=image_id,
patient_id="UNKNOWN",
probs=probs_dict,
calib_cfg=calib_cfg,
threshold=threshold,
loaded_folds=[0],
report_image_path=str(preview_path),
heatmap_path=str(heatmap_path),
true_label=true_label,
)
report.setdefault("prediction", {})
report["prediction"]["decision_threshold"] = report["prediction"].get("decision_threshold_any", threshold)
report["prediction"]["raw_probability"] = round(float(inference["raw_prob_any"]), 6)
report["prediction"]["calibrated_probability"] = round(float(inference["cal_prob_any"]), 6)
report["llm_summary"] = generate_medical_summary(inference, calib_cfg, report)
groq_api_key = os.environ.get("GROQ_API_KEY")
if Groq and groq_api_key:
report["llm_provider"] = "groq"
report["llm_model"] = os.environ.get("LLM_MODEL", "llama-3.1-8b-instant")
# Cloudinary Integration
cloud_name = os.environ.get("CLOUDINARY_CLOUD_NAME")
api_key = os.environ.get("CLOUDINARY_API_KEY")
api_secret = os.environ.get("CLOUDINARY_API_SECRET")
if cloudinary and cloud_name and api_key and api_secret:
try:
cloudinary.config(
cloud_name=cloud_name,
api_key=api_key,
api_secret=api_secret,
secure=True
)
# Upload preview
preview_res = cloudinary.uploader.upload(str(preview_path), folder="ich_previews")
report["cloudinary_preview_url"] = preview_res.get("secure_url")
# Upload heatmap
heatmap_res = cloudinary.uploader.upload(str(heatmap_path), folder="ich_heatmaps")
report["cloudinary_heatmap_url"] = heatmap_res.get("secure_url")
# Delete local copies to save space since we have them in the cloud
preview_path.unlink(missing_ok=True)
heatmap_path.unlink(missing_ok=True)
except Exception as e:
print(f"Cloudinary upload failed: {e}")
return report