| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from functools import partial |
| from typing import List |
|
|
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| def inflate_array_like(array, target): |
| """ (tested) |
| Inflates the array (array) with only a single axis (i.e. shape = (batch_size,), or possibly more empty |
| axes (i.e. shape (batch_size, 1, ..., 1)) to match the target shape. |
| Args: |
| array: (B, ) |
| target: (B, ...) |
| |
| Returns: |
| array: (B, ...) |
| """ |
| if isinstance(array, float): |
| return array |
| |
| diff_dims = target.ndim - array.ndim |
| assert diff_dims >= 0, f'Error: target.ndim {target.ndim} < array.ndim {array.ndim}' |
| if diff_dims == 0: |
| return array |
| assert target.shape[:array.ndim] == array.shape[:array.ndim], f'Error: target.shape[:array.ndim] {target.shape[:array.ndim]} != array.shape[:array.ndim] {array.shape[:array.ndim]}' |
| return array[(...,) + (None,) * diff_dims] |
|
|
|
|
| def permute_final_dims(tensor: torch.Tensor, inds: List[int]): |
| zero_index = -1 * len(inds) |
| first_inds = list(range(len(tensor.shape[:zero_index]))) |
| return tensor.permute(first_inds + [zero_index + i for i in inds]) |
|
|
| def flatten_final_dims(t: torch.Tensor, no_dims: int): |
| return t.reshape(t.shape[:-no_dims] + (-1,)) |
|
|
| def sum_except_batch(t: torch.Tensor, batch_dims: int=1): |
| return t.reshape(t.shape[:batch_dims] + (-1,)).sum(dim=-1) |
| |
| def masked_mean(mask, value, dim, eps=1e-4): |
| mask = mask.expand(*value.shape) |
| return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim)) |
|
|
|
|
| def pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64): |
| boundaries = torch.linspace( |
| min_bin, max_bin, no_bins - 1, device=pts.device |
| ) |
| dists = torch.sqrt( |
| torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3)) ** 2, dim=-1) |
| ) |
| return torch.bucketize(dists, boundaries) |
|
|
|
|
| def dict_multimap(fn, dicts): |
| first = dicts[0] |
| new_dict = {} |
| for k, v in first.items(): |
| all_v = [d[k] for d in dicts] |
| if type(v) is dict: |
| new_dict[k] = dict_multimap(fn, all_v) |
| else: |
| new_dict[k] = fn(all_v) |
|
|
| return new_dict |
|
|
|
|
| def one_hot(x, v_bins): |
| reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),)) |
| diffs = x[..., None] - reshaped_bins |
| am = torch.argmin(torch.abs(diffs), dim=-1) |
| return nn.functional.one_hot(am, num_classes=len(v_bins)).float() |
|
|
|
|
| def batched_gather(data, inds, dim=0, no_batch_dims=0): |
| ranges = [] |
| for i, s in enumerate(data.shape[:no_batch_dims]): |
| r = torch.arange(s) |
| r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1)))) |
| ranges.append(r) |
|
|
| remaining_dims = [ |
| slice(None) for _ in range(len(data.shape) - no_batch_dims) |
| ] |
| remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds |
| ranges.extend(remaining_dims) |
| return data[ranges] |
|
|
|
|
| |
| def dict_map(fn, dic, leaf_type): |
| new_dict = {} |
| for k, v in dic.items(): |
| if type(v) is dict: |
| new_dict[k] = dict_map(fn, v, leaf_type) |
| else: |
| new_dict[k] = tree_map(fn, v, leaf_type) |
|
|
| return new_dict |
|
|
|
|
| def tree_map(fn, tree, leaf_type): |
| if isinstance(tree, dict): |
| return dict_map(fn, tree, leaf_type) |
| elif isinstance(tree, list): |
| return [tree_map(fn, x, leaf_type) for x in tree] |
| elif isinstance(tree, tuple): |
| return tuple([tree_map(fn, x, leaf_type) for x in tree]) |
| elif isinstance(tree, leaf_type): |
| return fn(tree) |
| else: |
| print(type(tree)) |
| raise ValueError("Not supported") |
|
|
|
|
| tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor) |
|
|
| def _fetch_dims(tree): |
| shapes = [] |
| tree_type = type(tree) |
| if tree_type is dict: |
| for v in tree.values(): |
| shapes.extend(_fetch_dims(v)) |
| elif tree_type is list or tree_type is tuple: |
| for t in tree: |
| shapes.extend(_fetch_dims(t)) |
| elif tree_type is torch.Tensor: |
| shapes.append(tree.shape) |
| else: |
| raise ValueError("Not supported") |
|
|
| return shapes |
|
|