| |
| from torch.utils.data import DataLoader |
| from torchvision.io import read_image |
| from torch.utils.data import Dataset |
| from torchvision.transforms import v2 |
| from torchvision import transforms |
| from torchvision import datasets |
| from PIL import Image |
| import pandas as pd |
| import idx2numpy, os |
| import torch |
|
|
| |
| |
| IMAGE_DIMS = 224 |
|
|
| normal_transforms = v2.Compose([ |
| v2.Resize(size=(IMAGE_DIMS, IMAGE_DIMS)), |
| |
| |
| v2.ToDtype(torch.float32), |
| |
| |
| v2.RandomRotation(degrees=(-15, 15)), |
| |
| |
| transforms.Normalize((0.13066047430038452,), (0.30810782313346863,)), |
| ]) |
|
|
|
|
|
|
|
|
| class CustomImageDataset(Dataset): |
| """ |
| This class must inherit from the torch.utils.data.Dataset class. |
| And contina functions __init__, __len__, and __getitem__. |
| """ |
| def __init__(self, annotations_file, img_dir, transform=None, target_transform=None): |
| self.img_labels = pd.read_csv(annotations_file) |
| self.img_dir = img_dir |
| self.transform = transform |
| self.target_transform = target_transform |
|
|
|
|
| def __len__(self): |
| return len(self.img_labels) |
|
|
| def __getitem__(self, idx): |
| """Get the image and label at the index idx.""" |
| img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0]) |
| Image.open(img_path).convert("RGB").save(img_path) |
| image = read_image(img_path) |
| label = self.img_labels.iloc[idx, 1] |
| if self.transform: |
| image = self.transform(image) |
| if self.target_transform: |
| label = self.target_transform(label) |
| return image, label |
|
|
|
|
| train_data = CustomImageDataset("./dataset/root/labels.csv", "./dataset/root/train/", transform=normal_transforms) |
|
|
| |
| |
|
|
| |
| |
| |
|
|
|
|
| train_size = int(0.8 * len(train_data)) |
| test_size = len(train_data) - train_size |
| train_dataset, test_dataset = torch.utils.data.random_split(train_data, [train_size, test_size]) |
|
|
| |
| train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) |
| test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False) |
|
|
| print("Data loader and Test Loaders are ready to be used.") |
|
|
|
|
| |
| |
| |