| import io |
| from PIL import Image |
| import numpy as np |
| import torch |
|
|
|
|
| class ImageEncoder: |
|
|
| @torch.inference_mode() |
| def encode_torch(self, img: torch.Tensor, quality=95): |
| if img.ndim == 2: |
| img = ( |
| img[None] |
| .repeat_interleave(3, dim=0) |
| .permute(1, 2, 0) |
| .contiguous() |
| .clamp(0, 255) |
| .type(torch.uint8) |
| ) |
| elif img.ndim == 3: |
| if img.shape[0] == 3: |
| img = img.permute(1, 2, 0).contiguous().clamp(0, 255).type(torch.uint8) |
| elif img.shape[2] == 3: |
| img = img.contiguous().clamp(0, 255).type(torch.uint8) |
| else: |
| raise ValueError(f"Unsupported image shape: {img.shape}") |
| else: |
| raise ValueError(f"Unsupported image num dims: {img.ndim}") |
|
|
| img = img.cpu().numpy().astype(np.uint8) |
| im = Image.fromarray(img) |
| iob = io.BytesIO() |
| im.save(iob, format="JPEG", quality=quality) |
| iob.seek(0) |
| return iob |
|
|