File size: 1,563 Bytes
621cb25 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 | """Env-var-driven dispatch between volumetric ONNX and 2D resnet18 MRI models."""
from __future__ import annotations
import os
from pathlib import Path
from typing import Any
from src.core.logger import get_logger
from src.models import mri_dl_2d, mri_model
logger = get_logger(__name__)
VALID_KINDS = ("volumetric_onnx", "resnet18_2d")
_DEFAULT_KIND = "volumetric_onnx"
def current_kind() -> str:
kind = os.environ.get("MRI_MODEL_KIND", _DEFAULT_KIND)
if kind not in VALID_KINDS:
raise ValueError(f"unknown MRI_MODEL_KIND={kind!r}; expected one of {VALID_KINDS}")
return kind
def label_names_for_kind(kind: str) -> tuple[str, ...]:
if kind == "resnet18_2d":
return tuple(mri_dl_2d.IDX_TO_CLASS[i] for i in range(len(mri_dl_2d.CLASS_TO_IDX)))
return mri_model.DEFAULT_LABEL_NAMES
def predict(
input_path: Path,
checkpoint_path: Path,
target_shape: tuple[int, int, int] | None = None,
label_names: tuple[str, ...] | None = None,
) -> dict[str, Any]:
"""Run the active MRI model on one input. Returns the unified prediction dict."""
kind = current_kind()
logger.info("dispatching MRI prediction kind=%s input=%s", kind, input_path)
if kind == "resnet18_2d":
model = mri_dl_2d.load(checkpoint_path)
return mri_dl_2d.predict_image(model, input_path)
model = mri_model.load(checkpoint_path)
return mri_model.predict_nifti(
model,
input_path,
target_shape=target_shape or mri_model.DEFAULT_TARGET_SHAPE,
label_names=label_names,
)
|