somatosmpl / soma /geometry /_utils.py
zirobtc's picture
Upload folder using huggingface_hub
bd95c9c verified
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import torch
def require_torch_tensors(*tensors, name="inputs"):
"""Validate that all inputs are torch.Tensors with matching dtype and device."""
if not tensors:
raise ValueError(f"{name} must not be empty.")
if not all(isinstance(x, torch.Tensor) for x in tensors):
raise TypeError(f"All {name} must be torch.Tensor.")
dtypes = {x.dtype for x in tensors}
devices = {x.device for x in tensors}
if len(dtypes) != 1:
raise TypeError(f"All {name} must share dtype; got {dtypes}.")
if len(devices) != 1:
raise TypeError(f"All {name} must be on the same device; got {devices}.")
return dtypes.pop(), devices.pop()
def one_hot_1d(L, idx, *, dtype, device):
"""(L,) with 1 at idx."""
return torch.eye(L, dtype=dtype, device=device)[idx]
def mask_1d(L, indices, *, dtype, device):
"""(L,) with 1 at all 'indices' (e.g., [0,2] for x/z)."""
if not indices:
return torch.zeros((L,), dtype=dtype, device=device)
ohs = torch.stack([one_hot_1d(L, i, dtype=dtype, device=device) for i in indices], dim=0)
return ohs.sum(dim=0)
def one_hot_2d(R, C, r, c, *, dtype, device):
"""(R,C) with 1 at (r,c)."""
return (
one_hot_1d(R, r, dtype=dtype, device=device)[:, None]
* one_hot_1d(C, c, dtype=dtype, device=device)[None, :]
)