hackathon / src /models /mri_dl_2d.py
mekosotto's picture
fix(mri): try weights_only=True first; fall back only for trusted module pickles
07b00eb
"""Pretrained 2D MRI Alzheimer's classifier (resnet18, 4 classes).
Decision-layer bridge for an externally-trained PyTorch checkpoint. Loads
either a state_dict (default) or a full pickled model, applies the trainer's
preprocessing (resize image_size=160, ImageNet normalisation), and emits the
same dict shape as src.models.mri_model.predict_with_proba so downstream
code paths don't care which backend produced the prediction.
"""
from __future__ import annotations
import pickle
from pathlib import Path
from typing import Any
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from torchvision import models, transforms
from src.core.logger import get_logger
logger = get_logger(__name__)
CLASS_TO_IDX: dict[str, int] = {
"MildDemented": 0,
"ModerateDemented": 1,
"NonDemented": 2,
"VeryMildDemented": 3,
}
IDX_TO_CLASS: dict[int, str] = {v: k for k, v in CLASS_TO_IDX.items()}
DEFAULT_IMAGE_SIZE = 160
_IMAGENET_MEAN = (0.485, 0.456, 0.406)
_IMAGENET_STD = (0.229, 0.224, 0.225)
_TRANSFORM = transforms.Compose([
transforms.Resize((DEFAULT_IMAGE_SIZE, DEFAULT_IMAGE_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=_IMAGENET_MEAN, std=_IMAGENET_STD),
])
def _build_resnet18_4class() -> nn.Module:
model = models.resnet18(weights=None)
model.fc = nn.Linear(model.fc.in_features, len(CLASS_TO_IDX))
return model
def load(path: Path) -> nn.Module:
"""Load checkpoint. Supports state_dict (preferred) or full pickled model.
Tries `weights_only=True` first (safe; refuses arbitrary pickle opcodes);
falls back to `weights_only=False` only when the artifact turns out to be
a full `nn.Module` pickle (rare). The fallback path executes pickle code
and should only be used with trusted artifacts.
"""
path = Path(path)
if not path.exists():
raise FileNotFoundError(f"MRI 2D checkpoint not found: {path}")
try:
obj = torch.load(str(path), map_location="cpu", weights_only=True)
except (pickle.UnpicklingError, RuntimeError) as e:
logger.warning(
"MRI 2D checkpoint at %s is not a state_dict (weights_only=True failed: %s); "
"falling back to weights_only=False — only safe with trusted artifacts.",
path, e,
)
obj = torch.load(str(path), map_location="cpu", weights_only=False)
if isinstance(obj, nn.Module):
model = obj
elif isinstance(obj, dict):
model = _build_resnet18_4class()
clean = {k.removeprefix("module."): v for k, v in obj.items()}
model.load_state_dict(clean, strict=True)
else:
raise ValueError(
f"MRI 2D checkpoint at {path} has unexpected type {type(obj).__name__}; "
"expected state_dict (dict) or a full nn.Module pickle."
)
model.eval()
return model
def predict_image(model: nn.Module, image_path: Path) -> dict[str, Any]:
"""Run inference on one image. Output mirrors mri_model.predict_with_proba."""
image_path = Path(image_path)
if not image_path.exists():
raise FileNotFoundError(f"MRI image not found: {image_path}")
img = Image.open(str(image_path)).convert("RGB")
tensor = _TRANSFORM(img).unsqueeze(0)
with torch.inference_mode():
logits = model(tensor)
probs = torch.softmax(logits, dim=1).squeeze(0).cpu().numpy()
label_idx = int(np.argmax(probs))
return {
"label": label_idx,
"label_text": IDX_TO_CLASS[label_idx],
"confidence": float(probs[label_idx]),
"probabilities": [
{"label": i, "label_text": IDX_TO_CLASS[i], "probability": float(p)}
for i, p in enumerate(probs)
],
}