| import numpy as np |
| import torch |
| from PIL import Image |
| from transformers.image_processing_utils import BaseImageProcessor |
| from transformers.utils import logging |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class VQModelImageProcessor(BaseImageProcessor): |
| def __init__( |
| self, |
| size: int = 256, |
| convert_rgb: bool = False, |
| resample: Image.Resampling = Image.Resampling.LANCZOS, |
| **kwargs: dict, |
| ) -> None: |
| self.size = size |
| self.convert_rgb = convert_rgb |
| self.resample = resample |
|
|
| def __call__(self, image: Image.Image) -> dict: |
| return self.preprocess(image) |
|
|
| def preprocess(self, image: Image.Image) -> dict: |
| width, height = image.size |
| size = (self.size, self.size) |
| image = image.resize(size, resample=self.resample) |
| image = image.convert("RGBA") |
|
|
| if self.convert_rgb: |
| |
| image_new = Image.new("RGB", image.size, (255, 255, 255)) |
| image_new.paste(image, mask=image.split()[3]) |
| image = image_new |
|
|
| return { |
| "image": self.to_tensor(image), |
| "width": width, |
| "height": height, |
| } |
|
|
| def to_tensor(self, image: Image.Image) -> torch.Tensor: |
| x = np.array(image) / 127.5 - 1.0 |
| x = x.transpose(2, 0, 1).astype(np.float32) |
| return torch.as_tensor(x) |
|
|
| def postprocess( |
| self, |
| x: torch.Tensor, |
| width: int | None = None, |
| height: int | None = None, |
| ) -> Image.Image: |
| x_np = x.detach().cpu().numpy() |
| x_np = x_np.transpose(1, 2, 0) |
| x_np = (x_np + 1.0) * 127.5 |
| x_np = np.clip(x_np, 0, 255).astype(np.uint8) |
| image = Image.fromarray(x_np) |
|
|
| |
| width = width or self.size |
| height = height or self.size |
| image = image.resize((width, height), resample=self.resample) |
|
|
| return image |
|
|