| |
| |
|
|
| import torch |
|
|
|
|
| def require_torch_tensors(*tensors, name="inputs"): |
| """Validate that all inputs are torch.Tensors with matching dtype and device.""" |
| if not tensors: |
| raise ValueError(f"{name} must not be empty.") |
| if not all(isinstance(x, torch.Tensor) for x in tensors): |
| raise TypeError(f"All {name} must be torch.Tensor.") |
| dtypes = {x.dtype for x in tensors} |
| devices = {x.device for x in tensors} |
| if len(dtypes) != 1: |
| raise TypeError(f"All {name} must share dtype; got {dtypes}.") |
| if len(devices) != 1: |
| raise TypeError(f"All {name} must be on the same device; got {devices}.") |
| return dtypes.pop(), devices.pop() |
|
|
|
|
| def one_hot_1d(L, idx, *, dtype, device): |
| """(L,) with 1 at idx.""" |
| return torch.eye(L, dtype=dtype, device=device)[idx] |
|
|
|
|
| def mask_1d(L, indices, *, dtype, device): |
| """(L,) with 1 at all 'indices' (e.g., [0,2] for x/z).""" |
| if not indices: |
| return torch.zeros((L,), dtype=dtype, device=device) |
| ohs = torch.stack([one_hot_1d(L, i, dtype=dtype, device=device) for i in indices], dim=0) |
| return ohs.sum(dim=0) |
|
|
|
|
| def one_hot_2d(R, C, r, c, *, dtype, device): |
| """(R,C) with 1 at (r,c).""" |
| return ( |
| one_hot_1d(R, r, dtype=dtype, device=device)[:, None] |
| * one_hot_1d(C, c, dtype=dtype, device=device)[None, :] |
| ) |
|
|