| import os |
| import json |
| import pydicom |
| import numpy as np |
| import torch |
|
|
| from typing import Callable, Optional, Tuple |
|
|
| from torch import Tensor |
| from torch.utils.data import Dataset |
| from sklearn.preprocessing import RobustScaler |
|
|
| DTYPE = torch.float16 |
|
|
|
|
| class SyntaxDataset(Dataset): |
| def __init__( |
| self, |
| root: str, |
| meta: str, |
| train: bool, |
| length: int, |
| label: str, |
| artery: str, |
| inference: bool = False, |
| validation: bool = False, |
| transform: Optional[Callable] = None |
| |
| ) -> None: |
| self.root = root |
| self.train = train |
| self.length = length |
| self.label = label |
| self.artery = artery |
| self.inference = inference |
| self.transform = transform |
| self.validation = validation |
| meta_path = meta if os.path.isabs(meta) else os.path.join(root, meta) |
|
|
| with open(meta_path) as f: |
| dataset = json.load(f) |
|
|
| if not self.inference: |
| dataset = [rec for rec in dataset if len(rec[f"videos_{artery}"]) > 0] |
|
|
| if validation: |
| dataset = [rec for rec in dataset if rec[self.label] > 0] |
|
|
| self.dataset = dataset |
|
|
| artery_bin = {"left":0, "right":1}.get(artery.lower()) |
| if artery_bin is None: |
| raise ValueError(f"Unknown artery '{artery}'") |
| |
| self.artery_bin = artery_bin |
|
|
| def __len__(self): |
| return len(self.dataset) |
| |
|
|
| def get_sample_weights(self): |
| |
| bin_thresholds = { |
| 0: [0, 5, 10, 15], |
| 1: [0, 2, 5, 8], |
| } |
|
|
| |
| thresholds = bin_thresholds[self.artery_bin] |
|
|
| thr0, thr1, thr2, thr3 = thresholds |
|
|
| |
| self.dataset_0 = [rec for rec in self.dataset if rec[self.label] == thr0] |
| self.dataset_1 = [rec for rec in self.dataset if thr0 < rec[self.label] <= thr1] |
| self.dataset_2 = [rec for rec in self.dataset if thr1 < rec[self.label] <= thr2] |
| self.dataset_3 = [rec for rec in self.dataset if thr2 < rec[self.label] <= thr3] |
| self.dataset_4 = [rec for rec in self.dataset if rec[self.label] > thr3] |
|
|
|
|
| total = len(self.dataset_0) + len(self.dataset_1) + len(self.dataset_2) + len(self.dataset_3) + len(self.dataset_4) |
|
|
|
|
| def safe_weight(count): |
| return total / count if count > 0 else 0.0 |
|
|
| self.weights_0 = safe_weight(len(self.dataset_0)) |
| self.weights_1 = safe_weight(len(self.dataset_1)) |
| self.weights_2 = safe_weight(len(self.dataset_2)) |
| self.weights_3 = safe_weight(len(self.dataset_3)) |
| self.weights_4 = safe_weight(len(self.dataset_4)) |
|
|
| |
| print("Counts: ", len(self.dataset_0), len(self.dataset_1), len(self.dataset_2), len(self.dataset_3), len(self.dataset_4)) |
|
|
| weights = [] |
| for rec in self.dataset: |
| syntax_score = rec[self.label] |
| if syntax_score == thr0: |
| weights.append(self.weights_0) |
| elif thr0 < syntax_score <= thr1: |
| weights.append(self.weights_1) |
| elif thr1 < syntax_score <= thr2: |
| weights.append(self.weights_2) |
| elif thr2 < syntax_score <= thr3: |
| weights.append(self.weights_3) |
| else: |
| weights.append(self.weights_4) |
|
|
| self.weights = torch.tensor(weights, dtype=DTYPE) |
| return self.weights |
|
|
| def __getitem__(self, idx: int) -> Tuple[Tensor, int]: |
|
|
| rec = self.dataset[idx] |
| suid = rec["study_uid"] |
| |
| |
| if self.label: |
| bin_thresholds = { |
| 0: 15, |
| 1: 5, |
| } |
|
|
| label = torch.tensor([int(rec[self.label] > bin_thresholds[self.artery_bin])], dtype=DTYPE) |
| target = torch.tensor([np.log(1.0+rec[self.label])], dtype=DTYPE) |
| else: |
| label = torch.tensor([0], dtype=DTYPE) |
| target = torch.tensor([0], dtype=DTYPE) |
|
|
| nv = len(rec[f"videos_{self.artery}"]) |
| if self.inference: |
| if nv == 0: |
| return 0, label, target, suid |
| seq = range(nv) |
| else: |
| seq = torch.randint(low=0, high=nv, size = (4,)) |
|
|
| videos = [] |
| for vi in seq: |
| video_rec = rec[f"videos_{self.artery}"][vi] |
| path = video_rec["path"] |
| if os.path.isabs(path): |
| full_path = path |
| else: |
| full_path = os.path.join(self.root, path) |
|
|
| video = pydicom.dcmread(full_path).pixel_array |
|
|
| if video.dtype == np.uint16: |
| vmax = np.max(video) |
| assert vmax > 0 |
| video = video.astype(np.float32) |
| video = video * (255. / vmax) |
| video = video.astype(np.uint8) |
| assert video.dtype == np.uint8 |
|
|
| while len(video) < self.length: |
| video = np.concatenate([video, video]) |
| t = len(video) |
| if self.train: |
| begin = torch.randint(low=0, high=t-self.length+1, size=(1,)) |
| end = begin + self.length |
| video = video[begin:end, :, :] |
| else: |
| begin = (t - self.length) // 2 |
| end = begin + self.length |
| video = video[begin:end, :, :] |
| |
| video = torch.tensor(np.stack([video, video, video], axis=-1)) |
|
|
| if self.transform is not None: |
| video = self.transform(video) |
| videos.append(video) |
| videos = torch.stack(videos, dim=0) |
|
|
| |
| return videos, label, target, suid |
|
|