| |
| from __future__ import annotations |
| import torch |
|
|
| def log_shape(tag: str, t: torch.Tensor) -> None: |
| try: |
| mn = float(t.min()) if t.numel() else float("nan") |
| mx = float(t.max()) if t.numel() else float("nan") |
| print(f"[interop] {tag}: shape={tuple(t.shape)} dtype={t.dtype} device={t.device} " |
| f"range=[{mn:.4f},{mx:.4f}]") |
| except Exception as e: |
| print(f"[interop] {tag}: <log failed: {e!r}>") |
|
|
| def _to_float01(x: torch.Tensor) -> torch.Tensor: |
| x = x.to(torch.float32) |
| if x.max() > 1.0: |
| x = x / 255.0 |
| return x.clamp_(0.0, 1.0) |
|
|
| def _squeeze_bt(x: torch.Tensor) -> torch.Tensor: |
| |
| if x.ndim == 5: |
| if x.shape[1] == 1: |
| x = x.squeeze(1) |
| if x.ndim == 5 and x.shape[0] == 1: |
| x = x.squeeze(0) |
| |
| if x.ndim == 4 and x.shape[0] == 1 and x.shape[1] == 1 and x.shape[-3] == 3: |
| x = x.squeeze(1) |
| return x |
|
|
| def ensure_image_nchw( |
| img: torch.Tensor, |
| device: torch.device | str = "cuda", |
| want_batched: bool = True, |
| ) -> torch.Tensor: |
| img = img.to(device) |
| img = _squeeze_bt(img) |
| if img.ndim == 3: |
| |
| if img.shape[0] in (1,3): |
| chw = img |
| else: |
| chw = img.permute(2,0,1) |
| chw = _to_float01(chw.contiguous()) |
| return chw.unsqueeze(0) if want_batched else chw |
| if img.ndim == 4: |
| N,A,B,C = img.shape |
| if A == 3: |
| nchw = img |
| elif C == 3: |
| nchw = img.permute(0,3,1,2) |
| else: |
| raise AssertionError(f"Cannot infer channels in image: {tuple(img.shape)}") |
| return _to_float01(nchw.contiguous()) |
| raise AssertionError(f"Image must be 3D/4D; got {tuple(img.shape)}") |
|
|
| def ensure_mask_for_matanyone( |
| mask: torch.Tensor, |
| *, |
| idx_mask: bool = False, |
| threshold: float = 0.5, |
| keep_soft: bool = False, |
| device: torch.device | str = "cuda", |
| ) -> torch.Tensor: |
| mask = mask.to(device) |
| mask = _squeeze_bt(mask) |
|
|
| if idx_mask: |
| |
| if mask.ndim == 3: |
| if mask.shape[0] == 1: |
| idx = (mask[0] >= threshold).to(torch.long) |
| else: |
| idx = torch.argmax(mask, dim=0).to(torch.long) |
| idx = (idx > 0).to(torch.long) |
| elif mask.ndim == 2: |
| idx = (mask >= threshold).to(torch.long) |
| else: |
| raise AssertionError(f"idx mask must be 2D/3D; got {tuple(mask.shape)}") |
| return idx |
|
|
| |
| if mask.ndim == 2: |
| out = mask.unsqueeze(0) |
| elif mask.ndim == 3: |
| if mask.shape[0] == 1: |
| out = mask |
| else: |
| |
| areas = mask.sum(dim=(-2,-1)) |
| out = mask[areas.argmax():areas.argmax()+1] |
| else: |
| raise AssertionError(f"mask must be 2D/3D; got {tuple(mask.shape)}") |
|
|
| out = out.to(torch.float32) |
| if not keep_soft: |
| out = (out >= threshold).to(torch.float32) |
| return out.clamp_(0.0, 1.0).contiguous() |
|
|