Spaces:
Running on Zero
Running on Zero
File size: 3,250 Bytes
ed8899d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 | # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from collections.abc import Sequence
from typing import Any, List, Mapping, Optional, Type, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from sapiens.registry import MODELS
CastData = Union[torch.Tensor, Mapping, Sequence, str, bytes, None]
# -------------------------------------------------------------------------------
@MODELS.register_module()
class BasePreprocessor(nn.Module):
def __init__(self, non_blocking: Optional[bool] = False):
super().__init__()
self._non_blocking = non_blocking
self._device = torch.device("cpu")
def is_seq_of(
self,
seq: Any,
expected_type: Union[Type, tuple],
seq_type: Optional[Type] = None,
) -> bool:
if seq_type is None:
exp_seq_type = Sequence
else:
assert isinstance(seq_type, type)
exp_seq_type = seq_type
if not isinstance(seq, exp_seq_type):
return False
for item in seq:
if not isinstance(item, expected_type):
return False
return True
def stack_batch(
self,
tensor_list: List[torch.Tensor],
pad_size_divisor: int = 1,
pad_value: Union[int, float] = 0,
) -> torch.Tensor:
if not tensor_list:
raise ValueError("tensor_list cannot be empty")
if len({t.ndim for t in tensor_list}) != 1:
raise ValueError("All tensors must have same number of dimensions")
ndim = tensor_list[0].ndim
shapes = torch.tensor([list(t.shape) for t in tensor_list])
max_dims = (
torch.ceil(torch.max(shapes, dim=0)[0] / pad_size_divisor)
* pad_size_divisor
)
# Don't pad channel dimension
pad_amounts = max_dims - shapes
pad_amounts[:, 0] = 0
if pad_amounts.sum() == 0:
return torch.stack(tensor_list)
# Create padding tuples and pad tensors
padded = []
for i, tensor in enumerate(tensor_list):
pad_tuple = []
for j in range(ndim - 1, -1, -1): # Reverse order for F.pad
pad_tuple.extend([0, int(pad_amounts[i, j])])
padded.append(F.pad(tensor, pad_tuple, value=pad_value))
return torch.stack(padded)
def cast_data(self, data: CastData, device: torch.device) -> CastData:
if isinstance(data, Mapping):
return {key: self.cast_data(data[key], device) for key in data}
elif isinstance(data, (str, bytes)) or data is None:
return data
elif isinstance(data, Sequence):
return type(data)(self.cast_data(sample, device) for sample in data)
elif isinstance(data, (torch.Tensor)):
return data.to(device, non_blocking=self._non_blocking)
else:
return data
def forward(self, data: dict, training: bool = False) -> Union[dict, list]:
raise NotImplementedError(
"The forward method must be implemented by a subclass."
)
|