Spaces:
Sleeping
Sleeping
File size: 5,539 Bytes
3a66575 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 | import os
import random
import torch
import numpy as np
from torch.utils.data import Dataset
from PIL import Image
import cv2
import glob
import scipy.io as io
import re
class SHHA(Dataset):
def __init__(
self,
data_root,
transform=None,
train=False,
patch=False,
flip=False,
train_list="shanghai_tech_part_a_train.list",
eval_list="shanghai_tech_part_a_test.list",
):
self.root_path = data_root
self.train_lists = train_list
self.eval_list = eval_list
# there may exist multiple list files
self.img_list_file = self.train_lists.split(',')
if train:
self.img_list_file = self.train_lists.split(',')
else:
self.img_list_file = self.eval_list.split(',')
self.img_map = {}
self.img_list = []
# loads the image/gt pairs
for _, train_list in enumerate(self.img_list_file):
train_list = train_list.strip()
with open(os.path.join(self.root_path, train_list)) as fin:
for line in fin:
if len(line) < 2:
continue
line = line.strip()
if "\t" in line:
img_path, gt_path = line.split("\t", 1)
else:
line = line.split()
if len(line) < 2:
continue
img_path = line[0]
gt_path = line[1]
img_path = img_path.strip()
gt_path = gt_path.strip()
if not os.path.isabs(img_path):
img_path = os.path.join(self.root_path, img_path)
if not os.path.isabs(gt_path):
gt_path = os.path.join(self.root_path, gt_path)
self.img_map[img_path] = gt_path
self.img_list = sorted(list(self.img_map.keys()))
# number of samples
self.nSamples = len(self.img_list)
self.transform = transform
self.train = train
self.patch = patch
self.flip = flip
def __len__(self):
return self.nSamples
def __getitem__(self, index):
assert index <= len(self), 'index range error'
img_path = self.img_list[index]
gt_path = self.img_map[img_path]
# load image and ground truth
img, point = load_data((img_path, gt_path), self.train)
# applu augumentation
if self.transform is not None:
img = self.transform(img)
if self.train:
# data augmentation -> random scale
scale_range = [0.7, 1.3]
min_size = min(img.shape[1:])
scale = random.uniform(*scale_range)
# scale the image and points
if scale * min_size > 128:
img = torch.nn.functional.upsample_bilinear(img.unsqueeze(0), scale_factor=scale).squeeze(0)
point *= scale
# random crop augumentaiton
if self.train and self.patch:
img, point = random_crop(img, point)
for i, _ in enumerate(point):
point[i] = torch.Tensor(point[i])
# random flipping
if random.random() > 0.5 and self.train and self.flip:
# random flip
img = torch.Tensor(img[:, :, :, ::-1].copy())
for i, _ in enumerate(point):
point[i][:, 0] = 128 - point[i][:, 0]
if not self.train:
point = [point]
img = torch.Tensor(img)
# pack up related infos
target = [{} for i in range(len(point))]
for i, _ in enumerate(point):
target[i]['point'] = torch.Tensor(point[i])
image_stem = os.path.splitext(os.path.basename(img_path))[0]
digits = re.findall(r'\d+', image_stem)
image_id = int(digits[-1]) if digits else index
image_id = torch.Tensor([image_id]).long()
target[i]['image_id'] = image_id
target[i]['labels'] = torch.ones([point[i].shape[0]]).long()
return img, target
def load_data(img_gt_path, train):
img_path, gt_path = img_gt_path
# load the images
img = cv2.imread(img_path)
img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
# load ground truth points
points = []
with open(gt_path) as f_label:
for line in f_label:
x = float(line.strip().split(' ')[0])
y = float(line.strip().split(' ')[1])
points.append([x, y])
return img, np.array(points)
# random crop augumentation
def random_crop(img, den, num_patch=4):
half_h = 128
half_w = 128
result_img = np.zeros([num_patch, img.shape[0], half_h, half_w])
result_den = []
# crop num_patch for each image
for i in range(num_patch):
start_h = random.randint(0, img.size(1) - half_h)
start_w = random.randint(0, img.size(2) - half_w)
end_h = start_h + half_h
end_w = start_w + half_w
# copy the cropped rect
result_img[i] = img[:, start_h:end_h, start_w:end_w]
# copy the cropped points
idx = (den[:, 0] >= start_w) & (den[:, 0] <= end_w) & (den[:, 1] >= start_h) & (den[:, 1] <= end_h)
# shift the corrdinates
record_den = den[idx]
record_den[:, 0] -= start_w
record_den[:, 1] -= start_h
result_den.append(record_den)
return result_img, result_den
|