| import os |
| import json |
|
|
| import torch |
| from torch.utils.data import Dataset |
| import numpy as np |
|
|
| from PIL import Image |
| import PIL.Image |
| try: |
| import pyspng |
| except ImportError: |
| pyspng = None |
|
|
|
|
| class CustomDataset(Dataset): |
| """ |
| data_dir 下 VAE latent:imagenet_256_vae/ |
| 无预处理语义时:VAE 统计量/配对文件在 vae-sd/(与原 REG 一致)。 |
| 有 semantic_features_dir 时:与主仓库 dataset 一致,从该目录 dataset.json 索引, |
| 按特征文件名推断 imagenet_256_vae 中对应 npy。 |
| """ |
|
|
| def __init__(self, data_dir, semantic_features_dir=None): |
| PIL.Image.init() |
| supported_ext = PIL.Image.EXTENSION.keys() | {'.npy'} |
|
|
| self.images_dir = os.path.join(data_dir, 'imagenet_256_vae') |
|
|
| if semantic_features_dir is None: |
| potential_semantic_dir = os.path.join( |
| data_dir, 'imagenet_256_features', 'dinov2-vit-b_tmp', 'gpu0' |
| ) |
| if os.path.exists(potential_semantic_dir): |
| self.semantic_features_dir = potential_semantic_dir |
| self.use_preprocessed_semantic = True |
| print(f"Found preprocessed semantic features at: {self.semantic_features_dir}") |
| else: |
| self.semantic_features_dir = None |
| self.use_preprocessed_semantic = False |
| else: |
| self.semantic_features_dir = semantic_features_dir |
| self.use_preprocessed_semantic = True |
| print(f"Using preprocessed semantic features from: {self.semantic_features_dir}") |
|
|
| if self.use_preprocessed_semantic: |
| label_fname = os.path.join(self.semantic_features_dir, 'dataset.json') |
| if not os.path.exists(label_fname): |
| raise FileNotFoundError(f"Label file not found: {label_fname}") |
|
|
| print(f"Using {label_fname}.") |
| with open(label_fname, 'rb') as f: |
| data = json.load(f) |
| labels_list = data.get('labels', None) |
| if labels_list is None: |
| raise ValueError(f"'labels' field is missing in {label_fname}") |
|
|
| semantic_fnames = [] |
| labels = [] |
| for entry in labels_list: |
| if entry is None: |
| continue |
| fname, lab = entry |
| semantic_fnames.append(fname) |
| labels.append(0 if lab is None else lab) |
|
|
| self.semantic_fnames = semantic_fnames |
| self.labels = np.array(labels, dtype=np.int64) |
| self.num_samples = len(self.semantic_fnames) |
| print(f"Loaded {self.num_samples} semantic entries from dataset.json") |
| else: |
| self.features_dir = os.path.join(data_dir, 'vae-sd') |
|
|
| self._image_fnames = { |
| os.path.relpath(os.path.join(root, fname), start=self.images_dir) |
| for root, _dirs, files in os.walk(self.images_dir) for fname in files |
| } |
| self.image_fnames = sorted( |
| fname for fname in self._image_fnames if self._file_ext(fname) in supported_ext |
| ) |
| self._feature_fnames = { |
| os.path.relpath(os.path.join(root, fname), start=self.features_dir) |
| for root, _dirs, files in os.walk(self.features_dir) for fname in files |
| } |
| self.feature_fnames = sorted( |
| fname for fname in self._feature_fnames if self._file_ext(fname) in supported_ext |
| ) |
|
|
| fname = os.path.join(self.features_dir, 'dataset.json') |
| if os.path.exists(fname): |
| print(f"Using {fname}.") |
| else: |
| raise FileNotFoundError("Neither of the specified files exists.") |
|
|
| with open(fname, 'rb') as f: |
| labels = json.load(f)['labels'] |
| labels = dict(labels) |
| labels = [labels[fname.replace('\\', '/')] for fname in self.feature_fnames] |
| labels = np.array(labels) |
| self.labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim]) |
|
|
| def _file_ext(self, fname): |
| return os.path.splitext(fname)[1].lower() |
|
|
| def __len__(self): |
| if self.use_preprocessed_semantic: |
| return self.num_samples |
| assert len(self.image_fnames) == len(self.feature_fnames), \ |
| "Number of feature files and label files should be same" |
| return len(self.feature_fnames) |
|
|
| def __getitem__(self, idx): |
| if self.use_preprocessed_semantic: |
| semantic_fname = self.semantic_fnames[idx] |
| basename = os.path.basename(semantic_fname) |
| idx_str = basename.split('-')[-1].split('.')[0] |
| subdir = idx_str[:5] |
| vae_relpath = os.path.join(subdir, f"img-mean-std-{idx_str}.npy") |
| vae_path = os.path.join(self.images_dir, vae_relpath) |
|
|
| with open(vae_path, 'rb') as f: |
| image = np.load(f) |
|
|
| semantic_path = os.path.join(self.semantic_features_dir, semantic_fname) |
| semantic_features = np.load(semantic_path) |
|
|
| return ( |
| torch.from_numpy(image).float(), |
| torch.from_numpy(image).float(), |
| torch.from_numpy(semantic_features).float(), |
| torch.tensor(self.labels[idx]), |
| ) |
|
|
| image_fname = self.image_fnames[idx] |
| feature_fname = self.feature_fnames[idx] |
| image_ext = self._file_ext(image_fname) |
| with open(os.path.join(self.images_dir, image_fname), 'rb') as f: |
| if image_ext == '.npy': |
| image = np.load(f) |
| image = image.reshape(-1, *image.shape[-2:]) |
| elif image_ext == '.png' and pyspng is not None: |
| image = pyspng.load(f.read()) |
| image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1) |
| else: |
| image = np.array(PIL.Image.open(f)) |
| image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1) |
|
|
| features = np.load(os.path.join(self.features_dir, feature_fname)) |
| return torch.from_numpy(image), torch.from_numpy(features), torch.tensor(self.labels[idx]) |
|
|