| import numpy as np |
| import torch |
| from torch.utils.data import Dataset |
| from torchvision.datasets import DatasetFolder |
|
|
|
|
| class FBanksTripletDataset(Dataset): |
| def __init__(self, root): |
| self.dataset_folder = DatasetFolder(root=root, loader=FBanksTripletDataset._npy_loader, extensions='.npy') |
| self.len_ = len(self.dataset_folder.samples) |
| bin_counts = np.bincount(self.dataset_folder.targets) |
| self.num_classes = len(self.dataset_folder.classes) |
| self.label_to_index_range = {} |
| start = 0 |
| for i in range(self.num_classes): |
| self.label_to_index_range[i] = (start, start + bin_counts[i]) |
| start = start + bin_counts[i] |
|
|
| @staticmethod |
| def _npy_loader(path): |
| sample = np.load(path) |
| assert sample.shape[0] == 64 |
| assert sample.shape[1] == 64 |
| assert sample.shape[2] == 1 |
|
|
| sample = np.moveaxis(sample, 2, 0) |
| sample = torch.from_numpy(sample).float() |
|
|
| return sample |
|
|
| def __getitem__(self, index): |
| anchor_x, anchor_y = self.dataset_folder[index] |
|
|
| |
| start, end = self.label_to_index_range[anchor_y] |
| i = np.random.randint(low=start, high=end) |
| positive_x, positive_y = self.dataset_folder[i] |
|
|
| |
| l_ = list(range(self.num_classes)) |
| l_.pop(anchor_y) |
| ny_ = np.random.choice(l_) |
| start, end = self.label_to_index_range[ny_] |
| i = np.random.randint(low=start, high=end) |
| negative_x, negative_y = self.dataset_folder[i] |
|
|
| return (anchor_x, anchor_y), (positive_x, positive_y), (negative_x, negative_y) |
|
|
| def __len__(self): |
| return self.len_ |