File size: 1,563 Bytes
621cb25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
"""Env-var-driven dispatch between volumetric ONNX and 2D resnet18 MRI models."""
from __future__ import annotations

import os
from pathlib import Path
from typing import Any

from src.core.logger import get_logger
from src.models import mri_dl_2d, mri_model

logger = get_logger(__name__)

VALID_KINDS = ("volumetric_onnx", "resnet18_2d")
_DEFAULT_KIND = "volumetric_onnx"


def current_kind() -> str:
    kind = os.environ.get("MRI_MODEL_KIND", _DEFAULT_KIND)
    if kind not in VALID_KINDS:
        raise ValueError(f"unknown MRI_MODEL_KIND={kind!r}; expected one of {VALID_KINDS}")
    return kind


def label_names_for_kind(kind: str) -> tuple[str, ...]:
    if kind == "resnet18_2d":
        return tuple(mri_dl_2d.IDX_TO_CLASS[i] for i in range(len(mri_dl_2d.CLASS_TO_IDX)))
    return mri_model.DEFAULT_LABEL_NAMES


def predict(
    input_path: Path,
    checkpoint_path: Path,
    target_shape: tuple[int, int, int] | None = None,
    label_names: tuple[str, ...] | None = None,
) -> dict[str, Any]:
    """Run the active MRI model on one input. Returns the unified prediction dict."""
    kind = current_kind()
    logger.info("dispatching MRI prediction kind=%s input=%s", kind, input_path)
    if kind == "resnet18_2d":
        model = mri_dl_2d.load(checkpoint_path)
        return mri_dl_2d.predict_image(model, input_path)
    model = mri_model.load(checkpoint_path)
    return mri_model.predict_nifti(
        model,
        input_path,
        target_shape=target_shape or mri_model.DEFAULT_TARGET_SHAPE,
        label_names=label_names,
    )