| import os |
| import io |
| import torch |
| import logging |
| import base64 |
| import requests |
| import numpy as np |
| import cv2 |
| from PIL import Image |
|
|
| from gfpgan import GFPGANer |
| from realesrgan import RealESRGANer |
| from realesrgan.archs.srvgg_arch import SRVGGNetCompact |
|
|
| logger = logging.getLogger(__name__) |
|
|
| class EndpointHandler: |
| def __init__(self, path="."): |
| logger.info("π [INIT] GFPGAN + Real-ESRGAN handler starting...") |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| self.half = self.device == "cuda" |
| self.path = path |
|
|
| |
| self.gfpgan_model_url = ( |
| "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" |
| ) |
| self.realesr_model_url = ( |
| "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth" |
| ) |
|
|
| |
| self.gfpgan_model_path = os.path.join(path, "GFPGANv1.4.pth") |
| self.realesr_model_path = os.path.join(path, "realesr-general-x4v3.pth") |
|
|
| self.bg_upsampler = None |
| self.restorer = None |
|
|
| |
| self._ensure_model(self.gfpgan_model_url, self.gfpgan_model_path) |
| self._ensure_model(self.realesr_model_url, self.realesr_model_path) |
|
|
| logger.info(f"π§ Device: {self.device}, half precision: {self.half}") |
|
|
| def _ensure_model(self, url, path): |
| """Download model if missing.""" |
| if not os.path.exists(path): |
| logger.info(f"β¬οΈ Downloading model from {url}") |
| r = requests.get(url, timeout=60) |
| r.raise_for_status() |
| with open(path, "wb") as f: |
| f.write(r.content) |
| logger.info(f"β
Model saved to {path}") |
| else: |
| logger.info(f"π Found cached model: {path}") |
|
|
| def _init_models(self): |
| """Lazy-load ESRGAN + GFPGAN models.""" |
| if self.bg_upsampler is None: |
| logger.info("π§© Initializing Real-ESRGAN upsampler...") |
| model = SRVGGNetCompact( |
| num_in_ch=3, num_out_ch=3, num_feat=64, |
| num_conv=32, upscale=4, act_type="prelu" |
| ) |
| self.bg_upsampler = RealESRGANer( |
| scale=4, |
| model_path=self.realesr_model_path, |
| model=model, |
| tile=400, |
| tile_pad=10, |
| pre_pad=0, |
| half=self.half, |
| device=self.device, |
| ) |
|
|
| if self.restorer is None: |
| logger.info("𧬠Initializing GFPGAN restorer...") |
| self.restorer = GFPGANer( |
| model_path=self.gfpgan_model_path, |
| upscale=2, |
| arch="clean", |
| channel_multiplier=2, |
| bg_upsampler=self.bg_upsampler, |
| ) |
| logger.info("β
Models ready!") |
|
|
| def _load_image(self, data): |
| """Accept base64, raw bytes, or URL and return PIL image.""" |
| if isinstance(data, dict) and "inputs" in data: |
| data = data["inputs"] |
|
|
| if isinstance(data, (bytes, bytearray)): |
| logger.info("π¦ Received raw bytes input") |
| return Image.open(io.BytesIO(data)).convert("RGB") |
|
|
| if isinstance(data, str): |
| if data.startswith("http"): |
| logger.info(f"π Downloading image from URL: {data}") |
| resp = requests.get(data) |
| return Image.open(io.BytesIO(resp.content)).convert("RGB") |
| else: |
| |
| logger.info("𧬠Decoding base64 image input") |
| try: |
| decoded = base64.b64decode(data) |
| return Image.open(io.BytesIO(decoded)).convert("RGB") |
| except Exception as e: |
| logger.error(f"β Failed to decode base64: {e}") |
| raise ValueError("Invalid base64 image input") |
|
|
| raise ValueError("Unsupported input type") |
|
|
| def __call__(self, data): |
| logger.info("βοΈ Starting GFPGAN inference pipeline...") |
| self._init_models() |
|
|
| |
| image = self._load_image(data) |
| input_img = np.array(image, dtype=np.uint8) |
| logger.info(f"π Input image shape: {input_img.shape}") |
|
|
| |
| cropped_faces, restored_faces, restored_img = self.restorer.enhance( |
| input_img, has_aligned=False, only_center_face=False, paste_back=True |
| ) |
|
|
| logger.info("πΌοΈ Restoration complete, preparing output...") |
|
|
| |
| restored_img_rgb = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB) |
| restored_img_rgb = np.clip(restored_img_rgb, 0, 255).astype(np.uint8) |
|
|
| |
| _, buffer = cv2.imencode(".jpg", restored_img_rgb) |
| b64_output = base64.b64encode(buffer).decode("utf-8") |
|
|
| logger.info("β
Returning base64-encoded image JSON response") |
|
|
| return { |
| "image": b64_output, |
| "status": "success", |
| "info": "Restored with GFPGAN v1.4 + Real-ESRGAN x4v3 (RGB fixed)" |
| } |
|
|
|
|