| from typing import Dict, List, Any |
| from PIL import Image |
| import torch |
| from diffusers import StableDiffusionUpscalePipeline |
| import base64 |
| from io import BytesIO |
| from transformers.utils import logging |
|
|
| logging.set_verbosity_info() |
| logger = logging.get_logger("transformers") |
| logger.info("INFO") |
| logger.warning("WARN") |
|
|
|
|
| |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| |
| |
|
|
| class EndpointHandler(): |
| def __init__(self, path=""): |
| self.path = path |
| |
| |
| |
|
|
| def __call__(self, data) -> List[Dict[str, Any]]: |
| """ |
| Args: |
| image (:obj:`string`) |
| Return: |
| A :obj:`dict`:. base64 encoded image |
| """ |
|
|
| self.pipe = StableDiffusionUpscalePipeline.from_pretrained(self.path, torch_dtype=torch.float16) |
| self.pipe = self.pipe.to(device) |
| |
| logger.info('data received %s', data) |
| inputs = data.get("inputs") |
| logger.info('inputs received %s', inputs) |
|
|
| image_base64 = base64.b64decode(inputs['image']) |
| logger.info('image_base64') |
| image_bytes = BytesIO(image_base64) |
| logger.info('image_bytes') |
| image = Image.open(image_bytes).convert("RGB") |
| prompt = inputs['prompt'] |
| logger.info('image') |
| |
| upscaled_image = self.pipe(prompt, image).images[0] |
|
|
| buffered = BytesIO() |
| upscaled_image.save(buffered, format="JPEG") |
| img_str = base64.b64encode(buffered.getvalue()) |
|
|
| |
| return {"image": img_str.decode()} |
| |
|
|