Spaces:
Sleeping
Sleeping
| """ | |
| Sprite Image Enhancement Module | |
| Uses Real-ESRGAN for high-quality upscaling | |
| """ | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| import os | |
| class SpriteProcessor: | |
| """Processor for enhancing sprite sheet images""" | |
| def __init__(self): | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| self.model = None | |
| self._load_model() | |
| def _load_model(self): | |
| """Load Real-ESRGAN model""" | |
| try: | |
| from realesrgan import RealESRGANer | |
| from basicsr.archs.rrdbnet_arch import RRDBNet | |
| # Create model | |
| model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, | |
| num_block=23, num_grow_ch=32, scale=4) | |
| # Initialize Real-ESRGAN | |
| model_path = "weights/RealESRGAN_x4plus.pth" | |
| if os.path.exists(model_path): | |
| self.model = RealESRGANer( | |
| scale=4, | |
| model_path=model_path, | |
| model=model, | |
| tile=0, | |
| pre_pad=0, | |
| half=False, | |
| device=self.device | |
| ) | |
| else: | |
| print("Warning: Real-ESRGAN model not found, using fallback enhancement") | |
| self.model = None | |
| except Exception as e: | |
| print(f"Error loading Real-ESRGAN: {e}") | |
| self.model = None | |
| def enhance_image(self, image: np.ndarray, scale: int = 4) -> np.ndarray: | |
| """ | |
| Enhance image quality using Real-ESRGAN or fallback methods | |
| Args: | |
| image: Input image (BGR or BGRA) | |
| scale: Upscaling factor (2 or 4) | |
| Returns: | |
| Enhanced image | |
| """ | |
| # Handle alpha channel | |
| has_alpha = len(image.shape) == 3 and image.shape[2] == 4 | |
| if has_alpha: | |
| # Separate alpha channel | |
| bgr = image[:, :, :3] | |
| alpha = image[:, :, 3] | |
| else: | |
| bgr = image | |
| alpha = None | |
| # Enhance RGB channels | |
| if self.model is not None and scale > 1: | |
| try: | |
| # Convert BGR to RGB for the model | |
| rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) | |
| # Apply Real-ESRGAN | |
| enhanced_rgb, _ = self.model.enhance(rgb, outscale=scale) | |
| # Convert back to BGR | |
| enhanced_bgr = cv2.cvtColor(enhanced_rgb, cv2.COLOR_RGB2BGR) | |
| except Exception as e: | |
| print(f"Real-ESRGAN failed, using fallback: {e}") | |
| enhanced_bgr = self._fallback_enhance(bgr, scale) | |
| else: | |
| enhanced_bgr = self._fallback_enhance(bgr, scale) | |
| # Enhance alpha channel if present | |
| if alpha is not None and scale > 1: | |
| enhanced_alpha = cv2.resize(alpha, None, fx=scale, fy=scale, | |
| interpolation=cv2.INTER_NEAREST) | |
| # Merge channels | |
| enhanced_image = cv2.merge([enhanced_bgr, enhanced_alpha]) | |
| else: | |
| enhanced_image = enhanced_bgr | |
| return enhanced_image | |
| def _fallback_enhance(self, image: np.ndarray, scale: int) -> np.ndarray: | |
| """ | |
| Fallback enhancement using OpenCV | |
| Args: | |
| image: Input BGR image | |
| scale: Upscaling factor | |
| Returns: | |
| Enhanced image | |
| """ | |
| # Resize with high-quality interpolation | |
| new_width = int(image.shape[1] * scale) | |
| new_height = int(image.shape[0] * scale) | |
| enhanced = cv2.resize(image, (new_width, new_height), | |
| interpolation=cv2.INTER_CUBIC) | |
| # Apply sharpening | |
| kernel = np.array([[-1, -1, -1], | |
| [-1, 9, -1], | |
| [-1, -1, -1]]) | |
| enhanced = cv2.filter2D(enhanced, -1, kernel) | |
| # Denoise | |
| enhanced = cv2.fastNlMeansDenoisingColored(enhanced, None, 5, 5, 7, 21) | |
| return enhanced | |
| def sharpen_image(self, image: np.ndarray, strength: float = 1.0) -> np.ndarray: | |
| """ | |
| Apply sharpening filter | |
| Args: | |
| image: Input image | |
| strength: Sharpening strength | |
| Returns: | |
| Sharpened image | |
| """ | |
| kernel = np.array([[-1, -1, -1], | |
| [-1, 9, -1], | |
| [-1, -1, -1]]) * strength | |
| sharpened = cv2.filter2D(image, -1, kernel) | |
| return sharpened | |
| def remove_blur(self, image: np.ndarray) -> np.ndarray: | |
| """ | |
| Reduce blur using deconvolution | |
| Args: | |
| image: Input image | |
| Returns: | |
| Deblurred image | |
| """ | |
| # Create a point spread function (PSF) | |
| psf_size = 5 | |
| psf = np.ones((psf_size, psf_size)) / (psf_size ** 2) | |
| # Simple deconvolution (Wiener filter approximation) | |
| result = image.copy() | |
| for i in range(3): # For each channel | |
| channel = image[:, :, i].astype(np.float32) / 255.0 | |
| # FFT | |
| psf_fft = np.fft.fft2(psf, s=channel.shape) | |
| channel_fft = np.fft.fft2(channel) | |
| # Wiener deconvolution | |
| K = 0.01 # Noise to signal ratio | |
| deconv_fft = channel_fft * np.conj(psf_fft) / (np.abs(psf_fft) ** 2 + K) | |
| # Inverse FFT | |
| deconv = np.fft.ifft2(deconv_fft).real | |
| # Clip and convert back | |
| deconv = np.clip(deconv * 255, 0, 255).astype(np.uint8) | |
| result[:, :, i] = deconv | |
| return result | |
| def enhance_contrast(self, image: np.ndarray) -> np.ndarray: | |
| """ | |
| Enhance contrast using CLAHE | |
| Args: | |
| image: Input image | |
| Returns: | |
| Contrast-enhanced image | |
| """ | |
| lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB) | |
| l, a, b = cv2.split(lab) | |
| clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) | |
| l = clahe.apply(l) | |
| enhanced = cv2.merge([l, a, b]) | |
| enhanced = cv2.cvtColor(enhanced, cv2.COLOR_LAB2BGR) | |
| return enhanced | |