| |
| import random |
|
|
| import cv2 |
| import numpy as np |
| import torch |
| from PIL import Image |
| from torchvision import transforms |
| from transformers import BatchEncoding, PreTrainedTokenizer |
|
|
| """ |
| Mixin for all modalities, each mixin has: |
| - preprocess function that takes in path or data and returns tensor |
| - construct_input function that takes in tensor and returns dict with batch |
| dimension for model input |
| - key string for model input dict |
| """ |
|
|
|
|
| class ECHO_Mixin: |
| LOWER_YELLOW: list[int] = [20, 50, 50] |
| UPPER_YELLOW: list[int] = [100, 255, 255] |
| IMAGE_SIZE: tuple[int, int] = (224, 224) |
| NORM_MEAN: tuple[float, float, float] = (0.48145466, 0.4578275, 0.40821073) |
| NORM_STD: tuple[float, float, float] = (0.26862954, 0.26130258, 0.27577711) |
|
|
| ECHO_TRANSFORMS = transforms.Compose( |
| [ |
| transforms.ToTensor(), |
| transforms.Resize(IMAGE_SIZE), |
| transforms.Normalize( |
| mean=NORM_MEAN, |
| std=NORM_STD, |
| ), |
| ] |
| ) |
| ECHO_KEY: str = "echo" |
|
|
| def grabimage(self, split: str, data: dict[str, np.ndarray]) -> np.ndarray: |
| """""" |
| if split == "train": |
| caseofinterest = random.choice(list(data.keys())) |
| imageindice = random.choice(list(range(data[caseofinterest].shape[0]))) |
|
|
| else: |
| caseofinterest = random.choice(list(data.keys())) |
| imageindice = 0 |
| video = data[caseofinterest] |
| return self.extract_echoframe(imageindice, video) |
|
|
| def extract_echoframe(self, imageindice: int, video: np.ndarray) -> np.ndarray: |
| image = video[imageindice] |
| hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) |
| lower_yellow = np.array(self.LOWER_YELLOW) |
| upper_yellow = np.array(self.UPPER_YELLOW) |
| mask = cv2.inRange(hsv_image, lower_yellow, upper_yellow) |
| image[mask > 0] = [0, 0, 0] |
| image = np.array(image, dtype=np.float32) |
| image -= image.min() |
| image /= image.max() |
| image *= 255 |
|
|
| image = image |
| image = image[:, :, :] |
| image = image.astype(np.uint8) |
| return image |
|
|
| def preprocess_echoseries( |
| self, video_dict: dict[str, np.ndarray], split: str = "valid" |
| ) -> torch.Tensor: |
| """assumes inference mode""" |
| image = self.grabimage(split, video_dict) |
| if not isinstance(image, np.ndarray): |
| raise TypeError("Expected image to be a numpy ndarray") |
| pil_image = Image.fromarray(image) |
| transformed = self.ECHO_TRANSFORMS(pil_image) |
| if not isinstance(transformed, torch.Tensor): |
| transformed = transforms.ToTensor()(pil_image) |
| return transformed |
|
|
| def preprocess_single_echo(self, avi_path: str) -> torch.Tensor: |
| """assumes inference mode, opens AVI file and processes first frame |
| Output: image: torch.Tensor of shape (C, H, W) |
| """ |
| cap = cv2.VideoCapture(avi_path) |
| success, frame = cap.read() |
| cap.release() |
| if not success or frame is None: |
| raise ValueError(f"Could not read frame from AVI file: {avi_path}") |
| image = self.extract_echoframe(0, np.array([frame])) |
| image = self.ECHO_TRANSFORMS(Image.fromarray(image)) |
| if not isinstance(image, torch.Tensor): |
| image = torch.from_numpy(image) |
| return image |
|
|
|
|
| |
| class CXR_Mixin: |
| RESIZE: tuple[int, int] = (256, 256) |
| IMAGE_SIZE: tuple[int, int] = (224, 224) |
| NORM_MEAN: list[float] = [0.5862785803043838] |
| NORM_STD: list[float] = [0.27950088968644304] |
| VISION_KEY: str = "vision" |
| CXR_TRANSFORMS = transforms.Compose( |
| [ |
| transforms.ToTensor(), |
| transforms.Resize(RESIZE), |
| transforms.CenterCrop(IMAGE_SIZE), |
| transforms.Normalize( |
| mean=NORM_MEAN, |
| std=NORM_STD, |
| ), |
| ] |
| ) |
|
|
| @staticmethod |
| def remove_border(pixel_array: np.ndarray) -> np.ndarray: |
| |
| coords = np.column_stack(np.where(pixel_array > 0)) |
| x_min, y_min = coords.min(axis=0) |
| x_max, y_max = coords.max(axis=0) |
| |
| cropped_image = pixel_array[x_min:x_max, y_min:y_max] |
| return cropped_image |
|
|
| def preprocess_loaded_cxr(self, img: np.array) -> torch.Tensor: |
| cxr = self.remove_border(img) |
| |
| cxr = np.repeat(cxr[..., np.newaxis], 3, axis=-1) |
|
|
| cxr = Image.fromarray(cxr) |
| transformed = self.CXR_TRANSFORMS(cxr) |
| if not isinstance(transformed, torch.Tensor): |
| transformed = transforms.ToTensor()(cxr) |
| return transformed |
|
|
| def preprocess_single_cxr(self, image_path: str) -> torch.Tensor: |
| """assumes inference mode""" |
| with open(image_path, "rb") as fopen: |
| image = Image.open(fopen).convert("RGB") |
| image = np.array(image)[:, :, 0] |
|
|
| cxr = self.preprocess_loaded_cxr(image) |
| return cxr |
|
|
|
|
| class ECG_Mixin: |
| LENGTH: int = 1000 |
| FREQUENCY: int = 100 |
| CHANNELS: int = 12 |
| NORM_MEAN: float = 0.02547506 |
| NORM_SCALE: float = 0.16486814 |
| NORM_VAR: float = 0.0271815 |
| ECG_KEY: str = "ecg" |
|
|
| def manual_standardize(self, x: np.ndarray) -> torch.Tensor: |
| """ |
| Apply manual standardization to ECG or other data. |
| Equivalent to sklearn's StandardScaler with given constants. |
| |
| Args: |
| x (np.ndarray): Input array of shape (12, 1000) |
| Returns: |
| torch.Tensor: Scaled array of the same shape |
| """ |
| return torch.from_numpy((x - self.NORM_MEAN) / self.NORM_SCALE).float() |
|
|
| def check_ecg(self, ecg: np.ndarray) -> np.ndarray: |
| |
| if np.isnan(ecg).any() or np.isinf(ecg).any(): |
| raise ValueError("ECG contains NaN or Inf values") |
| return ecg[:, : self.LENGTH] |
|
|
| def preprocess_single_ecg(self, ecg_path: str) -> torch.Tensor: |
| """assumes inference mode""" |
| |
| ecg = np.load(ecg_path) |
| if ecg.ndim == 2 and ecg.shape[0] != self.CHANNELS: |
| raise ValueError(f"Expected ECG with {self.CHANNELS} channels, got {ecg.shape[0]}") |
|
|
| ecg = self.check_ecg(ecg) |
| transformed = self.manual_standardize(ecg) |
|
|
| return transformed |
|
|
|
|
| class Text_Mixin: |
| MODALITY_LIST: dict[str, str] = {"echo": "echocardiogram", "ecg": "ecg", "vision": "cxr"} |
| MAX_LENGTH: int = 120 |
| TEXT_LENGTH: int = 100 |
|
|
| def get_first_n_words(self, text: str, n: int = 100) -> str: |
| """97.5 percentile of text is less than 35 words""" |
| words = text.split() |
| return " ".join(words[:n]) |
|
|
| def createCaption(self, caption: str, modality: str = "") -> str: |
| assert modality in set(self.MODALITY_LIST.keys()) or modality == "", ( |
| f"modality should be in {self.MODALITY_LIST} or empty" |
| ) |
| return f"text : {caption}, {modality} looks like : " |
|
|
| def createTokenizedCaption(self, caption: str, tokenizer: PreTrainedTokenizer) -> BatchEncoding: |
| encoding = tokenizer( |
| caption, |
| padding="max_length", |
| truncation=True, |
| max_length=self.MAX_LENGTH, |
| return_tensors="pt", |
| ) |
| return encoding |
|
|
| def construct_caption( |
| self, caption: str, tokenizer: PreTrainedTokenizer, modality: str = "" |
| ) -> BatchEncoding: |
| """given caption string, return tokenized caption dict for model input |
| Output: dict with keys 'input_ids' and 'attention_mask', each of shape (1, L) |
| """ |
| caption_str = self.createCaption(caption, modality) |
| tokenized = self.createTokenizedCaption(caption_str, tokenizer) |
| return tokenized |
|
|