| import json |
| import os |
| from sklearn.model_selection import train_test_split |
|
|
| from monai.data import DataLoader, Dataset |
| from monai import transforms |
|
|
| def datafold_read(datalist, basedir, fold=0, key="training"): |
| with open(datalist) as f: |
| json_data = json.load(f) |
|
|
| json_data = json_data[key] |
|
|
| for d in json_data: |
| for k in d: |
| if isinstance(d[k], list): |
| d[k] = [os.path.join(basedir, iv) for iv in d[k]] |
| elif isinstance(d[k], str): |
| d[k] = os.path.join(basedir, d[k]) if len(d[k]) > 0 else d[k] |
|
|
| tr = [] |
| val = [] |
| for d in json_data: |
| if "fold" in d and d["fold"] == fold: |
| val.append(d) |
| else: |
| tr.append(d) |
|
|
| return tr, val |
|
|
|
|
| def split_train_test(datalist, basedir, fold,test_size = 0.2, volume : float = None) : |
| train_files, _ = datafold_read(datalist=datalist, basedir=basedir, fold=fold) |
| if volume != None : |
| train_files, _ = train_test_split(train_files,test_size=volume,random_state=42) |
| |
| train_files,validation_files = train_test_split(train_files,test_size=test_size, random_state=42) |
| |
| validation_files,test_files = train_test_split(validation_files,test_size=test_size, random_state=42) |
| return train_files, validation_files, test_files |
|
|
|
|
| def get_loader(batch_size, data_dir, json_list, fold, roi,volume :float = None,test_size = 0.2): |
| train_files,validation_files,test_files = split_train_test(datalist = json_list,basedir = data_dir,test_size=test_size,fold = fold,volume= volume) |
| |
| train_transform = transforms.Compose( |
| [ |
| transforms.LoadImaged(keys=["image", "label"]), |
| transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys="label"), |
| transforms.CropForegroundd( |
| keys=["image", "label"], |
| source_key="image", |
| k_divisible=[roi[0], roi[1], roi[2]], |
| ), |
| transforms.RandSpatialCropd( |
| keys=["image", "label"], |
| roi_size=[roi[0], roi[1], roi[2]], |
| random_size=False, |
| ), |
| transforms.RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0), |
| transforms.RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1), |
| transforms.RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2), |
| transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), |
| transforms.RandScaleIntensityd(keys="image", factors=0.1, prob=1.0), |
| transforms.RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0), |
| ] |
| ) |
| val_transform = transforms.Compose( |
| [ |
| transforms.LoadImaged(keys=["image", "label"]), |
| transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys="label"), |
| transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), |
| ] |
| ) |
|
|
| train_ds = Dataset(data=train_files, transform=train_transform) |
| train_loader = DataLoader( |
| train_ds, |
| batch_size=batch_size, |
| shuffle=True, |
| num_workers=2, |
| pin_memory=True, |
| ) |
| val_ds = Dataset(data=validation_files, transform=val_transform) |
| val_loader = DataLoader( |
| val_ds, |
| batch_size=1, |
| shuffle=False, |
| num_workers=2, |
| pin_memory=True, |
| ) |
| test_ds = Dataset(data=test_files, transform=val_transform) |
| test_loader = DataLoader( |
| test_ds, |
| batch_size=1, |
| shuffle=False, |
| num_workers=2, |
| pin_memory=True, |
| ) |
| return train_loader, val_loader,test_loader |