| import functools
|
|
|
| import jax
|
| import jax.numpy as jnp
|
| import torch
|
| import torch.nn.functional as F
|
|
|
| import openpi.shared.array_typing as at
|
|
|
|
|
| @functools.partial(jax.jit, static_argnums=(1, 2, 3))
|
| @at.typecheck
|
| def resize_with_pad(
|
| images: at.UInt8[at.Array, "*b h w c"] | at.Float[at.Array, "*b h w c"],
|
| height: int,
|
| width: int,
|
| method: jax.image.ResizeMethod = jax.image.ResizeMethod.LINEAR,
|
| ) -> at.UInt8[at.Array, "*b {height} {width} c"] | at.Float[at.Array, "*b {height} {width} c"]:
|
| """Replicates tf.image.resize_with_pad. Resizes an image to a target height and width without distortion
|
| by padding with black. If the image is float32, it must be in the range [-1, 1].
|
| """
|
| has_batch_dim = images.ndim == 4
|
| if not has_batch_dim:
|
| images = images[None]
|
| cur_height, cur_width = images.shape[1:3]
|
| ratio = max(cur_width / width, cur_height / height)
|
| resized_height = int(cur_height / ratio)
|
| resized_width = int(cur_width / ratio)
|
| resized_images = jax.image.resize(
|
| images, (images.shape[0], resized_height, resized_width, images.shape[3]), method=method
|
| )
|
| if images.dtype == jnp.uint8:
|
|
|
| resized_images = jnp.round(resized_images).clip(0, 255).astype(jnp.uint8)
|
| elif images.dtype == jnp.float32:
|
| resized_images = resized_images.clip(-1.0, 1.0)
|
| else:
|
| raise ValueError(f"Unsupported image dtype: {images.dtype}")
|
|
|
| pad_h0, remainder_h = divmod(height - resized_height, 2)
|
| pad_h1 = pad_h0 + remainder_h
|
| pad_w0, remainder_w = divmod(width - resized_width, 2)
|
| pad_w1 = pad_w0 + remainder_w
|
| padded_images = jnp.pad(
|
| resized_images,
|
| ((0, 0), (pad_h0, pad_h1), (pad_w0, pad_w1), (0, 0)),
|
| constant_values=0 if images.dtype == jnp.uint8 else -1.0,
|
| )
|
|
|
| if not has_batch_dim:
|
| padded_images = padded_images[0]
|
| return padded_images
|
|
|
|
|
| def resize_with_pad_torch(
|
| images: torch.Tensor,
|
| height: int,
|
| width: int,
|
| mode: str = "bilinear",
|
| ) -> torch.Tensor:
|
| """PyTorch version of resize_with_pad. Resizes an image to a target height and width without distortion
|
| by padding with black. If the image is float32, it must be in the range [-1, 1].
|
|
|
| Args:
|
| images: Tensor of shape [*b, h, w, c] or [*b, c, h, w]
|
| height: Target height
|
| width: Target width
|
| mode: Interpolation mode ('bilinear', 'nearest', etc.)
|
|
|
| Returns:
|
| Resized and padded tensor with same shape format as input
|
| """
|
|
|
| if images.shape[-1] <= 4:
|
| channels_last = True
|
|
|
| if images.dim() == 3:
|
| images = images.unsqueeze(0)
|
| images = images.permute(0, 3, 1, 2)
|
| else:
|
| channels_last = False
|
| if images.dim() == 3:
|
| images = images.unsqueeze(0)
|
|
|
| batch_size, channels, cur_height, cur_width = images.shape
|
|
|
|
|
| ratio = max(cur_width / width, cur_height / height)
|
| resized_height = int(cur_height / ratio)
|
| resized_width = int(cur_width / ratio)
|
|
|
|
|
| resized_images = F.interpolate(
|
| images, size=(resized_height, resized_width), mode=mode, align_corners=False if mode == "bilinear" else None
|
| )
|
|
|
|
|
| if images.dtype == torch.uint8:
|
| resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
|
| elif images.dtype == torch.float32:
|
| resized_images = resized_images.clamp(-1.0, 1.0)
|
| else:
|
| raise ValueError(f"Unsupported image dtype: {images.dtype}")
|
|
|
|
|
| pad_h0, remainder_h = divmod(height - resized_height, 2)
|
| pad_h1 = pad_h0 + remainder_h
|
| pad_w0, remainder_w = divmod(width - resized_width, 2)
|
| pad_w1 = pad_w0 + remainder_w
|
|
|
|
|
| constant_value = 0 if images.dtype == torch.uint8 else -1.0
|
| padded_images = F.pad(
|
| resized_images,
|
| (pad_w0, pad_w1, pad_h0, pad_h1),
|
| mode="constant",
|
| value=constant_value,
|
| )
|
|
|
|
|
| if channels_last:
|
| padded_images = padded_images.permute(0, 2, 3, 1)
|
| if batch_size == 1 and images.shape[0] == 1:
|
| padded_images = padded_images.squeeze(0)
|
|
|
| return padded_images
|
|
|
|
|
| def replace_padding_0to1_torch(image: torch.Tensor,) -> torch.Tensor:
|
| """PyTorch version of replace_padding_0to1.
|
| OpenPI requires images with 0 value paddings, while VGGT series requires 1 value paddings.
|
| Here it achieves this bounding-box based padding replacement.
|
| Args:
|
| image: Tensor of shape [*b, h, w, c]
|
| Returns:
|
| Padding-replaced tensor with same shape as input
|
| """
|
| single = False
|
| if image.dim() == 3:
|
| image = image.unsqueeze(0)
|
| single = True
|
|
|
| b, h, w, c = image.shape
|
| device = image.device
|
|
|
| nonzero_any = (image != 0).any(dim=-1)
|
|
|
| row_any = nonzero_any.any(dim=2)
|
| col_any = nonzero_any.any(dim=1)
|
|
|
| top = row_any.to(torch.float32).argmax(dim=1)
|
| bottom = h - 1 - row_any.flip(dims=[1]).to(torch.float32).argmax(dim=1)
|
| left = col_any.to(torch.float32).argmax(dim=1)
|
| right = w - 1 - col_any.flip(dims=[1]).to(torch.float32).argmax(dim=1)
|
|
|
| has_any = row_any.any(dim=1)
|
| top = torch.where(has_any, top, torch.zeros_like(top))
|
| bottom = torch.where(has_any, bottom, torch.full_like(bottom, h - 1))
|
| left = torch.where(has_any, left, torch.zeros_like(left))
|
| right = torch.where(has_any, right, torch.full_like(right, w - 1))
|
|
|
| rows = torch.arange(h, device=device).view(1, h, 1)
|
| cols = torch.arange(w, device=device).view(1, 1, w)
|
| top_v = top.view(b, 1, 1)
|
| bottom_v = bottom.view(b, 1, 1)
|
| left_v = left.view(b, 1, 1)
|
| right_v = right.view(b, 1, 1)
|
|
|
| row_mask = (rows >= top_v) & (rows <= bottom_v)
|
| col_mask = (cols >= left_v) & (cols <= right_v)
|
| inside_mask = row_mask & col_mask
|
|
|
| padding_mask = ~inside_mask
|
|
|
| pixel_zero = (image == 0).all(dim=-1)
|
|
|
| final_mask = padding_mask & pixel_zero
|
|
|
| if final_mask.any():
|
| mask_exp = final_mask.unsqueeze(-1).expand_as(image)
|
| one_t = torch.tensor(1, dtype=image.dtype, device=device)
|
| image = torch.where(mask_exp, one_t, image)
|
|
|
| if single:
|
| image = image.squeeze(0)
|
| return image |