| |
| |
| |
| |
| |
| |
|
|
| import re |
| import contextlib |
| import functools |
| import numpy as np |
| import torch |
| import warnings |
| import dnnlib |
|
|
| |
| |
|
|
| def set_random_seed(*args): |
| seed = hash(args) % (1 << 31) |
| torch.manual_seed(seed) |
| np.random.seed(seed) |
|
|
| |
| |
| |
|
|
| _constant_cache = dict() |
|
|
| def constant(value, shape=None, dtype=None, device=None, memory_format=None): |
| value = np.asarray(value) |
| if shape is not None: |
| shape = tuple(shape) |
| if dtype is None: |
| dtype = torch.get_default_dtype() |
| if device is None: |
| device = torch.device('cpu') |
| if memory_format is None: |
| memory_format = torch.contiguous_format |
|
|
| key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) |
| tensor = _constant_cache.get(key, None) |
| if tensor is None: |
| tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) |
| if shape is not None: |
| tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) |
| tensor = tensor.contiguous(memory_format=memory_format) |
| _constant_cache[key] = tensor |
| return tensor |
|
|
| |
| |
| |
|
|
| def const_like(ref, value, shape=None, dtype=None, device=None, memory_format=None): |
| if dtype is None: |
| dtype = ref.dtype |
| if device is None: |
| device = ref.device |
| return constant(value, shape=shape, dtype=dtype, device=device, memory_format=memory_format) |
|
|
| |
| |
|
|
| @functools.lru_cache(None) |
| def pinned_buf(shape, dtype): |
| return torch.empty(shape, dtype=dtype).pin_memory() |
|
|
| |
| |
|
|
| try: |
| symbolic_assert = torch._assert |
| except AttributeError: |
| symbolic_assert = torch.Assert |
|
|
| |
| |
| |
|
|
| @contextlib.contextmanager |
| def suppress_tracer_warnings(): |
| flt = ('ignore', None, torch.jit.TracerWarning, None, 0) |
| warnings.filters.insert(0, flt) |
| yield |
| warnings.filters.remove(flt) |
|
|
| |
| |
| |
| |
|
|
| def assert_shape(tensor, ref_shape): |
| if tensor.ndim != len(ref_shape): |
| raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') |
| for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): |
| if ref_size is None: |
| pass |
| elif isinstance(ref_size, torch.Tensor): |
| with suppress_tracer_warnings(): |
| symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') |
| elif isinstance(size, torch.Tensor): |
| with suppress_tracer_warnings(): |
| symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') |
| elif size != ref_size: |
| raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') |
|
|
| |
| |
|
|
| def profiled_function(fn): |
| def decorator(*args, **kwargs): |
| with torch.autograd.profiler.record_function(fn.__name__): |
| return fn(*args, **kwargs) |
| decorator.__name__ = fn.__name__ |
| return decorator |
|
|
| |
| |
| |
|
|
| class InfiniteSampler(torch.utils.data.Sampler): |
| def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, start_idx=0): |
| assert len(dataset) > 0 |
| assert num_replicas > 0 |
| assert 0 <= rank < num_replicas |
| warnings.filterwarnings('ignore', '`data_source` argument is not used and will be removed') |
| super().__init__(dataset) |
| self.dataset_size = len(dataset) |
| self.start_idx = start_idx + rank |
| self.stride = num_replicas |
| self.shuffle = shuffle |
| self.seed = seed |
|
|
| def __iter__(self): |
| idx = self.start_idx |
| epoch = None |
| while True: |
| if epoch != idx // self.dataset_size: |
| epoch = idx // self.dataset_size |
| order = np.arange(self.dataset_size) |
| if self.shuffle: |
| np.random.RandomState(hash((self.seed, epoch)) % (1 << 31)).shuffle(order) |
| yield int(order[idx % self.dataset_size]) |
| idx += self.stride |
|
|
| |
| |
|
|
| def params_and_buffers(module): |
| assert isinstance(module, torch.nn.Module) |
| return list(module.parameters()) + list(module.buffers()) |
|
|
| def named_params_and_buffers(module): |
| assert isinstance(module, torch.nn.Module) |
| return list(module.named_parameters()) + list(module.named_buffers()) |
|
|
| @torch.no_grad() |
| def copy_params_and_buffers(src_module, dst_module, require_all=False): |
| assert isinstance(src_module, torch.nn.Module) |
| assert isinstance(dst_module, torch.nn.Module) |
| src_tensors = dict(named_params_and_buffers(src_module)) |
| for name, tensor in named_params_and_buffers(dst_module): |
| assert (name in src_tensors) or (not require_all) |
| if name in src_tensors: |
| tensor.copy_(src_tensors[name]) |
|
|
| |
| |
| |
|
|
| @contextlib.contextmanager |
| def ddp_sync(module, sync): |
| assert isinstance(module, torch.nn.Module) |
| if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): |
| yield |
| else: |
| with module.no_sync(): |
| yield |
|
|
| |
| |
|
|
| def check_ddp_consistency(module, ignore_regex=None): |
| assert isinstance(module, torch.nn.Module) |
| for name, tensor in named_params_and_buffers(module): |
| fullname = type(module).__name__ + '.' + name |
| if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): |
| continue |
| tensor = tensor.detach() |
| if tensor.is_floating_point(): |
| tensor = torch.nan_to_num(tensor) |
| other = tensor.clone() |
| torch.distributed.broadcast(tensor=other, src=0) |
| assert (tensor == other).all(), fullname |
|
|
| |
| |
|
|
| @torch.no_grad() |
| def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): |
| assert isinstance(module, torch.nn.Module) |
| assert not isinstance(module, torch.jit.ScriptModule) |
| assert isinstance(inputs, (tuple, list)) |
|
|
| |
| entries = [] |
| nesting = [0] |
| def pre_hook(_mod, _inputs): |
| nesting[0] += 1 |
| def post_hook(mod, _inputs, outputs): |
| nesting[0] -= 1 |
| if nesting[0] <= max_nesting: |
| outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] |
| outputs = [t for t in outputs if isinstance(t, torch.Tensor)] |
| entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) |
| hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] |
| hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] |
|
|
| |
| outputs = module(*inputs) |
| for hook in hooks: |
| hook.remove() |
|
|
| |
| tensors_seen = set() |
| for e in entries: |
| e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] |
| e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] |
| e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] |
| tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} |
|
|
| |
| if skip_redundant: |
| entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] |
|
|
| |
| rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] |
| rows += [['---'] * len(rows[0])] |
| param_total = 0 |
| buffer_total = 0 |
| submodule_names = {mod: name for name, mod in module.named_modules()} |
| for e in entries: |
| name = '<top-level>' if e.mod is module else submodule_names[e.mod] |
| param_size = sum(t.numel() for t in e.unique_params) |
| buffer_size = sum(t.numel() for t in e.unique_buffers) |
| output_shapes = [str(list(t.shape)) for t in e.outputs] |
| output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] |
| rows += [[ |
| name + (':0' if len(e.outputs) >= 2 else ''), |
| str(param_size) if param_size else '-', |
| str(buffer_size) if buffer_size else '-', |
| (output_shapes + ['-'])[0], |
| (output_dtypes + ['-'])[0], |
| ]] |
| for idx in range(1, len(e.outputs)): |
| rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] |
| param_total += param_size |
| buffer_total += buffer_size |
| rows += [['---'] * len(rows[0])] |
| rows += [['Total', str(param_total), str(buffer_total), '-', '-']] |
|
|
| |
| widths = [max(len(cell) for cell in column) for column in zip(*rows)] |
| print() |
| for row in rows: |
| print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) |
| print() |
|
|
| |
| |
|
|
| def tile_images(x, w, h): |
| assert x.ndim == 4 |
| return x.reshape(h, w, *x.shape[1:]).permute(2, 0, 3, 1, 4).reshape(x.shape[1], h * x.shape[2], w * x.shape[3]) |
|
|
| |
|
|