| """Pretrained 2D MRI Alzheimer's classifier (resnet18, 4 classes). |
| |
| Decision-layer bridge for an externally-trained PyTorch checkpoint. Loads |
| either a state_dict (default) or a full pickled model, applies the trainer's |
| preprocessing (resize image_size=160, ImageNet normalisation), and emits the |
| same dict shape as src.models.mri_model.predict_with_proba so downstream |
| code paths don't care which backend produced the prediction. |
| """ |
| from __future__ import annotations |
|
|
| import pickle |
| from pathlib import Path |
| from typing import Any |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from PIL import Image |
| from torchvision import models, transforms |
|
|
| from src.core.logger import get_logger |
|
|
| logger = get_logger(__name__) |
|
|
| CLASS_TO_IDX: dict[str, int] = { |
| "MildDemented": 0, |
| "ModerateDemented": 1, |
| "NonDemented": 2, |
| "VeryMildDemented": 3, |
| } |
| IDX_TO_CLASS: dict[int, str] = {v: k for k, v in CLASS_TO_IDX.items()} |
|
|
| DEFAULT_IMAGE_SIZE = 160 |
| _IMAGENET_MEAN = (0.485, 0.456, 0.406) |
| _IMAGENET_STD = (0.229, 0.224, 0.225) |
|
|
| _TRANSFORM = transforms.Compose([ |
| transforms.Resize((DEFAULT_IMAGE_SIZE, DEFAULT_IMAGE_SIZE)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=_IMAGENET_MEAN, std=_IMAGENET_STD), |
| ]) |
|
|
|
|
| def _build_resnet18_4class() -> nn.Module: |
| model = models.resnet18(weights=None) |
| model.fc = nn.Linear(model.fc.in_features, len(CLASS_TO_IDX)) |
| return model |
|
|
|
|
| def load(path: Path) -> nn.Module: |
| """Load checkpoint. Supports state_dict (preferred) or full pickled model. |
| |
| Tries `weights_only=True` first (safe; refuses arbitrary pickle opcodes); |
| falls back to `weights_only=False` only when the artifact turns out to be |
| a full `nn.Module` pickle (rare). The fallback path executes pickle code |
| and should only be used with trusted artifacts. |
| """ |
| path = Path(path) |
| if not path.exists(): |
| raise FileNotFoundError(f"MRI 2D checkpoint not found: {path}") |
| try: |
| obj = torch.load(str(path), map_location="cpu", weights_only=True) |
| except (pickle.UnpicklingError, RuntimeError) as e: |
| logger.warning( |
| "MRI 2D checkpoint at %s is not a state_dict (weights_only=True failed: %s); " |
| "falling back to weights_only=False — only safe with trusted artifacts.", |
| path, e, |
| ) |
| obj = torch.load(str(path), map_location="cpu", weights_only=False) |
| if isinstance(obj, nn.Module): |
| model = obj |
| elif isinstance(obj, dict): |
| model = _build_resnet18_4class() |
| clean = {k.removeprefix("module."): v for k, v in obj.items()} |
| model.load_state_dict(clean, strict=True) |
| else: |
| raise ValueError( |
| f"MRI 2D checkpoint at {path} has unexpected type {type(obj).__name__}; " |
| "expected state_dict (dict) or a full nn.Module pickle." |
| ) |
| model.eval() |
| return model |
|
|
|
|
| def predict_image(model: nn.Module, image_path: Path) -> dict[str, Any]: |
| """Run inference on one image. Output mirrors mri_model.predict_with_proba.""" |
| image_path = Path(image_path) |
| if not image_path.exists(): |
| raise FileNotFoundError(f"MRI image not found: {image_path}") |
| img = Image.open(str(image_path)).convert("RGB") |
| tensor = _TRANSFORM(img).unsqueeze(0) |
|
|
| with torch.inference_mode(): |
| logits = model(tensor) |
| probs = torch.softmax(logits, dim=1).squeeze(0).cpu().numpy() |
|
|
| label_idx = int(np.argmax(probs)) |
| return { |
| "label": label_idx, |
| "label_text": IDX_TO_CLASS[label_idx], |
| "confidence": float(probs[label_idx]), |
| "probabilities": [ |
| {"label": i, "label_text": IDX_TO_CLASS[i], "probability": float(p)} |
| for i, p in enumerate(probs) |
| ], |
| } |
|
|