TCube_Merging / data /fewshot_datasets.py
razaimam45's picture
Upload 108 files
a96891a verified
import math
import os
import json
import random
import numpy as np
import torch
from torch.utils.data import Dataset
import PIL
from PIL import Image
class BaseJsonDataset(Dataset):
def __init__(self, image_path, json_path, mode='train', n_shot=None, transform=None):
self.transform = transform
self.image_path = image_path
self.split_json = json_path
self.mode = mode
self.image_list = []
self.label_list = []
with open(self.split_json) as fp:
splits = json.load(fp)
samples = splits[self.mode]
for s in samples:
self.image_list.append(s[0])
self.label_list.append(s[1])
if n_shot is not None:
few_shot_samples = []
c_range = max(self.label_list) + 1
for c in range(c_range):
c_idx = [idx for idx, lable in enumerate(self.label_list) if lable == c]
random.seed(0)
few_shot_samples.extend(random.sample(c_idx, n_shot))
self.image_list = [self.image_list[i] for i in few_shot_samples]
self.label_list = [self.label_list[i] for i in few_shot_samples]
def __len__(self):
return len(self.image_list)
def __getitem__(self, idx):
image_path = os.path.join(self.image_path, self.image_list[idx])
image = Image.open(image_path).convert('RGB')
label = self.label_list[idx]
if self.transform:
image = self.transform(image)
return image, torch.tensor(label).long()
fewshot_datasets = ['DTD', 'Flower102', 'Food101', 'Cars', 'SUN397',
'Aircraft', 'Pets', 'Caltech101', 'UCF101', 'eurosat']
path_dict = {
# dataset_name: ["image_dir", "json_split_file"]
"flower102": ["jpg", "data/data_splits/split_zhou_OxfordFlowers.json"],
"food101": ["images", "data/data_splits/split_zhou_Food101.json"],
"dtd": ["images", "data/data_splits/split_zhou_DescribableTextures.json"],
"pets": ["", "data/data_splits/split_zhou_OxfordPets.json"],
"sun397": ["", "data/data_splits/split_zhou_SUN397.json"],
"caltech101": ["", "data/data_splits/split_zhou_Caltech101.json"],
"ucf101": ["", "data/data_splits/split_zhou_UCF101.json"],
"cars": ["", "data/data_splits/split_zhou_StanfordCars.json"],
"eurosat": ["", "data/data_splits/split_zhou_EuroSAT.json"]
}
def build_fewshot_dataset(set_id, root, transform, mode='train', n_shot=None):
if set_id.lower() == 'aircraft':
return Aircraft(root, mode, n_shot, transform)
path_suffix, json_path = path_dict[set_id.lower()]
image_path = os.path.join(root, path_suffix)
return BaseJsonDataset(image_path, json_path, mode, n_shot, transform)
class Aircraft(Dataset):
""" FGVC Aircraft dataset """
def __init__(self, root, mode='train', n_shot=None, transform=None):
self.transform = transform
self.path = root
self.mode = mode
self.cname = []
with open(os.path.join(self.path, "variants.txt"), 'r') as fp:
self.cname = [l.replace("\n", "") for l in fp.readlines()]
self.image_list = []
self.label_list = []
with open(os.path.join(self.path, 'images_variant_{:s}.txt'.format(self.mode)), 'r') as fp:
lines = [s.replace("\n", "") for s in fp.readlines()]
for l in lines:
ls = l.split(" ")
img = ls[0]
label = " ".join(ls[1:])
self.image_list.append("{}.jpg".format(img))
self.label_list.append(self.cname.index(label))
if n_shot is not None:
few_shot_samples = []
c_range = max(self.label_list) + 1
for c in range(c_range):
c_idx = [idx for idx, lable in enumerate(self.label_list) if lable == c]
random.seed(0)
few_shot_samples.extend(random.sample(c_idx, n_shot))
self.image_list = [self.image_list[i] for i in few_shot_samples]
self.label_list = [self.label_list[i] for i in few_shot_samples]
def __len__(self):
return len(self.image_list)
def __getitem__(self, idx):
image_path = os.path.join(self.path, 'images', self.image_list[idx])
image = Image.open(image_path).convert('RGB')
label = self.label_list[idx]
if self.transform:
image = self.transform(image)
return image, torch.tensor(label).long()