| import json |
| import os.path |
|
|
| from PIL import Image |
| from torch.utils.data import DataLoader |
|
|
| from transformers import CLIPProcessor |
| from torchvision.transforms import transforms |
|
|
| import pytorch_lightning as pl |
|
|
|
|
| class WikiArtDataset(): |
| def __init__(self, meta_file): |
| super(WikiArtDataset, self).__init__() |
|
|
| self.files = [] |
| with open(meta_file, 'r') as f: |
| js = json.load(f) |
| for img_path in js: |
| img_name = os.path.splitext(os.path.basename(img_path))[0] |
| caption = img_name.split('_')[-1] |
| caption = caption.split('-') |
| j = len(caption) - 1 |
| while j >= 0: |
| if not caption[j].isdigit(): |
| break |
| j -= 1 |
| if j < 0: |
| continue |
| sentence = ' '.join(caption[:j + 1]) |
| self.files.append({'img_path': os.path.join('datasets/wikiart', img_path), 'sentence': sentence}) |
|
|
| version = 'openai/clip-vit-large-patch14' |
| self.processor = CLIPProcessor.from_pretrained(version) |
|
|
| self.jpg_transform = transforms.Compose([ |
| transforms.Resize(512), |
| transforms.RandomCrop(512), |
| transforms.ToTensor(), |
| ]) |
|
|
| def __getitem__(self, idx): |
| file = self.files[idx] |
|
|
| im = Image.open(file['img_path']) |
|
|
| im_tensor = self.jpg_transform(im) |
|
|
| clip_im = self.processor(images=im, return_tensors="pt")['pixel_values'][0] |
|
|
| return {'jpg': im_tensor, 'style': clip_im, 'txt': file['sentence']} |
|
|
| def __len__(self): |
| return len(self.files) |
|
|
|
|
| class WikiArtDataModule(pl.LightningDataModule): |
| def __init__(self, meta_file, batch_size, num_workers): |
| super(WikiArtDataModule, self).__init__() |
| self.train_dataset = WikiArtDataset(meta_file) |
| self.batch_size = batch_size |
| self.num_workers = num_workers |
|
|
| def train_dataloader(self): |
| return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, |
| pin_memory=True) |
|
|