| import torch |
| from torch import nn |
| import numpy as np |
| import os, json |
| from tqdm import tqdm |
| from argparse import ArgumentParser |
| from typing import Dict |
|
|
| import datasets |
|
|
|
|
| class SumPool2d(nn.Module): |
| def __init__(self, kernel_size: int, stride: int): |
| super(SumPool2d, self).__init__() |
| self.kernel_size = kernel_size |
| self.stride = stride |
| self.sum_pool = nn.AvgPool2d(kernel_size, stride, divisor_override=1) |
|
|
| def forward(self, x): |
| return self.sum_pool(x) |
|
|
|
|
| def _update_dict(d: Dict, keys: np.ndarray, values: np.ndarray) -> Dict: |
| keys = keys.tolist() if isinstance(keys, np.ndarray) else keys |
| values = values.tolist() if isinstance(values, np.ndarray) else values |
| for k, v in zip(keys, values): |
| d[k] = d.get(k, 0) + v |
|
|
| return d |
|
|
|
|
| def _get_counts( |
| dataset_name: str, |
| device: torch.device, |
| ) -> None: |
| filter_4 = SumPool2d(4, 1).to(device) |
| filter_7 = SumPool2d(7, 1).to(device) |
| filter_8 = SumPool2d(8, 1).to(device) |
| filter_14 = SumPool2d(14, 1).to(device) |
| filter_16 = SumPool2d(16, 1).to(device) |
| filter_28 = SumPool2d(28, 1).to(device) |
| filter_32 = SumPool2d(32, 1).to(device) |
| filter_56 = SumPool2d(56, 1).to(device) |
| filter_64 = SumPool2d(64, 1).to(device) |
| counts_1, counts_4, counts_7, counts_8 = {}, {}, {}, {} |
| counts_14, counts_16 = {}, {} |
| counts_28, counts_32 = {}, {} |
| counts_56, counts_64 = {}, {} |
|
|
| max_counts_4 = {"max": 0., "name": None, "x": None, "y": None} |
| max_counts_7 = {"max": 0., "name": None, "x": None, "y": None} |
| max_counts_8 = {"max": 0., "name": None, "x": None, "y": None} |
| max_counts_14 = {"max": 0., "name": None, "x": None, "y": None} |
| max_counts_16 = {"max": 0., "name": None, "x": None, "y": None} |
| max_counts_28 = {"max": 0., "name": None, "x": None, "y": None} |
| max_counts_32 = {"max": 0., "name": None, "x": None, "y": None} |
| max_counts_56 = {"max": 0., "name": None, "x": None, "y": None} |
| max_counts_64 = {"max": 0., "name": None, "x": None, "y": None} |
|
|
| counts_dir = os.path.join(os.getcwd(), "counts") |
| os.makedirs(counts_dir, exist_ok=True) |
|
|
| dataset = datasets.Crowd(dataset=dataset_name, split="train", transforms=None, return_filename=True) |
| print(f"Counting {dataset_name} dataset") |
|
|
| for i in tqdm(range(len(dataset))): |
| _, _, density, img_name = dataset[i] |
| density_np = density.cpu().numpy().astype(int) |
| uniques_, counts_ = np.unique(density_np, return_counts=True) |
| counts_1 = _update_dict(counts_1, uniques_, counts_) |
|
|
| density = density.to(device) |
| window_4, window_7, window_8 = filter_4(density), filter_7(density), filter_8(density) |
| window_14, window_16 = filter_14(density), filter_16(density) |
| window_28, window_32 = filter_28(density), filter_32(density) |
| window_56, window_64 = filter_56(density), filter_64(density) |
|
|
| window_4, window_7, window_8 = torch.round(window_4).int(), torch.round(window_7).int(), torch.round(window_8).int() |
| window_14, window_16 = torch.round(window_14).int(), torch.round(window_16).int() |
| window_28, window_32 = torch.round(window_28).int(), torch.round(window_32).int() |
| window_56, window_64 = torch.round(window_56).int(), torch.round(window_64).int() |
|
|
| window_4, window_7, window_8 = torch.squeeze(window_4), torch.squeeze(window_7), torch.squeeze(window_8) |
| window_14, window_16 = torch.squeeze(window_14), torch.squeeze(window_16) |
| window_28, window_32 = torch.squeeze(window_28), torch.squeeze(window_32) |
| window_56, window_64 = torch.squeeze(window_56), torch.squeeze(window_64) |
|
|
| if window_4.max().item() > max_counts_4["max"]: |
| max_counts_4["max"] = window_4.max().item() |
| max_counts_4["name"] = img_name |
| x, y = torch.where(window_4 == window_4.max()) |
| x, y = x[0].item(), y[0].item() |
| max_counts_4["x"] = x |
| max_counts_4["y"] = y |
| |
| if window_7.max().item() > max_counts_7["max"]: |
| max_counts_7["max"] = window_7.max().item() |
| max_counts_7["name"] = img_name |
| x, y = torch.where(window_7 == window_7.max()) |
| x, y = x[0].item(), y[0].item() |
| max_counts_7["x"] = x |
| max_counts_7["y"] = y |
| |
| if window_8.max().item() > max_counts_8["max"]: |
| max_counts_8["max"] = window_8.max().item() |
| max_counts_8["name"] = img_name |
| x, y = torch.where(window_8 == window_8.max()) |
| x, y = x[0].item(), y[0].item() |
| max_counts_8["x"] = x |
| max_counts_8["y"] = y |
| |
| if window_14.max().item() > max_counts_14["max"]: |
| max_counts_14["max"] = window_14.max().item() |
| max_counts_14["name"] = img_name |
| x, y = torch.where(window_14 == window_14.max()) |
| x, y = x[0].item(), y[0].item() |
| max_counts_14["x"] = x |
| max_counts_14["y"] = y |
| |
| if window_16.max().item() > max_counts_16["max"]: |
| max_counts_16["max"] = window_16.max().item() |
| max_counts_16["name"] = img_name |
| x, y = torch.where(window_16 == window_16.max()) |
| x, y = x[0].item(), y[0].item() |
| max_counts_16["x"] = x |
| max_counts_16["y"] = y |
| |
| if window_28.max().item() > max_counts_28["max"]: |
| max_counts_28["max"] = window_28.max().item() |
| max_counts_28["name"] = img_name |
| x, y = torch.where(window_28 == window_28.max()) |
| x, y = x[0].item(), y[0].item() |
| max_counts_28["x"] = x |
| max_counts_28["y"] = y |
| |
| if window_32.max().item() > max_counts_32["max"]: |
| max_counts_32["max"] = window_32.max().item() |
| max_counts_32["name"] = img_name |
| x, y = torch.where(window_32 == window_32.max()) |
| x, y = x[0].item(), y[0].item() |
| max_counts_32["x"] = x |
| max_counts_32["y"] = y |
| |
| if window_56.max().item() > max_counts_56["max"]: |
| max_counts_56["max"] = window_56.max().item() |
| max_counts_56["name"] = img_name |
| x, y = torch.where(window_56 == window_56.max()) |
| x, y = x[0].item(), y[0].item() |
| max_counts_56["x"] = x |
| max_counts_56["y"] = y |
| |
| if window_64.max().item() > max_counts_64["max"]: |
| max_counts_64["max"] = window_64.max().item() |
| max_counts_64["name"] = img_name |
| x, y = torch.where(window_64 == window_64.max()) |
| x, y = x[0].item(), y[0].item() |
| max_counts_64["x"] = x |
| max_counts_64["y"] = y |
|
|
| window_4 = window_4.view(-1).cpu().numpy().astype(int) |
| window_7 = window_7.view(-1).cpu().numpy().astype(int) |
| window_8 = window_8.view(-1).cpu().numpy().astype(int) |
| window_14 = window_14.view(-1).cpu().numpy().astype(int) |
| window_16 = window_16.view(-1).cpu().numpy().astype(int) |
| window_28 = window_28.view(-1).cpu().numpy().astype(int) |
| window_32 = window_32.view(-1).cpu().numpy().astype(int) |
| window_56 = window_56.view(-1).cpu().numpy().astype(int) |
| window_64 = window_64.view(-1).cpu().numpy().astype(int) |
| |
|
|
| uniques_, counts_ = np.unique(window_4, return_counts=True) |
| counts_4 = _update_dict(counts_4, uniques_, counts_) |
|
|
| uniques_, counts_ = np.unique(window_7, return_counts=True) |
| counts_7 = _update_dict(counts_7, uniques_, counts_) |
|
|
| uniques_, counts_ = np.unique(window_8, return_counts=True) |
| counts_8 = _update_dict(counts_8, uniques_, counts_) |
|
|
| uniques_, counts_ = np.unique(window_14, return_counts=True) |
| counts_14 = _update_dict(counts_14, uniques_, counts_) |
|
|
| uniques_, counts_ = np.unique(window_16, return_counts=True) |
| counts_16 = _update_dict(counts_16, uniques_, counts_) |
|
|
| uniques_, counts_ = np.unique(window_28, return_counts=True) |
| counts_28 = _update_dict(counts_28, uniques_, counts_) |
|
|
| uniques_, counts_ = np.unique(window_32, return_counts=True) |
| counts_32 = _update_dict(counts_32, uniques_, counts_) |
|
|
| uniques_, counts_ = np.unique(window_56, return_counts=True) |
| counts_56 = _update_dict(counts_56, uniques_, counts_) |
|
|
| uniques_, counts_ = np.unique(window_64, return_counts=True) |
| counts_64 = _update_dict(counts_64, uniques_, counts_) |
|
|
| counts = { |
| 1: counts_1, |
| 4: counts_4, |
| 7: counts_7, |
| 8: counts_8, |
| 14: counts_14, |
| 16: counts_16, |
| 28: counts_28, |
| 32: counts_32, |
| 56: counts_56, |
| 64: counts_64 |
| } |
|
|
| max_counts = { |
| 4: max_counts_4, |
| 7: max_counts_7, |
| 8: max_counts_8, |
| 14: max_counts_14, |
| 16: max_counts_16, |
| 28: max_counts_28, |
| 32: max_counts_32, |
| 56: max_counts_56, |
| 64: max_counts_64 |
| } |
|
|
| with open(os.path.join(counts_dir, f"{dataset_name}.json"), "w") as f: |
| json.dump(counts, f) |
| |
| with open(os.path.join(counts_dir, f"{dataset_name}_max.json"), "w") as f: |
| json.dump(max_counts, f) |
|
|
|
|
| def parse_args(): |
| parser = ArgumentParser(description="Get local counts of the dataset") |
| parser.add_argument( |
| "--dataset", |
| type=str, |
| choices=["nwpu", "ucf_qnrf", "shanghaitech_a", "shanghaitech_b"], |
| required=True, |
| help="The dataset to use." |
| ) |
| parser.add_argument( |
| "--device", |
| type=str, |
| default="cuda", |
| help="The device to use." |
| ) |
| args = parser.parse_args() |
| return args |
|
|
|
|
| if __name__ == "__main__": |
| args = parse_args() |
| args.dataset = datasets.standardize_dataset_name(args.dataset) |
| args.device = torch.device(args.device) |
| _get_counts(args.dataset, args.device) |
|
|