File size: 3,726 Bytes
11f62d8
 
 
 
 
 
 
 
 
 
07b00eb
11f62d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07b00eb
 
 
 
 
 
 
11f62d8
 
 
07b00eb
 
 
 
 
 
 
 
 
11f62d8
 
07b00eb
11f62d8
 
 
07b00eb
 
 
 
 
11f62d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""Pretrained 2D MRI Alzheimer's classifier (resnet18, 4 classes).

Decision-layer bridge for an externally-trained PyTorch checkpoint. Loads
either a state_dict (default) or a full pickled model, applies the trainer's
preprocessing (resize image_size=160, ImageNet normalisation), and emits the
same dict shape as src.models.mri_model.predict_with_proba so downstream
code paths don't care which backend produced the prediction.
"""
from __future__ import annotations

import pickle
from pathlib import Path
from typing import Any

import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from torchvision import models, transforms

from src.core.logger import get_logger

logger = get_logger(__name__)

CLASS_TO_IDX: dict[str, int] = {
    "MildDemented": 0,
    "ModerateDemented": 1,
    "NonDemented": 2,
    "VeryMildDemented": 3,
}
IDX_TO_CLASS: dict[int, str] = {v: k for k, v in CLASS_TO_IDX.items()}

DEFAULT_IMAGE_SIZE = 160
_IMAGENET_MEAN = (0.485, 0.456, 0.406)
_IMAGENET_STD = (0.229, 0.224, 0.225)

_TRANSFORM = transforms.Compose([
    transforms.Resize((DEFAULT_IMAGE_SIZE, DEFAULT_IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=_IMAGENET_MEAN, std=_IMAGENET_STD),
])


def _build_resnet18_4class() -> nn.Module:
    model = models.resnet18(weights=None)
    model.fc = nn.Linear(model.fc.in_features, len(CLASS_TO_IDX))
    return model


def load(path: Path) -> nn.Module:
    """Load checkpoint. Supports state_dict (preferred) or full pickled model.

    Tries `weights_only=True` first (safe; refuses arbitrary pickle opcodes);
    falls back to `weights_only=False` only when the artifact turns out to be
    a full `nn.Module` pickle (rare). The fallback path executes pickle code
    and should only be used with trusted artifacts.
    """
    path = Path(path)
    if not path.exists():
        raise FileNotFoundError(f"MRI 2D checkpoint not found: {path}")
    try:
        obj = torch.load(str(path), map_location="cpu", weights_only=True)
    except (pickle.UnpicklingError, RuntimeError) as e:
        logger.warning(
            "MRI 2D checkpoint at %s is not a state_dict (weights_only=True failed: %s); "
            "falling back to weights_only=False — only safe with trusted artifacts.",
            path, e,
        )
        obj = torch.load(str(path), map_location="cpu", weights_only=False)
    if isinstance(obj, nn.Module):
        model = obj
    elif isinstance(obj, dict):
        model = _build_resnet18_4class()
        clean = {k.removeprefix("module."): v for k, v in obj.items()}
        model.load_state_dict(clean, strict=True)
    else:
        raise ValueError(
            f"MRI 2D checkpoint at {path} has unexpected type {type(obj).__name__}; "
            "expected state_dict (dict) or a full nn.Module pickle."
        )
    model.eval()
    return model


def predict_image(model: nn.Module, image_path: Path) -> dict[str, Any]:
    """Run inference on one image. Output mirrors mri_model.predict_with_proba."""
    image_path = Path(image_path)
    if not image_path.exists():
        raise FileNotFoundError(f"MRI image not found: {image_path}")
    img = Image.open(str(image_path)).convert("RGB")
    tensor = _TRANSFORM(img).unsqueeze(0)

    with torch.inference_mode():
        logits = model(tensor)
        probs = torch.softmax(logits, dim=1).squeeze(0).cpu().numpy()

    label_idx = int(np.argmax(probs))
    return {
        "label": label_idx,
        "label_text": IDX_TO_CLASS[label_idx],
        "confidence": float(probs[label_idx]),
        "probabilities": [
            {"label": i, "label_text": IDX_TO_CLASS[i], "probability": float(p)}
            for i, p in enumerate(probs)
        ],
    }