welcom / sprite_processor.py
Skydata001's picture
Upload 3 files
925fbb1 verified
"""
Sprite Image Enhancement Module
Uses Real-ESRGAN for high-quality upscaling
"""
import cv2
import numpy as np
import torch
from PIL import Image
import os
class SpriteProcessor:
"""Processor for enhancing sprite sheet images"""
def __init__(self):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model = None
self._load_model()
def _load_model(self):
"""Load Real-ESRGAN model"""
try:
from realesrgan import RealESRGANer
from basicsr.archs.rrdbnet_arch import RRDBNet
# Create model
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
num_block=23, num_grow_ch=32, scale=4)
# Initialize Real-ESRGAN
model_path = "weights/RealESRGAN_x4plus.pth"
if os.path.exists(model_path):
self.model = RealESRGANer(
scale=4,
model_path=model_path,
model=model,
tile=0,
pre_pad=0,
half=False,
device=self.device
)
else:
print("Warning: Real-ESRGAN model not found, using fallback enhancement")
self.model = None
except Exception as e:
print(f"Error loading Real-ESRGAN: {e}")
self.model = None
def enhance_image(self, image: np.ndarray, scale: int = 4) -> np.ndarray:
"""
Enhance image quality using Real-ESRGAN or fallback methods
Args:
image: Input image (BGR or BGRA)
scale: Upscaling factor (2 or 4)
Returns:
Enhanced image
"""
# Handle alpha channel
has_alpha = len(image.shape) == 3 and image.shape[2] == 4
if has_alpha:
# Separate alpha channel
bgr = image[:, :, :3]
alpha = image[:, :, 3]
else:
bgr = image
alpha = None
# Enhance RGB channels
if self.model is not None and scale > 1:
try:
# Convert BGR to RGB for the model
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
# Apply Real-ESRGAN
enhanced_rgb, _ = self.model.enhance(rgb, outscale=scale)
# Convert back to BGR
enhanced_bgr = cv2.cvtColor(enhanced_rgb, cv2.COLOR_RGB2BGR)
except Exception as e:
print(f"Real-ESRGAN failed, using fallback: {e}")
enhanced_bgr = self._fallback_enhance(bgr, scale)
else:
enhanced_bgr = self._fallback_enhance(bgr, scale)
# Enhance alpha channel if present
if alpha is not None and scale > 1:
enhanced_alpha = cv2.resize(alpha, None, fx=scale, fy=scale,
interpolation=cv2.INTER_NEAREST)
# Merge channels
enhanced_image = cv2.merge([enhanced_bgr, enhanced_alpha])
else:
enhanced_image = enhanced_bgr
return enhanced_image
def _fallback_enhance(self, image: np.ndarray, scale: int) -> np.ndarray:
"""
Fallback enhancement using OpenCV
Args:
image: Input BGR image
scale: Upscaling factor
Returns:
Enhanced image
"""
# Resize with high-quality interpolation
new_width = int(image.shape[1] * scale)
new_height = int(image.shape[0] * scale)
enhanced = cv2.resize(image, (new_width, new_height),
interpolation=cv2.INTER_CUBIC)
# Apply sharpening
kernel = np.array([[-1, -1, -1],
[-1, 9, -1],
[-1, -1, -1]])
enhanced = cv2.filter2D(enhanced, -1, kernel)
# Denoise
enhanced = cv2.fastNlMeansDenoisingColored(enhanced, None, 5, 5, 7, 21)
return enhanced
def sharpen_image(self, image: np.ndarray, strength: float = 1.0) -> np.ndarray:
"""
Apply sharpening filter
Args:
image: Input image
strength: Sharpening strength
Returns:
Sharpened image
"""
kernel = np.array([[-1, -1, -1],
[-1, 9, -1],
[-1, -1, -1]]) * strength
sharpened = cv2.filter2D(image, -1, kernel)
return sharpened
def remove_blur(self, image: np.ndarray) -> np.ndarray:
"""
Reduce blur using deconvolution
Args:
image: Input image
Returns:
Deblurred image
"""
# Create a point spread function (PSF)
psf_size = 5
psf = np.ones((psf_size, psf_size)) / (psf_size ** 2)
# Simple deconvolution (Wiener filter approximation)
result = image.copy()
for i in range(3): # For each channel
channel = image[:, :, i].astype(np.float32) / 255.0
# FFT
psf_fft = np.fft.fft2(psf, s=channel.shape)
channel_fft = np.fft.fft2(channel)
# Wiener deconvolution
K = 0.01 # Noise to signal ratio
deconv_fft = channel_fft * np.conj(psf_fft) / (np.abs(psf_fft) ** 2 + K)
# Inverse FFT
deconv = np.fft.ifft2(deconv_fft).real
# Clip and convert back
deconv = np.clip(deconv * 255, 0, 255).astype(np.uint8)
result[:, :, i] = deconv
return result
def enhance_contrast(self, image: np.ndarray) -> np.ndarray:
"""
Enhance contrast using CLAHE
Args:
image: Input image
Returns:
Contrast-enhanced image
"""
lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
l, a, b = cv2.split(lab)
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
l = clahe.apply(l)
enhanced = cv2.merge([l, a, b])
enhanced = cv2.cvtColor(enhanced, cv2.COLOR_LAB2BGR)
return enhanced