| from torch.utils.data import Dataset |
| import os |
| import pathlib |
| import torch |
|
|
| from PIL import Image |
| from torch.utils.data import Dataset |
| from torchvision import transforms |
| from typing import Tuple, Dict, List |
|
|
| import torch.utils.data as data |
| import numpy as np |
| |
| import random |
| |
| |
|
|
| |
| def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]: |
| """Finds the class folder names in a target directory. |
| |
| Assumes target directory is in standard image classification format. |
| |
| Args: |
| directory (str): target directory to load classnames from. |
| |
| Returns: |
| Tuple[List[str], Dict[str, int]]: (list_of_class_names, dict(class_name: idx...)) |
| |
| Example: |
| find_classes("food_images/train") |
| >>> (["class_1", "class_2"], {"class_1": 0, ...}) |
| """ |
| |
| classes = sorted([entry.name for entry in os.scandir(directory) if entry.is_dir()]) |
| |
| if not classes: |
| raise FileNotFoundError(f"Couldn't find any classes in {directory}.") |
| |
| |
| class_to_idx = {cls_name: int(cls_name) for cls_name in (classes)} |
| return classes, class_to_idx |
|
|
|
|
| class SamData(Dataset): |
| |
| |
| def __init__(self, targ_dir: str, transform=None) -> None: |
| |
| |
| |
| self.paths = sorted(list(pathlib.Path(targ_dir).glob("*/*.jpg"))) |
| |
| |
| self.indexes = [] |
| self.folds = [] |
| for i, n in enumerate(self.paths): |
| |
| strrr= str(n) |
| |
| self.indexes.append(int(strrr[strrr.index('sa_')+13:strrr.index('.jpg')])) |
| self.folds.append(strrr[strrr.index('sa_')+3:strrr.index('sa_')+9]) |
|
|
| self.transform = transform |
| |
| |
|
|
| |
| def load_image(self, index: int) -> Image.Image: |
| "Opens an image via a path and returns it." |
| image_path = self.paths[index] |
| return Image.open(image_path) |
| |
| |
| def __len__(self) -> int: |
| "Returns the total number of samples." |
| return len(self.paths) |
| |
| |
| def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]: |
| "Returns one sample of data, data, label (X, y, index)." |
| img = self.load_image(index) |
|
|
| indx = self.indexes[index] |
| |
| |
| |
| |
| |
| if self.transform: |
| return self.transform(img), indx |
| else: |
| return img, indx |