|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import warnings
|
| from typing import List, Optional
|
|
|
| import torch
|
| from monai.transforms import (
|
| Compose,
|
| DivisiblePadd,
|
| EnsureChannelFirstd,
|
| EnsureTyped,
|
| Lambdad,
|
| LoadImaged,
|
| Orientationd,
|
| RandAdjustContrastd,
|
| RandBiasFieldd,
|
| RandFlipd,
|
| RandGibbsNoised,
|
| RandHistogramShiftd,
|
| RandRotate90d,
|
| RandRotated,
|
| RandScaleIntensityd,
|
| RandShiftIntensityd,
|
| RandSpatialCropd,
|
| RandZoomd,
|
| ResizeWithPadOrCropd,
|
| ScaleIntensityRanged,
|
| ScaleIntensityRangePercentilesd,
|
| SelectItemsd,
|
| Spacingd,
|
| SpatialPadd,
|
| )
|
|
|
| SUPPORT_MODALITIES = ["ct", "mri"]
|
|
|
|
|
| def define_fixed_intensity_transform(modality: str, image_keys: List[str] = ["image"]) -> List:
|
| """
|
| Define fixed intensity transform based on the modality.
|
|
|
| Args:
|
| modality (str): The imaging modality, either 'ct' or 'mri'.
|
| image_keys (List[str], optional): List of image keys. Defaults to ["image"].
|
|
|
| Returns:
|
| List: A list of intensity transforms.
|
| """
|
| if modality not in SUPPORT_MODALITIES:
|
| warnings.warn(
|
| f"Intensity transform only support {SUPPORT_MODALITIES}. Got {modality}. Will not do any intensity transform and will use original intensities."
|
| )
|
|
|
| modality = modality.lower()
|
|
|
| intensity_transforms = {
|
| "mri": [
|
| ScaleIntensityRangePercentilesd(keys=image_keys, lower=0.0, upper=99.5, b_min=0.0, b_max=1, clip=False)
|
| ],
|
| "ct": [ScaleIntensityRanged(keys=image_keys, a_min=-1000, a_max=1000, b_min=0.0, b_max=1.0, clip=True)],
|
| }
|
|
|
| if modality not in intensity_transforms:
|
| return []
|
|
|
| return intensity_transforms[modality]
|
|
|
|
|
| def define_random_intensity_transform(modality: str, image_keys: List[str] = ["image"]) -> List:
|
| """
|
| Define random intensity transform based on the modality.
|
|
|
| Args:
|
| modality (str): The imaging modality, either 'ct' or 'mri'.
|
| image_keys (List[str], optional): List of image keys. Defaults to ["image"].
|
|
|
| Returns:
|
| List: A list of random intensity transforms.
|
| """
|
| modality = modality.lower()
|
| if modality not in SUPPORT_MODALITIES:
|
| warnings.warn(
|
| f"Intensity transform only support {SUPPORT_MODALITIES}. Got {modality}. Will not do any intensity transform and will use original intensities."
|
| )
|
|
|
| if modality == "ct":
|
| return []
|
| elif modality == "mri":
|
| return [
|
| RandBiasFieldd(keys=image_keys, prob=0.3, coeff_range=(0.0, 0.3)),
|
| RandGibbsNoised(keys=image_keys, prob=0.3, alpha=(0.5, 1.0)),
|
| RandAdjustContrastd(keys=image_keys, prob=0.3, gamma=(0.5, 2.0)),
|
| RandHistogramShiftd(keys=image_keys, prob=0.05, num_control_points=10),
|
| ]
|
| else:
|
| return []
|
|
|
|
|
| def define_vae_transform(
|
| is_train: bool,
|
| modality: str,
|
| random_aug: bool,
|
| k: int = 4,
|
| patch_size: List[int] = [128, 128, 128],
|
| val_patch_size: Optional[List[int]] = None,
|
| output_dtype: torch.dtype = torch.float32,
|
| spacing_type: str = "original",
|
| spacing: Optional[List[float]] = None,
|
| image_keys: List[str] = ["image"],
|
| label_keys: List[str] = [],
|
| additional_keys: List[str] = [],
|
| select_channel: int = 0,
|
| ) -> tuple:
|
| """
|
| Define the MAISI VAE transform pipeline for training or validation.
|
|
|
| Args:
|
| is_train (bool): Whether it's for training or not. If True, the output transform will consider random_aug, the cropping will use "patch_size" for random cropping. If False, the output transform will alwasy treat "random_aug" as False, will use "val_patch_size" for central cropping.
|
| modality (str): The imaging modality, either 'ct' or 'mri'.
|
| random_aug (bool): Whether to apply random data augmentation.
|
| k (int, optional): Patches should be divisible by k. Defaults to 4.
|
| patch_size (List[int], optional): Size of the patches. Defaults to [128, 128, 128]. Will random crop patch for training.
|
| val_patch_size (Optional[List[int]], optional): Size of validation patches. Defaults to None. If None, will use the whole volume for validation. If given, will central crop a patch for validation.
|
| output_dtype (torch.dtype, optional): Output data type. Defaults to torch.float32.
|
| spacing_type (str, optional): Type of spacing. Defaults to "original". Choose from ["original", "fixed", "rand_zoom"].
|
| spacing (Optional[List[float]], optional): Spacing values. Defaults to None.
|
| image_keys (List[str], optional): List of image keys. Defaults to ["image"].
|
| label_keys (List[str], optional): List of label keys. Defaults to [].
|
| additional_keys (List[str], optional): List of additional keys. Defaults to [].
|
| select_channel (int, optional): Channel to select for multi-channel MRI. Defaults to 0.
|
|
|
| Returns:
|
| tuple: A tuple containing Composed Transform train_transforms or val_transforms depending on 'is_train'.
|
| """
|
| modality = modality.lower()
|
| if modality not in SUPPORT_MODALITIES:
|
| warnings.warn(
|
| f"Intensity transform only support {SUPPORT_MODALITIES}. Got {modality}. Will not do any intensity transform and will use original intensities."
|
| )
|
|
|
| if spacing_type not in ["original", "fixed", "rand_zoom"]:
|
| raise ValueError(f"spacing_type has to be chosen from ['original', 'fixed', 'rand_zoom']. Got {spacing_type}.")
|
|
|
| keys = image_keys + label_keys + additional_keys
|
| interp_mode = ["bilinear"] * len(image_keys) + ["nearest"] * len(label_keys)
|
|
|
| common_transform = [
|
| SelectItemsd(keys=keys, allow_missing_keys=True),
|
| LoadImaged(keys=keys, allow_missing_keys=True),
|
| EnsureChannelFirstd(keys=keys, allow_missing_keys=True),
|
| Orientationd(keys=keys, axcodes="RAS", allow_missing_keys=True),
|
| ]
|
|
|
| if modality == "mri":
|
| common_transform.append(Lambdad(keys=image_keys, func=lambda x: x[select_channel : select_channel + 1, ...]))
|
|
|
| common_transform.extend(define_fixed_intensity_transform(modality, image_keys=image_keys))
|
|
|
| if spacing_type == "fixed":
|
| common_transform.append(
|
| Spacingd(keys=image_keys + label_keys, allow_missing_keys=True, pixdim=spacing, mode=interp_mode)
|
| )
|
|
|
| random_transform = []
|
| if is_train and random_aug:
|
| random_transform.extend(define_random_intensity_transform(modality, image_keys=image_keys))
|
| random_transform.extend(
|
| [RandFlipd(keys=keys, allow_missing_keys=True, prob=0.5, spatial_axis=axis) for axis in range(3)]
|
| + [
|
| RandRotate90d(keys=keys, allow_missing_keys=True, prob=0.5, spatial_axes=axes)
|
| for axes in [(0, 1), (1, 2), (0, 2)]
|
| ]
|
| + [
|
| RandScaleIntensityd(keys=image_keys, allow_missing_keys=True, prob=0.3, factors=(0.9, 1.1)),
|
| RandShiftIntensityd(keys=image_keys, allow_missing_keys=True, prob=0.3, offsets=0.05),
|
| ]
|
| )
|
|
|
| if spacing_type == "rand_zoom":
|
| random_transform.extend(
|
| [
|
| RandZoomd(
|
| keys=image_keys + label_keys,
|
| allow_missing_keys=True,
|
| prob=0.3,
|
| min_zoom=0.5,
|
| max_zoom=1.5,
|
| keep_size=False,
|
| mode=interp_mode,
|
| ),
|
| RandRotated(
|
| keys=image_keys + label_keys,
|
| allow_missing_keys=True,
|
| prob=0.3,
|
| range_x=0.1,
|
| range_y=0.1,
|
| range_z=0.1,
|
| keep_size=True,
|
| mode=interp_mode,
|
| ),
|
| ]
|
| )
|
|
|
| if is_train:
|
| train_crop = [
|
| SpatialPadd(keys=keys, spatial_size=patch_size, allow_missing_keys=True),
|
| RandSpatialCropd(
|
| keys=keys, roi_size=patch_size, allow_missing_keys=True, random_size=False, random_center=True
|
| ),
|
| ]
|
| else:
|
| val_crop = (
|
| [DivisiblePadd(keys=keys, allow_missing_keys=True, k=k)]
|
| if val_patch_size is None
|
| else [ResizeWithPadOrCropd(keys=keys, allow_missing_keys=True, spatial_size=val_patch_size)]
|
| )
|
|
|
| final_transform = [EnsureTyped(keys=keys, dtype=output_dtype, allow_missing_keys=True)]
|
|
|
| if is_train:
|
| train_transforms = Compose(
|
| common_transform + random_transform + train_crop + final_transform
|
| if random_aug
|
| else common_transform + train_crop + final_transform
|
| )
|
| return train_transforms
|
| else:
|
| val_transforms = Compose(common_transform + val_crop + final_transform)
|
| return val_transforms
|
|
|
|
|
| class VAE_Transform:
|
| """
|
| A class to handle MAISI VAE transformations for different modalities.
|
| """
|
|
|
| def __init__(
|
| self,
|
| is_train: bool,
|
| random_aug: bool,
|
| k: int = 4,
|
| patch_size: List[int] = [128, 128, 128],
|
| val_patch_size: Optional[List[int]] = None,
|
| output_dtype: torch.dtype = torch.float32,
|
| spacing_type: str = "original",
|
| spacing: Optional[List[float]] = None,
|
| image_keys: List[str] = ["image"],
|
| label_keys: List[str] = [],
|
| additional_keys: List[str] = [],
|
| select_channel: int = 0,
|
| ):
|
| """
|
| Initialize the VAE_Transform.
|
|
|
| Args:
|
| is_train (bool): Whether it's for training or not. If True, the output transform will consider random_aug, the cropping will use "patch_size" for random cropping. If False, the output transform will alwasy treat "random_aug" as False, will use "val_patch_size" for central cropping.
|
| random_aug (bool): Whether to apply random data augmentation for training.
|
| k (int, optional): Patches should be divisible by k. Defaults to 4.
|
| patch_size (List[int], optional): Size of the patches. Defaults to [128, 128, 128]. Will random crop patch for training.
|
| val_patch_size (Optional[List[int]], optional): Size of validation patches. Defaults to None. If None, will use the whole volume for validation. If given, will central crop a patch for validation.
|
| output_dtype (torch.dtype, optional): Output data type. Defaults to torch.float32.
|
| spacing_type (str, optional): Type of spacing. Defaults to "original". Choose from ["original", "fixed", "rand_zoom"].
|
| spacing (Optional[List[float]], optional): Spacing values. Defaults to None.
|
| image_keys (List[str], optional): List of image keys. Defaults to ["image"].
|
| label_keys (List[str], optional): List of label keys. Defaults to [].
|
| additional_keys (List[str], optional): List of additional keys. Defaults to [].
|
| select_channel (int, optional): Channel to select for multi-channel MRI. Defaults to 0.
|
| """
|
| if spacing_type not in ["original", "fixed", "rand_zoom"]:
|
| raise ValueError(
|
| f"spacing_type has to be chosen from ['original', 'fixed', 'rand_zoom']. Got {spacing_type}."
|
| )
|
|
|
| self.is_train = is_train
|
| self.transform_dict = {}
|
|
|
| for modality in ["ct", "mri"]:
|
| self.transform_dict[modality] = define_vae_transform(
|
| is_train=is_train,
|
| modality=modality,
|
| random_aug=random_aug,
|
| k=k,
|
| patch_size=patch_size,
|
| val_patch_size=val_patch_size,
|
| output_dtype=output_dtype,
|
| spacing_type=spacing_type,
|
| spacing=spacing,
|
| image_keys=image_keys,
|
| label_keys=label_keys,
|
| additional_keys=additional_keys,
|
| select_channel=select_channel,
|
| )
|
|
|
| def __call__(self, img: dict, fixed_modality: Optional[str] = None) -> dict:
|
| """
|
| Apply the appropriate transform to the input image.
|
|
|
| Args:
|
| img (dict): Input image dictionary.
|
| fixed_modality (Optional[str], optional): Fixed modality to use. Defaults to None.
|
|
|
| Returns:
|
| Composed Transform
|
|
|
| Raises:
|
| ValueError: If the modality is not 'ct' or 'mri'.
|
| """
|
| modality = fixed_modality or img["class"]
|
| modality = modality.lower()
|
| if modality not in ["ct", "mri"]:
|
| warnings.warn(
|
| f"Intensity transform only support {SUPPORT_MODALITIES}. Got {modality}. Will not do any intensity transform and will use original intensities."
|
| )
|
|
|
| transform = self.transform_dict[modality]
|
| return transform(img)
|
|
|