| import math |
| import os |
| from typing import Any, Callable, Optional, Tuple |
| from monai import data, transforms as med |
| from monai.data import load_decathlon_datalist |
| import PIL.Image as PImage |
| from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD |
| from torchvision.datasets.folder import DatasetFolder, IMG_EXTENSIONS |
| from torchvision.transforms import transforms |
| from torch.utils.data import Dataset |
| import torch |
| import numpy as np |
| import cv2 |
| try: |
| from torchvision.transforms import InterpolationMode |
| interpolation = InterpolationMode.BICUBIC |
| except: |
| import PIL |
| interpolation = PIL.Image.BICUBIC |
| from monai.transforms.transform import LazyTransform, MapTransform, RandomizableTransform |
| import random |
|
|
|
|
| def pil_loader(path): |
| |
| with open(path, 'rb') as f: img: PImage.Image = PImage.open(f).convert('RGB') |
| return img |
|
|
|
|
| class ImageNetDataset(DatasetFolder): |
| def __init__( |
| self, |
| imagenet_folder: str, |
| train: bool, |
| transform: Callable, |
| is_valid_file: Optional[Callable[[str], bool]] = None, |
| ): |
| imagenet_folder = os.path.join(imagenet_folder, 'train' if train else 'val') |
| super(ImageNetDataset, self).__init__( |
| imagenet_folder, |
| loader=pil_loader, |
| extensions=IMG_EXTENSIONS if is_valid_file is None else None, |
| transform=transform, |
| target_transform=None, is_valid_file=is_valid_file |
| ) |
| |
| self.samples = tuple(img for (img, label) in self.samples) |
| self.targets = None |
| |
| def __getitem__(self, index: int) -> Any: |
| img_file_path = self.samples[index] |
| return self.transform(self.loader(img_file_path)) |
|
|
|
|
| def build_dataset_to_pretrain(dataset_path, input_size) -> Dataset: |
| """ |
| You may need to modify this function to return your own dataset. |
| Define a new class, a subclass of `Dataset`, to replace our ImageNetDataset. |
| Use dataset_path to build your image file path list. |
| Use input_size to create the transformation function for your images, can refer to the `trans_train` blow. |
| |
| :param dataset_path: the folder of dataset |
| :param input_size: the input size (image resolution) |
| :return: the dataset used for pretraining |
| """ |
| trans_train = transforms.Compose([ |
| transforms.RandomResizedCrop(input_size, scale=(0.67, 1.0), interpolation=interpolation), |
| transforms.RandomHorizontalFlip(), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), |
| ]) |
| |
| dataset_path = os.path.abspath(dataset_path) |
| for postfix in ('train', 'val'): |
| if dataset_path.endswith(postfix): |
| dataset_path = dataset_path[:-len(postfix)] |
| |
| dataset_train = ImageNetDataset(imagenet_folder=dataset_path, transform=trans_train, train=True) |
| print_transform(trans_train, '[pre-train]') |
| return dataset_train |
|
|
|
|
| def build_meddataset_to_pretrain(dataset_path, input_size) -> Dataset: |
| """ |
| You may need to modify this function to return your own dataset. |
| Define a new class, a subclass of `Dataset`, to replace our ImageNetDataset. |
| Use dataset_path to build your image file path list. |
| Use input_size to create the transformation function for your images, can refer to the `trans_train` blow. |
| |
| :param dataset_path: the folder of dataset |
| :param input_size: the input size (image resolution) |
| :return: the dataset used for pretraining |
| """ |
| trans_train = transforms.Compose([ |
| transforms.RandomResizedCrop(input_size, scale=(0.67, 1.0), interpolation=interpolation), |
| transforms.RandomHorizontalFlip(), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), |
| ]) |
|
|
| dataset_path = os.path.abspath(dataset_path) |
|
|
|
|
| dataset_train = MedicalDataSets(base_dir=dataset_path, transform=trans_train) |
| print_transform(trans_train, '[pre-train]') |
| return dataset_train |
|
|
|
|
|
|
| class MedicalDataSets(Dataset): |
| def __init__( |
| self, |
| base_dir=None, |
| transform=None, |
| ): |
| self._base_dir = base_dir |
| self.sample_list = [] |
| self.sample_list = os.listdir(self._base_dir) |
| self.transform = transform |
| print("total {}".format(len(self.sample_list))) |
|
|
| def __len__(self): |
| return len(self.sample_list) |
|
|
| def __getitem__(self, idx): |
| case = self.sample_list[idx] |
| img = PImage.open(os.path.join(self._base_dir, case)).convert('RGB') |
| aug = self.transform(img) |
| return aug |
|
|
| def print_transform(transform, s): |
| print(f'Transform {s} = ') |
| for t in transform.transforms: |
| print(t) |
| print('---------------------------\n') |
|
|
|
|
| class Sampler(torch.utils.data.Sampler): |
| def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, make_even=True): |
| if num_replicas is None: |
| if not torch.distributed.is_available(): |
| raise RuntimeError("Requires distributed package to be available") |
| num_replicas = torch.distributed.get_world_size() |
| if rank is None: |
| if not torch.distributed.is_available(): |
| raise RuntimeError("Requires distributed package to be available") |
| rank = torch.distributed.get_rank() |
| self.shuffle = shuffle |
| self.make_even = make_even |
| self.dataset = dataset |
| self.num_replicas = num_replicas |
| self.rank = rank |
| self.epoch = 0 |
| self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) |
| self.total_size = self.num_samples * self.num_replicas |
| indices = list(range(len(self.dataset))) |
| self.valid_length = len(indices[self.rank : self.total_size : self.num_replicas]) |
|
|
| def __iter__(self): |
| if self.shuffle: |
| g = torch.Generator() |
| g.manual_seed(self.epoch) |
| indices = torch.randperm(len(self.dataset), generator=g).tolist() |
| else: |
| indices = list(range(len(self.dataset))) |
| if self.make_even: |
| if len(indices) < self.total_size: |
| if self.total_size - len(indices) < len(indices): |
| indices += indices[: (self.total_size - len(indices))] |
| else: |
| extra_ids = np.random.randint(low=0, high=len(indices), size=self.total_size - len(indices)) |
| indices += [indices[ids] for ids in extra_ids] |
| assert len(indices) == self.total_size |
| indices = indices[self.rank : self.total_size : self.num_replicas] |
| self.num_samples = len(indices) |
| return iter(indices) |
|
|
| def __len__(self): |
| return self.num_samples |
|
|
| def set_epoch(self, epoch): |
| self.epoch = epoch |
|
|
| class RandScaleCropdPlusScaleByMidDimSampled(MapTransform): |
| def __init__(self, keys, mode='area', max_size=128,allow_missing_keys=False,num_samples=4,max_radio=0.8,min_radio=0.5): |
| self.keys = keys |
| self.mode = mode |
| self.allow_missing_keys = allow_missing_keys |
| self.max_size=max_size |
| self.num_samples = num_samples |
| self.max_radio=max_radio |
| self.min_radio=min_radio |
|
|
| def __call__(self, data): |
| outputs = [] |
| for i in range(self.num_samples): |
| random_number = round(random.uniform(self.min_radio, self.max_radio), 2) |
| _data = dict(data) |
| for key in self.keys: |
| cropper= med.RandScaleCropd(keys=[key],roi_scale=random_number) |
| _data[key] = cropper(_data)[key] |
| ct_tensor = _data[key] |
| sorted_numbers = sorted(ct_tensor.shape[1:]) |
| scale_factor = self.max_size / sorted_numbers[1] |
| new_size = [int(d * scale_factor) |
| for d in ct_tensor.shape[1:]] |
|
|
| resizer = med.Resized(keys=[key], |
| spatial_size=new_size, |
| mode=self.mode, |
| allow_missing_keys=self.allow_missing_keys) |
| _data[key] = resizer(_data)[key] |
|
|
| outputs.append(_data) |
|
|
| return outputs |
|
|
|
|
|
|
|
|
| def get_loader(data_dir, size): |
| datalist_json = os.path.join(data_dir, "dataset.json") |
| train_transform = med.Compose( |
| [ |
| med.LoadImaged(keys=["image"], allow_missing_keys=True), |
| med.AddChanneld(keys=["image"], allow_missing_keys=True), |
| med.Orientationd(keys=["image"], axcodes="RAS", allow_missing_keys=True), |
| med.Spacingd(keys=["image"], pixdim=(1.5, 1.5, 1.5), mode="bilinear", allow_missing_keys=True), |
| med.ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True), |
| med.CropForegroundd(keys=["image"], source_key="image", allow_missing_keys=True), |
| med.SpatialPadd(keys=["image"], spatial_size=(size, size, size), mode='constant'), |
| med.RandCropByPosNegLabeld( |
| spatial_size=(size, size, size), |
| keys=["image"], |
| label_key="image", |
| pos=1, |
| neg=0, |
| num_samples=4, |
| ), |
| med.RandFlipd(keys=["image"], |
| prob=0.2, |
| spatial_axis=0), |
| med.RandFlipd(keys=["image"], |
| prob=0.2, |
| spatial_axis=1), |
| med.RandFlipd(keys=["image"], |
| prob=0.1, |
| spatial_axis=2), |
| med.ToTensord(keys=["image"]), |
| ]) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| datalist = load_decathlon_datalist(datalist_json, True, "training", base_dir=data_dir) |
| |
| |
| |
| train_ds= data.CacheNTransDataset(data=datalist, transform=train_transform, cache_n_trans=6, cache_dir="/fenghetang/3d/pretrain/MM/cache_dataset") |
| return train_ds |
|
|
|
|
|
|
|
|