| r""" Dataloader builder for few-shot semantic segmentation dataset """ |
| from torch.utils.data.distributed import DistributedSampler as Sampler |
| from torch.utils.data import DataLoader |
| from torchvision import transforms |
|
|
| from data.pascal import DatasetPASCAL |
| from data.coco import DatasetCOCO |
| from data.fss import DatasetFSS |
| from data.deepglobe import DatasetDeepglobe |
| from data.isic import DatasetISIC |
| from data.lung import DatasetLung |
| from data.fss import DatasetFSS |
| from data.suim import DatasetSUIM |
|
|
|
|
| class FSSDataset: |
|
|
| @classmethod |
| def initialize(cls, img_size, datapath): |
|
|
| cls.datasets = { |
| 'pascal': DatasetPASCAL, |
| 'coco': DatasetCOCO, |
| 'fss': DatasetFSS, |
| 'deepglobe': DatasetDeepglobe, |
| 'isic': DatasetISIC, |
| 'lung': DatasetLung, |
| 'suim': DatasetSUIM |
| } |
|
|
| cls.img_mean = [0.485, 0.456, 0.406] |
| cls.img_std = [0.229, 0.224, 0.225] |
| cls.datapath = datapath |
|
|
| cls.transform = transforms.Compose([transforms.Resize(size=(img_size, img_size)), |
| transforms.ToTensor(), |
| transforms.Normalize(cls.img_mean, cls.img_std)]) |
|
|
| @classmethod |
| def build_dataloader(cls, benchmark, bsz, nworker, fold, split, shot=1): |
| nworker = nworker if split == 'trn' else 0 |
|
|
| dataset = cls.datasets[benchmark](cls.datapath, fold=fold, |
| transform=cls.transform, |
| split=split, shot=shot) |
| |
| |
| |
| dataloader = DataLoader(dataset, batch_size=bsz, shuffle=split=='trn', num_workers=nworker, |
| pin_memory=True) |
|
|
| return dataloader |
|
|