| import PIL.Image |
| import PIL.ImageOps |
| from packaging import version |
| from PIL import Image |
|
|
|
|
| if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): |
| PIL_INTERPOLATION = { |
| "linear": PIL.Image.Resampling.BILINEAR, |
| "bilinear": PIL.Image.Resampling.BILINEAR, |
| "bicubic": PIL.Image.Resampling.BICUBIC, |
| "lanczos": PIL.Image.Resampling.LANCZOS, |
| "nearest": PIL.Image.Resampling.NEAREST, |
| } |
| else: |
| PIL_INTERPOLATION = { |
| "linear": PIL.Image.LINEAR, |
| "bilinear": PIL.Image.BILINEAR, |
| "bicubic": PIL.Image.BICUBIC, |
| "lanczos": PIL.Image.LANCZOS, |
| "nearest": PIL.Image.NEAREST, |
| } |
|
|
|
|
| def pt_to_pil(images): |
| """ |
| Convert a torch image to a PIL image. |
| """ |
| images = (images / 2 + 0.5).clamp(0, 1) |
| images = images.cpu().permute(0, 2, 3, 1).float().numpy() |
| images = numpy_to_pil(images) |
| return images |
|
|
|
|
| def numpy_to_pil(images): |
| """ |
| Convert a numpy image or a batch of images to a PIL image. |
| """ |
| if images.ndim == 3: |
| images = images[None, ...] |
| images = (images * 255).round().astype("uint8") |
| if images.shape[-1] == 1: |
| |
| pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] |
| else: |
| pil_images = [Image.fromarray(image) for image in images] |
|
|
| return pil_images |
|
|