jsflow / REG /dataset.py
xiangzai's picture
Add files using upload-large-folder tool
b65e56d verified
import os
import json
import torch
from torch.utils.data import Dataset
import numpy as np
from PIL import Image
import PIL.Image
try:
import pyspng
except ImportError:
pyspng = None
class CustomDataset(Dataset):
"""
data_dir 下 VAE latent:imagenet_256_vae/
无预处理语义时:VAE 统计量/配对文件在 vae-sd/(与原 REG 一致)。
有 semantic_features_dir 时:与主仓库 dataset 一致,从该目录 dataset.json 索引,
按特征文件名推断 imagenet_256_vae 中对应 npy。
"""
def __init__(self, data_dir, semantic_features_dir=None):
PIL.Image.init()
supported_ext = PIL.Image.EXTENSION.keys() | {'.npy'}
self.images_dir = os.path.join(data_dir, 'imagenet_256_vae')
if semantic_features_dir is None:
potential_semantic_dir = os.path.join(
data_dir, 'imagenet_256_features', 'dinov2-vit-b_tmp', 'gpu0'
)
if os.path.exists(potential_semantic_dir):
self.semantic_features_dir = potential_semantic_dir
self.use_preprocessed_semantic = True
print(f"Found preprocessed semantic features at: {self.semantic_features_dir}")
else:
self.semantic_features_dir = None
self.use_preprocessed_semantic = False
else:
self.semantic_features_dir = semantic_features_dir
self.use_preprocessed_semantic = True
print(f"Using preprocessed semantic features from: {self.semantic_features_dir}")
if self.use_preprocessed_semantic:
label_fname = os.path.join(self.semantic_features_dir, 'dataset.json')
if not os.path.exists(label_fname):
raise FileNotFoundError(f"Label file not found: {label_fname}")
print(f"Using {label_fname}.")
with open(label_fname, 'rb') as f:
data = json.load(f)
labels_list = data.get('labels', None)
if labels_list is None:
raise ValueError(f"'labels' field is missing in {label_fname}")
semantic_fnames = []
labels = []
for entry in labels_list:
if entry is None:
continue
fname, lab = entry
semantic_fnames.append(fname)
labels.append(0 if lab is None else lab)
self.semantic_fnames = semantic_fnames
self.labels = np.array(labels, dtype=np.int64)
self.num_samples = len(self.semantic_fnames)
print(f"Loaded {self.num_samples} semantic entries from dataset.json")
else:
self.features_dir = os.path.join(data_dir, 'vae-sd')
self._image_fnames = {
os.path.relpath(os.path.join(root, fname), start=self.images_dir)
for root, _dirs, files in os.walk(self.images_dir) for fname in files
}
self.image_fnames = sorted(
fname for fname in self._image_fnames if self._file_ext(fname) in supported_ext
)
self._feature_fnames = {
os.path.relpath(os.path.join(root, fname), start=self.features_dir)
for root, _dirs, files in os.walk(self.features_dir) for fname in files
}
self.feature_fnames = sorted(
fname for fname in self._feature_fnames if self._file_ext(fname) in supported_ext
)
fname = os.path.join(self.features_dir, 'dataset.json')
if os.path.exists(fname):
print(f"Using {fname}.")
else:
raise FileNotFoundError("Neither of the specified files exists.")
with open(fname, 'rb') as f:
labels = json.load(f)['labels']
labels = dict(labels)
labels = [labels[fname.replace('\\', '/')] for fname in self.feature_fnames]
labels = np.array(labels)
self.labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])
def _file_ext(self, fname):
return os.path.splitext(fname)[1].lower()
def __len__(self):
if self.use_preprocessed_semantic:
return self.num_samples
assert len(self.image_fnames) == len(self.feature_fnames), \
"Number of feature files and label files should be same"
return len(self.feature_fnames)
def __getitem__(self, idx):
if self.use_preprocessed_semantic:
semantic_fname = self.semantic_fnames[idx]
basename = os.path.basename(semantic_fname)
idx_str = basename.split('-')[-1].split('.')[0]
subdir = idx_str[:5]
vae_relpath = os.path.join(subdir, f"img-mean-std-{idx_str}.npy")
vae_path = os.path.join(self.images_dir, vae_relpath)
with open(vae_path, 'rb') as f:
image = np.load(f)
semantic_path = os.path.join(self.semantic_features_dir, semantic_fname)
semantic_features = np.load(semantic_path)
return (
torch.from_numpy(image).float(),
torch.from_numpy(image).float(),
torch.from_numpy(semantic_features).float(),
torch.tensor(self.labels[idx]),
)
image_fname = self.image_fnames[idx]
feature_fname = self.feature_fnames[idx]
image_ext = self._file_ext(image_fname)
with open(os.path.join(self.images_dir, image_fname), 'rb') as f:
if image_ext == '.npy':
image = np.load(f)
image = image.reshape(-1, *image.shape[-2:])
elif image_ext == '.png' and pyspng is not None:
image = pyspng.load(f.read())
image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
else:
image = np.array(PIL.Image.open(f))
image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
features = np.load(os.path.join(self.features_dir, feature_fname))
return torch.from_numpy(image), torch.from_numpy(features), torch.tensor(self.labels[idx])