| import json |
| import cv2 |
| import numpy as np |
|
|
| from torch.utils.data import Dataset |
| from utils.config import * |
|
|
| class MyDataset(Dataset): |
| def __init__(self): |
| json_path = dataset_root + "train.json" |
| with open(json_path, 'rt') as f: |
| res = json.load(f) |
| self.data = res |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def __getitem__(self, idx): |
| item = self.data[idx] |
|
|
| source_filename = item['pose'] |
| target_filename = item['gt'] |
| prompt = item['caption'] |
|
|
| source = cv2.imread(dataset_root + source_filename) |
| target = cv2.imread(dataset_root + target_filename) |
|
|
| |
| source = cv2.cvtColor(source, cv2.COLOR_BGR2RGB) |
| target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB) |
|
|
| |
| source = source.astype(np.float32) / 255.0 |
|
|
| |
| target = (target.astype(np.float32) / 127.5) - 1.0 |
|
|
| return dict(jpg=target, txt=prompt, hint=source) |
| |
|
|