| from basicsr.archs.rrdbnet_arch import RRDBNet |
| from realesrgan import RealESRGANer |
| from diffusers import StableDiffusionPipeline |
| import base64 |
| from PIL import Image |
| from io import BytesIO |
| import torch |
| from torch.cuda.amp import autocast |
| from typing import Dict, Any |
| import numpy as np |
|
|
| |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| class EndpointHandler(): |
|
|
| def __init__(self, path=""): |
| |
| self.pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float32) |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| self.pipe = self.pipe.to(device) |
|
|
| |
| checkpoint = torch.load("/repository/RealESRGAN_x4plus_anime_6B.pth") |
|
|
| |
| if "params_ema" in checkpoint: |
| state_dict = checkpoint["params_ema"] |
| else: |
| state_dict = checkpoint |
|
|
| |
| self.model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) |
| self.model.load_state_dict(state_dict) |
| self.model.to(device) |
| self.model.eval() |
|
|
| |
| self.upsampler = RealESRGANer(scale=4, model=self.model, tile=0, model_path="/repository/RealESRGAN_x4plus_anime_6B.pth") |
|
|
|
|
| def __call__(self, data: Dict[str, Any], output_size=(512, )) -> Dict[str, str]: |
| inputs = data.get("inputs") |
| negative_prompt = data.get("negative_prompt", None) |
|
|
| |
| with autocast(): |
| output = self.pipe(inputs, guidance_scale=7.5, negative_prompt=negative_prompt) |
| image = output['images'][0] |
|
|
| |
| image = np.clip(image, 0, 255) / 255.0 |
|
|
| |
| tensor_image = torch.from_numpy(np.array(image)).float().permute(2, 0, 1).unsqueeze(0).to(device) |
|
|
| |
| with torch.no_grad(): |
| esrgan_output = self.model(tensor_image) |
|
|
| |
| esrgan_output = esrgan_output.squeeze().permute(1, 2, 0).cpu().numpy() |
| esrgan_output = np.clip(esrgan_output, 0, 1) |
| esrgan_image = Image.fromarray((esrgan_output * 255).astype('uint8')) |
|
|
| |
| buffered = BytesIO() |
| esrgan_image.save(buffered, format="PNG") |
| img_str = base64.b64encode(buffered.getvalue()) |
|
|
| return {"image": img_str.decode()} |