import os import json import torch from torch.utils.data import Dataset import numpy as np from PIL import Image import PIL.Image try: import pyspng except ImportError: pyspng = None class CustomDataset(Dataset): """ data_dir 下 VAE latent:imagenet_256_vae/ 无预处理语义时:VAE 统计量/配对文件在 vae-sd/(与原 REG 一致)。 有 semantic_features_dir 时:与主仓库 dataset 一致,从该目录 dataset.json 索引, 按特征文件名推断 imagenet_256_vae 中对应 npy。 """ def __init__(self, data_dir, semantic_features_dir=None): PIL.Image.init() supported_ext = PIL.Image.EXTENSION.keys() | {'.npy'} self.images_dir = os.path.join(data_dir, 'imagenet_256_vae') if semantic_features_dir is None: potential_semantic_dir = os.path.join( data_dir, 'imagenet_256_features', 'dinov2-vit-b_tmp', 'gpu0' ) if os.path.exists(potential_semantic_dir): self.semantic_features_dir = potential_semantic_dir self.use_preprocessed_semantic = True print(f"Found preprocessed semantic features at: {self.semantic_features_dir}") else: self.semantic_features_dir = None self.use_preprocessed_semantic = False else: self.semantic_features_dir = semantic_features_dir self.use_preprocessed_semantic = True print(f"Using preprocessed semantic features from: {self.semantic_features_dir}") if self.use_preprocessed_semantic: label_fname = os.path.join(self.semantic_features_dir, 'dataset.json') if not os.path.exists(label_fname): raise FileNotFoundError(f"Label file not found: {label_fname}") print(f"Using {label_fname}.") with open(label_fname, 'rb') as f: data = json.load(f) labels_list = data.get('labels', None) if labels_list is None: raise ValueError(f"'labels' field is missing in {label_fname}") semantic_fnames = [] labels = [] for entry in labels_list: if entry is None: continue fname, lab = entry semantic_fnames.append(fname) labels.append(0 if lab is None else lab) self.semantic_fnames = semantic_fnames self.labels = np.array(labels, dtype=np.int64) self.num_samples = len(self.semantic_fnames) print(f"Loaded {self.num_samples} semantic entries from dataset.json") else: self.features_dir = os.path.join(data_dir, 'vae-sd') self._image_fnames = { os.path.relpath(os.path.join(root, fname), start=self.images_dir) for root, _dirs, files in os.walk(self.images_dir) for fname in files } self.image_fnames = sorted( fname for fname in self._image_fnames if self._file_ext(fname) in supported_ext ) self._feature_fnames = { os.path.relpath(os.path.join(root, fname), start=self.features_dir) for root, _dirs, files in os.walk(self.features_dir) for fname in files } self.feature_fnames = sorted( fname for fname in self._feature_fnames if self._file_ext(fname) in supported_ext ) fname = os.path.join(self.features_dir, 'dataset.json') if os.path.exists(fname): print(f"Using {fname}.") else: raise FileNotFoundError("Neither of the specified files exists.") with open(fname, 'rb') as f: labels = json.load(f)['labels'] labels = dict(labels) labels = [labels[fname.replace('\\', '/')] for fname in self.feature_fnames] labels = np.array(labels) self.labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim]) def _file_ext(self, fname): return os.path.splitext(fname)[1].lower() def __len__(self): if self.use_preprocessed_semantic: return self.num_samples assert len(self.image_fnames) == len(self.feature_fnames), \ "Number of feature files and label files should be same" return len(self.feature_fnames) def __getitem__(self, idx): if self.use_preprocessed_semantic: semantic_fname = self.semantic_fnames[idx] basename = os.path.basename(semantic_fname) idx_str = basename.split('-')[-1].split('.')[0] subdir = idx_str[:5] vae_relpath = os.path.join(subdir, f"img-mean-std-{idx_str}.npy") vae_path = os.path.join(self.images_dir, vae_relpath) with open(vae_path, 'rb') as f: image = np.load(f) semantic_path = os.path.join(self.semantic_features_dir, semantic_fname) semantic_features = np.load(semantic_path) return ( torch.from_numpy(image).float(), torch.from_numpy(image).float(), torch.from_numpy(semantic_features).float(), torch.tensor(self.labels[idx]), ) image_fname = self.image_fnames[idx] feature_fname = self.feature_fnames[idx] image_ext = self._file_ext(image_fname) with open(os.path.join(self.images_dir, image_fname), 'rb') as f: if image_ext == '.npy': image = np.load(f) image = image.reshape(-1, *image.shape[-2:]) elif image_ext == '.png' and pyspng is not None: image = pyspng.load(f.read()) image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1) else: image = np.array(PIL.Image.open(f)) image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1) features = np.load(os.path.join(self.features_dir, feature_fname)) return torch.from_numpy(image), torch.from_numpy(features), torch.tensor(self.labels[idx])