| """Env-var-driven dispatch between volumetric ONNX and 2D resnet18 MRI models.""" |
| from __future__ import annotations |
|
|
| import os |
| from pathlib import Path |
| from typing import Any |
|
|
| from src.core.logger import get_logger |
| from src.models import mri_dl_2d, mri_model |
|
|
| logger = get_logger(__name__) |
|
|
| VALID_KINDS = ("volumetric_onnx", "resnet18_2d") |
| _DEFAULT_KIND = "volumetric_onnx" |
|
|
|
|
| def current_kind() -> str: |
| kind = os.environ.get("MRI_MODEL_KIND", _DEFAULT_KIND) |
| if kind not in VALID_KINDS: |
| raise ValueError(f"unknown MRI_MODEL_KIND={kind!r}; expected one of {VALID_KINDS}") |
| return kind |
|
|
|
|
| def label_names_for_kind(kind: str) -> tuple[str, ...]: |
| if kind == "resnet18_2d": |
| return tuple(mri_dl_2d.IDX_TO_CLASS[i] for i in range(len(mri_dl_2d.CLASS_TO_IDX))) |
| return mri_model.DEFAULT_LABEL_NAMES |
|
|
|
|
| def predict( |
| input_path: Path, |
| checkpoint_path: Path, |
| target_shape: tuple[int, int, int] | None = None, |
| label_names: tuple[str, ...] | None = None, |
| ) -> dict[str, Any]: |
| """Run the active MRI model on one input. Returns the unified prediction dict.""" |
| kind = current_kind() |
| logger.info("dispatching MRI prediction kind=%s input=%s", kind, input_path) |
| if kind == "resnet18_2d": |
| model = mri_dl_2d.load(checkpoint_path) |
| return mri_dl_2d.predict_image(model, input_path) |
| model = mri_model.load(checkpoint_path) |
| return mri_model.predict_nifti( |
| model, |
| input_path, |
| target_shape=target_shape or mri_model.DEFAULT_TARGET_SHAPE, |
| label_names=label_names, |
| ) |
|
|