| from tokenize import group |
| import torch |
| import numpy as np |
| import numpy.random as npr |
| import torch.distributed as dist |
| import math |
|
|
| from ...log_service import print_log |
| from ... import sync |
|
|
| def singleton(class_): |
| instances = {} |
| def getinstance(*args, **kwargs): |
| if class_ not in instances: |
| instances[class_] = class_(*args, **kwargs) |
| return instances[class_] |
| return getinstance |
|
|
| @singleton |
| class get_sampler(object): |
| def __init__(self): |
| self.sampler = {} |
|
|
| def register(self, sampler): |
| self.sampler[sampler.__name__] = sampler |
|
|
| def __call__(self, dataset, cfg): |
| if cfg == 'default_train': |
| return GlobalDistributedSampler(dataset, shuffle=True, extend=False) |
| elif cfg == 'default_eval': |
| return GlobalDistributedSampler(dataset, shuffle=False, extend=True) |
| else: |
| t = cfg.type |
| return self.sampler[t](dataset=dataset, **cfg.args) |
|
|
| def register(): |
| def wrapper(class_): |
| get_sampler().register(class_) |
| return class_ |
| return wrapper |
|
|
| |
| |
| |
|
|
| @register() |
| class GlobalDistributedSampler(torch.utils.data.Sampler): |
| """ |
| This is a distributed sampler that sync accross gpus and nodes. |
| """ |
| def __init__(self, |
| dataset, |
| shuffle=True, |
| extend=False,): |
| """ |
| Arguments: |
| dataset: Dataset used for sampling. |
| shuffle: If true, sampler will shuffle the indices |
| extend: If true, sampler will extend the indices that can be even distributed by ranks |
| otherwise sampler will truncate the indices to make it even. |
| """ |
| self.ddp = sync.is_ddp() |
| self.rank = sync.get_rank('global') |
| self.world_size = sync.get_world_size('global') |
| self.dataset = dataset |
| self.shuffle = shuffle |
| self.extend = extend |
|
|
| num_samples = len(dataset) // self.world_size |
| if extend and (len(dataset)%self.world_size != 0): |
| num_samples+=1 |
| self.num_samples = num_samples |
| self.total_size = num_samples * self.world_size |
|
|
| def __iter__(self): |
| indices = self.get_sync_order() |
| if self.extend: |
| |
| indices = indices+indices[0:self.total_size-len(indices)] |
| else: |
| |
| indices = indices[0:self.total_size] |
| |
| indices = indices[self.rank : len(indices) : self.world_size] |
| return iter(indices) |
|
|
| def __len__(self): |
| return self.num_samples |
|
|
| def get_sync_order(self): |
| if self.shuffle: |
| indices = torch.randperm(len(self.dataset)).to(self.rank) |
| if self.ddp: |
| dist.broadcast(indices, src=0) |
| indices = indices.to('cpu').tolist() |
| else: |
| indices = list(range(len(self.dataset))) |
| print_log('Sampler : {}'.format(str(indices[0:5])) ) |
| return indices |
|
|
| @register() |
| class LocalDistributedSampler(GlobalDistributedSampler): |
| """ |
| This is a distributed sampler that sync across gpus within the nodes. |
| But not sync across nodes. |
| """ |
| def __init__(self, |
| dataset, |
| shuffle=True, |
| extend=False,): |
| super().__init__(dataset, shuffle, extend) |
| self.rank = sync.get_rank('local') |
| self.world_size = sync.get_world_size('local') |
|
|
| def get_sync_order(self): |
| if self.shuffle: |
| if self.rank == 0: |
| indices = list(npr.permutation(len(self.dataset))) |
| sync.nodewise_sync().broadcast_r0(indices) |
| else: |
| indices = sync.nodewise_sync().broadcast_r0(None) |
| else: |
| indices = list(range(len(self.dataset))) |
| print_log('Sampler : {}'.format(str(indices[0:5])) ) |
| return indices |
|
|
| |
| |
| |
| |
|
|
| @register() |
| class GroupSampler(torch.utils.data.Sampler): |
| """ |
| This is a new DistributedSampler that sample all index according to group. |
| i.e. |
| if group_size=3, num_replicas=2, train mode: |
| 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 |
| ==> (group) [0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10] |
| ==> (distribute) process0: [3, 4, 5], (leftover [6, 7, 8, 9, 10]) |
| process1: [0, 1, 2] |
| ==> (group leftover) process0: [3, 4, 5], (leftover [6, 7], [8, 9], 10) |
| process1: [0, 1, 2] |
| ==> (distribute) process0: [3, 4, 5], [6, 7] (remove 10) |
| process1: [0, 1, 2], [8, 9] |
| |
| it will avoid_batchsize=1: |
| 0, 1, 2, 3, 4, 5, 6, 7, 8, |
| ==> (group) [0, 1, 2], [3, 4, 5], [6, 7, 8] |
| ==> (distribute) process0: [3, 4, 5], (leftover [6, 7, 8]) |
| process1: [0, 1, 2] |
| ==> (group leftover) process0: [3, 4, 5], (leftover [6], [7], [8]) |
| process1: [0, 1, 2] |
| ==> (distribute) process0: [3, 4, 5], (remove 6, 7, 8) (because distribute make batchsize 1) |
| process1: [0, 1, 2] |
| |
| if group_size=3, num_replicas=2, eval mode: |
| 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 |
| ==> (extend) 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10 |
| ==> (group) [0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 10] |
| ==> (distribute) process0: [0, 1, 2], [6, 7, 8], |
| process1: [3, 4, 5], [9, 10, 10] |
| """ |
|
|
| def __init__(self, |
| dataset, |
| group_size, |
| num_replicas=None, |
| rank=None, |
| mode='train',): |
| if num_replicas is None: |
| if not dist.is_available(): |
| raise ValueError |
| num_replicas = dist.get_world_size() |
| if rank is None: |
| if not dist.is_available(): |
| raise ValueError |
| rank = dist.get_rank() |
|
|
| self.dataset = dataset |
| self.len_dataset = len(dataset) |
| self.group_size = group_size |
| self.num_replicas = num_replicas |
| self.rank = rank |
| self.mode = mode |
| len_dataset = self.len_dataset |
|
|
| if (len_dataset % num_replicas != 0) and (mode == 'train'): |
| |
| aligned_indices = np.arange(len_dataset)[:-(len_dataset % num_replicas)] |
| aligned_len_dataset = aligned_indices.shape[0] |
| elif (len_dataset % num_replicas != 0) and (mode == 'eval'): |
| extend = np.array([len_dataset-1 for _ in range(num_replicas - len_dataset % num_replicas)]) |
| aligned_indices = np.concatenate([range(len_dataset), extend]) |
| aligned_len_dataset = aligned_indices.shape[0] |
| else: |
| aligned_indices = np.arange(len_dataset) |
| aligned_len_dataset = len_dataset |
|
|
| num_even_distributed_groups = aligned_len_dataset // (group_size * num_replicas) |
| num_even = num_even_distributed_groups * group_size * num_replicas |
|
|
| self.regular_groups = aligned_indices[0:num_even].reshape(-1, group_size) |
| self.leftover_groups = aligned_indices[num_even:].reshape(num_replicas, -1) |
|
|
| if self.leftover_groups.size == 0: |
| self.leftover_groups = None |
| elif (self.leftover_groups.shape[-1]==1) and (mode == 'train'): |
| |
| self.leftover_groups = None |
|
|
| |
| for groupi in self.regular_groups: |
| for idx in groupi: |
| idx_lowerbd = groupi[0] |
| idx_upperbd = groupi[-1] |
| idx_reference = (idx_lowerbd+idx_upperbd)//2 |
| dataset.load_info[idx]['ref_size'] = dataset.load_info[idx_reference]['image_size'] |
| if self.leftover_groups is not None: |
| for groupi in self.leftover_groups: |
| for idx in groupi: |
| idx_lowerbd = groupi[0] |
| idx_upperbd = groupi[-1] |
| idx_reference = (idx_lowerbd+idx_upperbd)//2 |
| dataset.load_info[idx]['ref_size'] = dataset.load_info[idx_reference]['image_size'] |
|
|
| def concat(self, nparrays, axis=0): |
| |
| nparrays = [i for i in nparrays if i.size > 0] |
| return np.concatenate(nparrays, axis=axis) |
|
|
| def __iter__(self): |
| indices = self.get_sync_order() |
| return iter(indices) |
|
|
| def __len__(self): |
| return self.num_samples |
|
|
| def get_sync_order(self): |
| |
| |
|
|
| mode = self.mode |
| rank = self.rank |
| num_replicas = self.num_replicas |
| group_size = self.group_size |
| num_groups = len(self.regular_groups) |
|
|
| if mode == 'train': |
| g_indices = torch.randperm(num_groups).to(rank) |
| dist.broadcast(g_indices, src=0) |
| g_indices = g_indices.to('cpu').tolist() |
| num_groups_per_rank = num_groups // num_replicas |
| groups = self.regular_groups[g_indices][num_groups_per_rank*rank : num_groups_per_rank*(rank+1)] |
| indices = groups.flatten() |
|
|
| if self.leftover_groups is not None: |
| leftg_indices = torch.randperm(len(self.leftover_groups)).to(rank) |
| dist.broadcast(leftg_indices, src=0) |
| leftg_indices = leftg_indices.to('cpu').tolist() |
| last = self.leftover_groups[leftg_indices][rank] |
| indices = np.concatenate([indices, last], axis=0) |
| elif mode == 'eval': |
| groups = self.regular_groups.reshape(-1, num_replicas, group_size)[:, rank, :] |
| indices = groups.flatten() |
| if self.leftover_groups is not None: |
| last = self.leftover_groups[rank] |
| indices = np.concatenate([indices, last], axis=0) |
| else: |
| raise ValueError |
| |
| print_log('Sampler RANK {} : {}'.format(rank, str(indices[0:group_size+1]))) |
| return indices |
|
|