| import os |
| import json |
| import pydicom |
| import numpy as np |
| import torch |
|
|
| from typing import Callable, Optional, Tuple |
| from torch import Tensor |
| from torch.utils.data import Dataset |
|
|
| |
| |
| DTYPE = torch.float16 |
|
|
|
|
| class SyntaxDataset(Dataset): |
| """ |
| PyTorch Dataset для обучения видеобэкбона на задаче SYNTAX. |
| |
| Функциональность: |
| - читает метаданные из JSON (относительный путь относительно root); |
| - фильтрует по артерии (левая / правая); |
| - опционально отфильтровывает только примеры с положительным SYNTAX |
| (validation=True); |
| - рассчитывает sample weights по бинам SYNTAX (для WeightedRandomSampler); |
| - конвертирует DICOM-видео в тензор (T, H, W, 3) c uint8 [0–255]; |
| - возвращает: |
| video, label_bin, target_log, weight, rel_path, original_label. |
| """ |
|
|
| def __init__( |
| self, |
| root: str, |
| meta: str, |
| train: bool, |
| length: int, |
| label: str, |
| artery_bin: int, |
| validation: bool = False, |
| transform: Optional[Callable] = None, |
| ) -> None: |
| super().__init__() |
| self.root = root |
| self.train = train |
| self.length = length |
| self.label = label |
| self.transform = transform |
| self.validation = validation |
|
|
| |
| meta_path = os.path.join(root, meta) |
| with open(meta_path, "r") as f: |
| dataset = json.load(f) |
|
|
| |
| if artery_bin is not None: |
| assert artery_bin in (0, 1), "artery_bin должен быть 0 (левая) или 1 (правая)" |
| dataset = [rec for rec in dataset if rec["artery"] == artery_bin] |
| self.artery_bin = artery_bin |
| else: |
| |
| raise ValueError("artery_bin должен быть явно задан (0 или 1).") |
|
|
| |
| if validation: |
| dataset = [rec for rec in dataset if rec[self.label] > 0] |
|
|
| |
| for rec in dataset: |
| rec["weight"] = 1.0 |
|
|
| self.dataset = dataset |
|
|
| |
| |
| |
| def get_sample_weights(self) -> Tensor: |
| """ |
| Считает веса для примеров по бинам SYNTAX. |
| |
| Для каждой артерии определён свой набор порогов, |
| после чего каждый пример получает вес, обратный частоте своего бина. |
| """ |
| |
| bin_thresholds = { |
| 0: [0, 5, 10, 15], |
| 1: [0, 2, 5, 8], |
| } |
|
|
| thresholds = bin_thresholds[self.artery_bin] |
| thr0, thr1, thr2, thr3 = thresholds |
|
|
| |
| self.dataset_0 = [rec for rec in self.dataset if rec[self.label] == thr0] |
| self.dataset_1 = [rec for rec in self.dataset if thr0 < rec[self.label] <= thr1] |
| self.dataset_2 = [rec for rec in self.dataset if thr1 < rec[self.label] <= thr2] |
| self.dataset_3 = [rec for rec in self.dataset if thr2 < rec[self.label] <= thr3] |
| self.dataset_4 = [rec for rec in self.dataset if rec[self.label] > thr3] |
|
|
| total = ( |
| len(self.dataset_0) |
| + len(self.dataset_1) |
| + len(self.dataset_2) |
| + len(self.dataset_3) |
| + len(self.dataset_4) |
| ) |
|
|
| def safe_weight(count: int) -> float: |
| |
| return total / count if count > 0 else 0.0 |
|
|
| self.weights_0 = safe_weight(len(self.dataset_0)) |
| self.weights_1 = safe_weight(len(self.dataset_1)) |
| self.weights_2 = safe_weight(len(self.dataset_2)) |
| self.weights_3 = safe_weight(len(self.dataset_3)) |
| self.weights_4 = safe_weight(len(self.dataset_4)) |
|
|
| print( |
| "Weights: ", |
| self.weights_0, |
| self.weights_1, |
| self.weights_2, |
| self.weights_3, |
| self.weights_4, |
| ) |
| print( |
| "Counts: ", |
| len(self.dataset_0), |
| len(self.dataset_1), |
| len(self.dataset_2), |
| len(self.dataset_3), |
| len(self.dataset_4), |
| ) |
|
|
| |
| weights = [] |
| for rec in self.dataset: |
| syntax_score = rec[self.label] |
| if syntax_score == thr0: |
| weights.append(self.weights_0) |
| elif thr0 < syntax_score <= thr1: |
| weights.append(self.weights_1) |
| elif thr1 < syntax_score <= thr2: |
| weights.append(self.weights_2) |
| elif thr2 < syntax_score <= thr3: |
| weights.append(self.weights_3) |
| else: |
| weights.append(self.weights_4) |
|
|
| self.weights = torch.tensor(weights, dtype=DTYPE) |
| return self.weights |
|
|
| |
| def __len__(self) -> int: |
| return len(self.dataset) |
|
|
| |
| def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor, Tensor, str, Tensor]: |
| """ |
| Возвращает один пример: |
| - video: Tensor (T, H, W, 3) → после transform обычно (C, T, H, W) |
| - label: бинарный таргет по порогу для конкретной артерии |
| - target: логарифмированный SYNTAX score (регрессия) |
| - weight: вес примера (для самплера / лосса) |
| - path: относительный путь к DICOM файлу |
| - original_label: исходный SYNTAX score |
| """ |
| rec = self.dataset[idx] |
|
|
| |
| path = rec["path"] |
| weight = rec["weight"] |
|
|
| full_path = os.path.join(self.root, path) |
| video = pydicom.dcmread(full_path).pixel_array |
|
|
| |
| if video.dtype == np.uint16: |
| vmax = np.max(video) |
| assert vmax > 0 |
| video = video.astype(np.float32) |
| video = video * (255.0 / vmax) |
| video = video.astype(np.uint8) |
| assert video.dtype == np.uint8 |
|
|
| |
| bin_thresholds = { |
| 0: 15, |
| 1: 5, |
| } |
|
|
| syntax_value = rec[self.label] |
| label = torch.tensor( |
| [int(syntax_value > bin_thresholds[self.artery_bin])], |
| dtype=DTYPE, |
| ) |
| target = torch.tensor([np.log(1.0 + syntax_value)], dtype=DTYPE) |
| original_label = torch.tensor([syntax_value], dtype=DTYPE) |
|
|
| |
| while len(video) < self.length: |
| video = np.concatenate([video, video]) |
| t = len(video) |
|
|
| if self.train: |
| |
| begin = torch.randint(low=0, high=t - self.length + 1, size=(1,)) |
| end = begin + self.length |
| video = video[begin:end, :, :] |
| else: |
| |
| video = video |
|
|
| |
| video = torch.tensor(np.stack([video, video, video], axis=-1)) |
|
|
| if self.transform is not None: |
| video = self.transform(video) |
|
|
| sample_weight = torch.tensor([weight], dtype=DTYPE) |
|
|
| return video, label, target, sample_weight, path, original_label |
|
|