| import torch |
| import torchvision |
| import numpy as np |
| from PIL import Image |
| import os |
| from torchvision import transforms |
| import pandas as pd |
|
|
|
|
| from torch.utils.data import Dataset |
|
|
|
|
|
|
| tfms_normal = transforms.Compose([ |
| transforms.CenterCrop(size=(256,256)), |
| transforms.ToTensor() |
| |
| ]) |
|
|
| tfms_target = transforms.CenterCrop(size = (256,256)) |
| |
|
|
| class Data_provider_SYSU(Dataset): |
|
|
|
|
| def __init__(self, path): |
| self.data_path = path |
| self.pre_path = os.path.join(path, "time1") |
| self.post_path = os.path.join(path, "time2") |
| self.target_path = os.path.join(path, "label") |
|
|
| def __len__(self): |
| return len(os.listdir(self.post_path)) |
| |
| def __getitem__(self, idx): |
|
|
| post_list = os.listdir(self.pre_path) |
| pre_list = os.listdir(self.post_path) |
| target_list = os.listdir(self.target_path) |
|
|
| post_list.sort() |
| pre_list.sort() |
| target_list.sort() |
|
|
| pre_image_path = os.path.join(self.pre_path, pre_list[idx]) |
| post_image_path = os.path.join(self.post_path, post_list[idx]) |
| target_path = os.path.join(self.target_path, target_list[idx]) |
|
|
| pre_image = Image.open(pre_image_path) |
| post_image = Image.open(post_image_path) |
| target_image = Image.open(target_path) |
|
|
| pre_image = tfms_normal(pre_image) |
| post_image = tfms_normal(post_image) |
|
|
|
|
| target_image = torch.tensor(np.array(tfms_target(target_image))/255).long() |
|
|
|
|
|
|
| return pre_image, post_image, target_image |
| |
|
|
| class Data_provider_levir(Dataset): |
|
|
|
|
| def __init__(self, path): |
| self.data_path = path |
| self.pre_path = os.path.join(path, "A") |
| self.post_path = os.path.join(path, "B") |
| self.target_path = os.path.join(path, "label") |
|
|
| def __len__(self): |
| return len(os.listdir(self.post_path)) |
| |
| def __getitem__(self, idx): |
|
|
| pre_list = os.listdir(self.pre_path) |
| post_list = os.listdir(self.post_path) |
| target_list = os.listdir(self.target_path) |
|
|
| post_list.sort() |
| pre_list.sort() |
| target_list.sort() |
|
|
| pre_image_path = os.path.join(self.pre_path, pre_list[idx]) |
| post_image_path = os.path.join(self.post_path, post_list[idx]) |
| target_path = os.path.join(self.target_path, target_list[idx]) |
| |
| |
| |
|
|
| pre_image = Image.open(pre_image_path) |
| post_image = Image.open(post_image_path) |
| target_image = Image.open(target_path) |
|
|
| pre_image = tfms_normal(pre_image) |
| post_image = tfms_normal(post_image) |
|
|
|
|
| target_image = torch.tensor(np.array(target_image)/ 255).long() |
|
|
|
|
|
|
| return pre_image, post_image, target_image |
| |
| |
|
|
|
|
|
|
|
|
| class Data_provider_WHU(Dataset): |
|
|
|
|
| def __init__(self, path, file): |
| self.data_path = path |
| self.pre_path = os.path.join(path, "A") |
| self.post_path = os.path.join(path, "B") |
| self.target_path = os.path.join(path, "label") |
| self.data_names = np.array(pd.read_csv(file, names=["tt"])) |
|
|
|
|
| def __len__(self): |
| return len(self.data_names) |
| |
| def __getitem__(self, idx): |
|
|
| name = self.data_names[idx].item() |
|
|
| |
| |
| |
|
|
|
|
| |
| |
| |
|
|
|
|
|
|
| pre_image_path = os.path.join(self.pre_path,name ) |
| post_image_path = os.path.join(self.post_path,name ) |
| target_path = os.path.join(self.target_path, name) |
|
|
|
|
|
|
| |
|
|
| pre_image = Image.open(pre_image_path) |
| post_image = Image.open(post_image_path) |
| target_image = Image.open(target_path) |
|
|
|
|
| pre_image = tfms_normal(pre_image) |
| post_image = tfms_normal(post_image) |
|
|
| target_image = torch.tensor(np.array(tfms_target(target_image)) / 255).long() |
| |
|
|
|
|
|
|
| return pre_image, post_image, target_image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| |
|
|
|
|
|
|
|
|
|
|