| from typing import Any, Dict, Union |
|
|
| import numpy as np |
| import torch |
| from PIL import Image |
| from torchvision.transforms.functional import convert_image_dtype |
| from transformers.image_processing_utils import BaseImageProcessor |
| from transformers.image_utils import ChannelDimension |
|
|
| |
| try: |
| from transformers.image_processing_utils import BatchFeature |
| except Exception: |
| from transformers.feature_extraction_utils import BatchFeature |
|
|
|
|
| _BICUBIC = Image.BICUBIC |
|
|
|
|
| class CuriaImageProcessor(BaseImageProcessor): |
| """ |
| 1-channel medical preprocessor replicating: |
| NumpyToTensor -> float32 -> Resize(crop_size, BICUBIC, antialias) |
| -> optional ClipIntensity(min=-1000) -> NormalizeIntensity(channel_wise=True) |
| Outputs: pixel_values as (B, 1, crop_size, crop_size) |
| |
| Images needs to be in: |
| |
| - PL for axial |
| - IL for coronal |
| - IP for sagittal |
| |
| for CT, no windowing, just hounsfield or normalized image |
| for MRI, similar, no windowing, just raw values or normalized image |
| """ |
|
|
| model_input_names = ["pixel_values"] |
|
|
| def __init__( |
| self, |
| crop_size: int = 512, |
| clip_below_air: bool = False, |
| eps: float = 1e-6, |
| do_resize: bool = True, |
| do_normalize: bool = True, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
| self.crop_size = int(crop_size) |
| self.clip_below_air = bool(clip_below_air) |
| self.eps = float(eps) |
| self.do_resize = bool(do_resize) |
| self.do_normalize = bool(do_normalize) |
|
|
| def _to_tensor(self, image: Union[np.ndarray, torch.Tensor, Image.Image]) -> torch.Tensor: |
| """Accepts (H,W), (1,H,W) or PIL; returns torch.float32 tensor (H, W) in grayscale.""" |
| if isinstance(image, Image.Image): |
| |
| if image.mode != "L" and image.mode != "F": |
| image = image.convert("L") |
| arr = np.array(image) |
| tensor = torch.from_numpy(arr) |
| return tensor.float() |
|
|
| if isinstance(image, torch.Tensor): |
| tensor = image.detach().cpu() |
| if tensor.ndim == 3 and tensor.shape[0] == 1: |
| tensor = tensor[0] |
| if tensor.ndim != 2: |
| raise ValueError(f"Expected 2D grayscale tensor or (1,H,W); got shape {tensor.shape}") |
| return tensor.float() |
|
|
| if isinstance(image, np.ndarray): |
| arr = image |
| |
| if arr.ndim == 3 and arr.shape[0] == 1: |
| arr = arr[0] |
| if arr.ndim != 2: |
| raise ValueError(f"Expected 2D grayscale array or (1,H,W); got shape {arr.shape}") |
| tensor = torch.from_numpy(arr) |
| return tensor.to(torch.int16) |
|
|
| def _resize(self, tensor: torch.Tensor) -> torch.Tensor: |
| """ |
| Resize a 2D torch.Tensor (H, W) to (crop_size, crop_size) using bicubic interpolation. |
| If do_resize is False, returns the input tensor unchanged. |
| """ |
| if not self.do_resize: |
| return tensor |
| if tensor.ndim != 2: |
| raise ValueError(f"Expected 2D tensor (H, W), got shape {tensor.shape}") |
| |
| tensor = tensor.unsqueeze(0).unsqueeze(0) |
| tensor = torch.nn.functional.interpolate( |
| tensor, |
| size=(self.crop_size, self.crop_size), |
| mode="bicubic", |
| align_corners=False, |
| antialias=True, |
| ) |
| |
| return tensor[0, 0] |
|
|
| def _clip_min(self, tensor: torch.Tensor) -> torch.Tensor: |
| if self.clip_below_air: |
| torch.clamp_min(tensor, -1000.0, out=tensor) |
| return tensor |
|
|
| def _zscore_per_image(self, tensor: torch.Tensor) -> torch.Tensor: |
| |
| mean = float(tensor.mean()) |
| std = float(tensor.std()) |
| if std < self.eps: |
| |
| return tensor - mean |
| return (tensor - mean) / std |
|
|
| def __call__(self, images, return_tensors="pt", data_format=ChannelDimension.FIRST, **kwargs): |
| if not isinstance(images, (list, tuple)): |
| images = [images] |
|
|
| batch = [] |
| for img in images: |
| if len(img.shape) == 3: |
| full_volume = [] |
| for i in range(img.shape[-1]): |
| x = self._to_tensor(img[:, :, i]) |
| x = convert_image_dtype(x, torch.float32) |
| x = self._resize(x) |
| x = self._clip_min(x) |
| x = x[None, ...] |
| full_volume.append(x) |
| x = torch.stack(full_volume, dim=0) |
| x = self._zscore_per_image(x) |
| else: |
| x = self._to_tensor(img) |
| x = convert_image_dtype(x, torch.float32) |
| x = self._resize(x) |
| x = self._clip_min(x) |
| x = self._zscore_per_image(x) |
| x = x[None, ...] |
| batch.append(x) |
|
|
| pixel_values = np.stack(batch, axis=0) |
|
|
| |
| return BatchFeature( |
| data={"pixel_values": pixel_values}, |
| tensor_type=return_tensors, |
| ) |
|
|
| |
| def to_dict(self) -> Dict[str, Any]: |
| out = super().to_dict() |
| out.update( |
| dict( |
| crop_size=self.crop_size, |
| clip_below_air=self.clip_below_air, |
| eps=self.eps, |
| do_resize=self.do_resize, |
| do_normalize=self.do_normalize, |
| ) |
| ) |
| |
| out["auto_map"] = {"AutoImageProcessor": "curia_image_processor.CuriaImageProcessor"} |
| return out |
|
|