| import os |
| import pickle |
| import random |
| from pathlib import Path |
| from typing import Dict |
| from typing import List |
|
|
| import torch |
| from loguru import logger |
| from PIL import Image |
| from torch.utils.data import Dataset |
| from torchvision import transforms |
|
|
| from configs.mode import FaceSwapMode |
| from configs.train_config import TrainConfig |
|
|
|
|
| class ManyToManyTrainDataset(Dataset): |
| def __init__(self, dataset_root: str, dataset_index: str, same_rate=0.5): |
| """ |
| Many-to-many 训练数据集构建 |
| Parameters: |
| ----------- |
| dataset_root: str, 数据集根目录 |
| dataset_index: str, 数据集index文件路径 |
| same_rate: float, 每个batch里面相同人脸所占的比例 |
| """ |
| super(ManyToManyTrainDataset, self).__init__() |
| self.transform = transforms.Compose( |
| [ |
| transforms.Resize((256, 256)), |
| transforms.CenterCrop((256, 256)), |
| transforms.ToTensor(), |
| ] |
| ) |
| self.data_root = Path(dataset_root) |
| with open(dataset_index, "rb") as f: |
| self.file_index = pickle.load(f, encoding="bytes") |
|
|
| self.same_rate = same_rate |
|
|
| self.id_list: List[str] = list(self.file_index.keys()) |
|
|
| |
| self.length = len(self.id_list) |
| self.image_num = sum([len(v) for v in self.file_index.values()]) |
|
|
| self.mask_dir = "mask" if TrainConfig().mouth_mask else "mask_no_mouth" |
| logger.info(f"dataset contains {self.length} ids and {self.image_num} images") |
| logger.info(f"will use mask mode: {self.mask_dir}") |
|
|
| def __len__(self): |
| return self.length |
|
|
| def __getitem__(self, index): |
| source_id_index = index |
| source_file = random.choice(self.file_index[self.id_list[source_id_index]]) |
| if random.random() < self.same_rate: |
| |
| target_file = random.choice(self.file_index[self.id_list[source_id_index]]) |
| same = torch.ones(1) |
| else: |
| |
| target_id_index = random.choice(list(set(range(self.length)) - set([source_id_index]))) |
| target_file = random.choice(self.file_index[self.id_list[target_id_index]]) |
| same = torch.zeros(1) |
|
|
| source_file = self.data_root / Path(source_file) |
| target_file = self.data_root / Path(target_file) |
| target_mask_file = target_file.parent.parent.parent / self.mask_dir / target_file.parent.stem / target_file.name |
|
|
| target_img = Image.open(target_file.as_posix()).convert("RGB") |
| source_img = Image.open(source_file.as_posix()).convert("RGB") |
|
|
| target_mask = Image.open(target_mask_file.as_posix()).convert("RGB") |
|
|
| source_img = self.transform(source_img) |
| target_img = self.transform(target_img) |
| target_mask = self.transform(target_mask)[0, :, :].unsqueeze(0) |
|
|
| return { |
| "source_image": source_img, |
| "target_image": target_img, |
| "target_mask": target_mask, |
| "same": same, |
| |
| |
| |
| } |
|
|
|
|
| class OneToManyTrainDataset(Dataset): |
| def __init__(self, dataset_root: str, dataset_index: str, source_name: str, same_rate=0.5): |
| """ |
| One-to-many 训练数据集构建 |
| Parameters: |
| ----------- |
| dataset_root: str, 数据集根目录 |
| dataset_index: str, 数据集index文件路径 |
| source_name: str, source face id的名称, one-to-many里面的one |
| same_rate: float, 每个batch里面相同人脸所占的比例 |
| """ |
| super(OneToManyTrainDataset, self).__init__() |
| self.transform = transforms.Compose( |
| [ |
| transforms.Resize((256, 256)), |
| transforms.CenterCrop((256, 256)), |
| transforms.ToTensor(), |
| ] |
| ) |
| self.data_root = Path(dataset_root) |
| with open(dataset_index, "rb") as f: |
| self.file_index = pickle.load(f, encoding="bytes") |
| self.same_rate = same_rate |
| self.source_name = source_name |
|
|
| self.id_list: List[str] = list(self.file_index.keys()) |
|
|
| try: |
| self.source_id_index: int = self.id_list.index(self.source_name) |
| except Exception: |
| raise Exception(f"{self.source_name} not in dataset dir") |
|
|
| |
| self.length = len(self.id_list) |
| self.image_num = sum([len(v) for v in self.file_index.values()]) |
| self.mask_dir = "mask" if TrainConfig().mouth_mask else "mask_no_mouth" |
| logger.info(f"dataset contains {self.length} ids and {self.image_num} images") |
| logger.info(f"will use mask mode: {self.mask_dir}") |
|
|
| def __len__(self): |
| return self.length |
|
|
| def __getitem__(self, index): |
| target_id_index = index |
| target_file = random.choice(self.file_index[self.id_list[target_id_index]]) |
| if random.random() < self.same_rate: |
| |
| source_file = random.choice(self.file_index[self.id_list[target_id_index]]) |
| same = torch.ones(1) |
| else: |
| |
| source_file = random.choice(self.file_index[self.source_name]) |
| |
| if self.source_id_index == target_id_index: |
| same = torch.ones(1) |
| else: |
| same = torch.zeros(1) |
|
|
| source_file = self.data_root / Path(source_file) |
| target_file = self.data_root / Path(target_file) |
| target_mask_file = target_file.parent.parent.parent / self.mask_dir / target_file.parent.stem / target_file.name |
|
|
| target_img = Image.open(target_file.as_posix()).convert("RGB") |
| source_img = Image.open(source_file.as_posix()).convert("RGB") |
|
|
| target_mask = Image.open(target_mask_file.as_posix()).convert("RGB") |
|
|
| source_img = self.transform(source_img) |
| target_img = self.transform(target_img) |
| target_mask = self.transform(target_mask)[0, :, :].unsqueeze(0) |
|
|
| return { |
| "source_image": source_img, |
| "target_image": target_img, |
| "target_mask": target_mask, |
| "same": same, |
| |
| |
| |
| } |
|
|
|
|
| class TrainDatasetDataLoader: |
| """Wrapper class of Dataset class that performs multi-threaded data loading""" |
|
|
| def __init__(self): |
| """Initialize this class""" |
| opt = TrainConfig() |
| if opt.mode is FaceSwapMode.MANY_TO_MANY: |
| self.dataset = ManyToManyTrainDataset(opt.dataset_root, opt.dataset_index, opt.same_rate) |
| elif opt.mode is FaceSwapMode.ONE_TO_MANY: |
| logger.info(f"In one-to-many mode, source face is {opt.source_name}") |
| self.dataset = OneToManyTrainDataset(opt.dataset_root, opt.dataset_index, opt.source_name, opt.same_rate) |
| else: |
| raise NotImplementedError |
| logger.info(f"dataset {type(self.dataset).__name__} created") |
| if opt.use_ddp: |
| self.train_sampler = torch.utils.data.distributed.DistributedSampler(self.dataset, shuffle=True) |
| self.dataloader = torch.utils.data.DataLoader( |
| self.dataset, |
| batch_size=opt.batch_size, |
| num_workers=int(opt.num_threads), |
| drop_last=True, |
| sampler=self.train_sampler, |
| pin_memory=True, |
| ) |
| else: |
| self.dataloader = torch.utils.data.DataLoader( |
| self.dataset, |
| batch_size=opt.batch_size, |
| shuffle=True, |
| num_workers=int(opt.num_threads), |
| drop_last=True, |
| pin_memory=True, |
| ) |
|
|
| def load_data(self): |
| return self |
|
|
| def __len__(self): |
| """Return the number of data in the dataset""" |
| return len(self.dataset) |
|
|
| def __iter__(self): |
| """Return a batch of data""" |
| for data in self.dataloader: |
| yield data |
|
|
|
|
| if __name__ == "__main__": |
| dataloader = TrainDatasetDataLoader() |
| for idx, data in enumerate(dataloader): |
| |
| |
| |
| print(data["same"]) |
|
|