Spaces:
Sleeping
Sleeping
Upload 3 files
Browse files- frame_detector.py +346 -0
- frame_namer.py +329 -0
- sprite_processor.py +206 -0
frame_detector.py
ADDED
|
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Frame Detection Module
|
| 3 |
+
Automatically detects and extracts sprite frames from sprite sheets
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import cv2
|
| 7 |
+
import numpy as np
|
| 8 |
+
from typing import List, Tuple
|
| 9 |
+
from scipy import ndimage
|
| 10 |
+
from skimage import measure
|
| 11 |
+
|
| 12 |
+
class FrameDetector:
|
| 13 |
+
"""Detector for finding sprite frames in sprite sheets"""
|
| 14 |
+
|
| 15 |
+
def __init__(self):
|
| 16 |
+
self.min_frame_size = 16 # Minimum frame size in pixels
|
| 17 |
+
self.max_frame_size = 512 # Maximum frame size in pixels
|
| 18 |
+
|
| 19 |
+
def detect_frames_auto(self, image: np.ndarray, padding: int = 2) -> Tuple[List[np.ndarray], List[Tuple]]:
|
| 20 |
+
"""
|
| 21 |
+
Automatically detect frames in a sprite sheet
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
image: Input sprite sheet image
|
| 25 |
+
padding: Padding to add around each frame
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
Tuple of (list of frame images, list of bounding boxes)
|
| 29 |
+
"""
|
| 30 |
+
# Handle alpha channel
|
| 31 |
+
if len(image.shape) == 3 and image.shape[2] == 4:
|
| 32 |
+
# Use alpha channel for detection
|
| 33 |
+
alpha = image[:, :, 3]
|
| 34 |
+
# Create binary mask where alpha > 0
|
| 35 |
+
_, binary = cv2.threshold(alpha, 10, 255, cv2.THRESH_BINARY)
|
| 36 |
+
else:
|
| 37 |
+
# Convert to grayscale and create mask
|
| 38 |
+
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
| 39 |
+
_, binary = cv2.threshold(gray, 10, 255, cv2.THRESH_BINARY)
|
| 40 |
+
|
| 41 |
+
# Find connected components
|
| 42 |
+
labels = measure.label(binary, connectivity=2)
|
| 43 |
+
regions = measure.regionprops(labels)
|
| 44 |
+
|
| 45 |
+
# Filter regions by size
|
| 46 |
+
valid_regions = []
|
| 47 |
+
for region in regions:
|
| 48 |
+
minr, minc, maxr, maxc = region.bbox
|
| 49 |
+
width = maxc - minc
|
| 50 |
+
height = maxr - minr
|
| 51 |
+
|
| 52 |
+
if (self.min_frame_size <= width <= self.max_frame_size and
|
| 53 |
+
self.min_frame_size <= height <= self.max_frame_size):
|
| 54 |
+
valid_regions.append(region)
|
| 55 |
+
|
| 56 |
+
# If no valid regions found, try grid-based detection
|
| 57 |
+
if len(valid_regions) == 0:
|
| 58 |
+
return self._detect_grid_based(image, padding)
|
| 59 |
+
|
| 60 |
+
# Sort regions by x-coordinate (left to right)
|
| 61 |
+
valid_regions.sort(key=lambda r: r.bbox[1])
|
| 62 |
+
|
| 63 |
+
# Extract frames
|
| 64 |
+
frames = []
|
| 65 |
+
frame_boxes = []
|
| 66 |
+
|
| 67 |
+
for region in valid_regions:
|
| 68 |
+
minr, minc, maxr, maxc = region.bbox
|
| 69 |
+
|
| 70 |
+
# Add padding
|
| 71 |
+
minr = max(0, minr - padding)
|
| 72 |
+
minc = max(0, minc - padding)
|
| 73 |
+
maxr = min(image.shape[0], maxr + padding)
|
| 74 |
+
maxc = min(image.shape[1], maxc + padding)
|
| 75 |
+
|
| 76 |
+
# Extract frame
|
| 77 |
+
frame = image[minr:maxr, minc:maxc]
|
| 78 |
+
|
| 79 |
+
frames.append(frame)
|
| 80 |
+
frame_boxes.append((minr, minc, maxr, maxc))
|
| 81 |
+
|
| 82 |
+
# If too many small regions, use grid-based approach
|
| 83 |
+
if len(frames) > 20:
|
| 84 |
+
return self._detect_grid_based(image, padding)
|
| 85 |
+
|
| 86 |
+
return frames, frame_boxes
|
| 87 |
+
|
| 88 |
+
def _detect_grid_based(self, image: np.ndarray, padding: int = 2) -> Tuple[List[np.ndarray], List[Tuple]]:
|
| 89 |
+
"""
|
| 90 |
+
Detect frames using grid-based approach with improved filtering
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
image: Input sprite sheet image
|
| 94 |
+
padding: Padding to add around each frame
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
Tuple of (list of frame images, list of bounding boxes)
|
| 98 |
+
"""
|
| 99 |
+
# Handle alpha channel
|
| 100 |
+
if len(image.shape) == 3 and image.shape[2] == 4:
|
| 101 |
+
alpha = image[:, :, 3]
|
| 102 |
+
_, binary = cv2.threshold(alpha, 10, 255, cv2.THRESH_BINARY)
|
| 103 |
+
else:
|
| 104 |
+
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
| 105 |
+
_, binary = cv2.threshold(gray, 10, 255, cv2.THRESH_BINARY)
|
| 106 |
+
|
| 107 |
+
# Apply morphological operations to clean up noise
|
| 108 |
+
kernel = np.ones((3, 3), np.uint8)
|
| 109 |
+
binary = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel)
|
| 110 |
+
binary = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel)
|
| 111 |
+
|
| 112 |
+
# Find horizontal and vertical projections
|
| 113 |
+
h_proj = np.sum(binary, axis=1)
|
| 114 |
+
v_proj = np.sum(binary, axis=0)
|
| 115 |
+
|
| 116 |
+
# Find gaps (low values in projection)
|
| 117 |
+
h_threshold = np.max(h_proj) * 0.1
|
| 118 |
+
v_threshold = np.max(v_proj) * 0.1
|
| 119 |
+
|
| 120 |
+
h_gaps = h_proj < h_threshold
|
| 121 |
+
v_gaps = v_proj < v_threshold
|
| 122 |
+
|
| 123 |
+
# Find row boundaries
|
| 124 |
+
row_boundaries = []
|
| 125 |
+
in_gap = True
|
| 126 |
+
for i, is_gap in enumerate(h_gaps):
|
| 127 |
+
if in_gap and not is_gap:
|
| 128 |
+
row_boundaries.append(i)
|
| 129 |
+
in_gap = False
|
| 130 |
+
elif not in_gap and is_gap:
|
| 131 |
+
row_boundaries.append(i)
|
| 132 |
+
in_gap = True
|
| 133 |
+
|
| 134 |
+
# Find column boundaries
|
| 135 |
+
col_boundaries = []
|
| 136 |
+
in_gap = True
|
| 137 |
+
for i, is_gap in enumerate(v_gaps):
|
| 138 |
+
if in_gap and not is_gap:
|
| 139 |
+
col_boundaries.append(i)
|
| 140 |
+
in_gap = False
|
| 141 |
+
elif not in_gap and is_gap:
|
| 142 |
+
col_boundaries.append(i)
|
| 143 |
+
in_gap = True
|
| 144 |
+
|
| 145 |
+
# Ensure even number of boundaries
|
| 146 |
+
if len(row_boundaries) % 2 != 0:
|
| 147 |
+
row_boundaries.append(binary.shape[0])
|
| 148 |
+
if len(col_boundaries) % 2 != 0:
|
| 149 |
+
col_boundaries.append(binary.shape[1])
|
| 150 |
+
|
| 151 |
+
# Extract frames
|
| 152 |
+
frames = []
|
| 153 |
+
frame_boxes = []
|
| 154 |
+
|
| 155 |
+
# If we have valid boundaries
|
| 156 |
+
if len(row_boundaries) >= 2 and len(col_boundaries) >= 2:
|
| 157 |
+
for i in range(0, len(row_boundaries), 2):
|
| 158 |
+
for j in range(0, len(col_boundaries), 2):
|
| 159 |
+
minr = row_boundaries[i]
|
| 160 |
+
maxr = row_boundaries[i + 1]
|
| 161 |
+
minc = col_boundaries[j]
|
| 162 |
+
maxc = col_boundaries[j + 1]
|
| 163 |
+
|
| 164 |
+
# Add padding
|
| 165 |
+
minr_p = max(0, minr - padding)
|
| 166 |
+
minc_p = max(0, minc - padding)
|
| 167 |
+
maxr_p = min(image.shape[0], maxr + padding)
|
| 168 |
+
maxc_p = min(image.shape[1], maxc + padding)
|
| 169 |
+
|
| 170 |
+
# Extract frame
|
| 171 |
+
frame = image[minr_p:maxr_p, minc_p:maxc_p]
|
| 172 |
+
|
| 173 |
+
# Check if frame has content and meets size requirements
|
| 174 |
+
frame_width = maxc - minc
|
| 175 |
+
frame_height = maxr - minr
|
| 176 |
+
|
| 177 |
+
# Filter out very small frames (likely effects/particles)
|
| 178 |
+
min_content_width = max(15, image.shape[1] // 50) # At least 15px or 2% of image width
|
| 179 |
+
min_content_height = max(20, image.shape[0] // 3) # At least 20px or 33% of image height
|
| 180 |
+
|
| 181 |
+
has_content = np.sum(frame) > 0
|
| 182 |
+
has_valid_size = (frame_width >= min_content_width and
|
| 183 |
+
frame_height >= min_content_height)
|
| 184 |
+
|
| 185 |
+
if has_content and has_valid_size:
|
| 186 |
+
frames.append(frame)
|
| 187 |
+
frame_boxes.append((minr_p, minc_p, maxr_p, maxc_p))
|
| 188 |
+
|
| 189 |
+
# If still no frames, use equal division
|
| 190 |
+
if len(frames) == 0:
|
| 191 |
+
return self._detect_equal_division(image, padding)
|
| 192 |
+
|
| 193 |
+
return frames, frame_boxes
|
| 194 |
+
|
| 195 |
+
def _detect_equal_division(self, image: np.ndarray, padding: int = 2,
|
| 196 |
+
num_frames: int = 8) -> Tuple[List[np.ndarray], List[Tuple]]:
|
| 197 |
+
"""
|
| 198 |
+
Detect frames by equal division
|
| 199 |
+
|
| 200 |
+
Args:
|
| 201 |
+
image: Input sprite sheet image
|
| 202 |
+
padding: Padding to add around each frame
|
| 203 |
+
num_frames: Number of frames to divide into
|
| 204 |
+
|
| 205 |
+
Returns:
|
| 206 |
+
Tuple of (list of frame images, list of bounding boxes)
|
| 207 |
+
"""
|
| 208 |
+
frames = []
|
| 209 |
+
frame_boxes = []
|
| 210 |
+
|
| 211 |
+
img_width = image.shape[1]
|
| 212 |
+
img_height = image.shape[0]
|
| 213 |
+
|
| 214 |
+
# Assume horizontal layout
|
| 215 |
+
frame_width = img_width // num_frames
|
| 216 |
+
frame_height = img_height
|
| 217 |
+
|
| 218 |
+
for i in range(num_frames):
|
| 219 |
+
minc = i * frame_width
|
| 220 |
+
maxc = (i + 1) * frame_width if i < num_frames - 1 else img_width
|
| 221 |
+
minr = 0
|
| 222 |
+
maxr = img_height
|
| 223 |
+
|
| 224 |
+
# Add padding
|
| 225 |
+
minc_p = max(0, minc - padding)
|
| 226 |
+
minr_p = max(0, minr - padding)
|
| 227 |
+
maxc_p = min(img_width, maxc + padding)
|
| 228 |
+
maxr_p = min(img_height, maxr + padding)
|
| 229 |
+
|
| 230 |
+
frame = image[minr_p:maxr_p, minc_p:maxc_p]
|
| 231 |
+
frames.append(frame)
|
| 232 |
+
frame_boxes.append((minr_p, minc_p, maxr_p, maxc_p))
|
| 233 |
+
|
| 234 |
+
return frames, frame_boxes
|
| 235 |
+
|
| 236 |
+
def detect_frames_manual(self, image: np.ndarray, num_frames: int,
|
| 237 |
+
padding: int = 2) -> Tuple[List[np.ndarray], List[Tuple]]:
|
| 238 |
+
"""
|
| 239 |
+
Manually specify number of frames
|
| 240 |
+
|
| 241 |
+
Args:
|
| 242 |
+
image: Input sprite sheet image
|
| 243 |
+
num_frames: Number of frames
|
| 244 |
+
padding: Padding to add around each frame
|
| 245 |
+
|
| 246 |
+
Returns:
|
| 247 |
+
Tuple of (list of frame images, list of bounding boxes)
|
| 248 |
+
"""
|
| 249 |
+
return self._detect_equal_division(image, padding, num_frames)
|
| 250 |
+
|
| 251 |
+
def refine_frame_boundaries(self, image: np.ndarray, frame: np.ndarray,
|
| 252 |
+
bbox: Tuple) -> Tuple[np.ndarray, Tuple]:
|
| 253 |
+
"""
|
| 254 |
+
Refine frame boundaries to remove excess transparent space
|
| 255 |
+
|
| 256 |
+
Args:
|
| 257 |
+
image: Original image
|
| 258 |
+
frame: Extracted frame
|
| 259 |
+
bbox: Bounding box (minr, minc, maxr, maxc)
|
| 260 |
+
|
| 261 |
+
Returns:
|
| 262 |
+
Refined frame and bounding box
|
| 263 |
+
"""
|
| 264 |
+
minr, minc, maxr, maxc = bbox
|
| 265 |
+
|
| 266 |
+
# Handle alpha channel
|
| 267 |
+
if len(frame.shape) == 3 and frame.shape[2] == 4:
|
| 268 |
+
alpha = frame[:, :, 3]
|
| 269 |
+
# Find non-transparent pixels
|
| 270 |
+
rows = np.any(alpha > 10, axis=1)
|
| 271 |
+
cols = np.any(alpha > 10, axis=0)
|
| 272 |
+
else:
|
| 273 |
+
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
| 274 |
+
rows = np.any(gray > 10, axis=1)
|
| 275 |
+
cols = np.any(gray > 10, axis=0)
|
| 276 |
+
|
| 277 |
+
# Find bounds
|
| 278 |
+
row_indices = np.where(rows)[0]
|
| 279 |
+
col_indices = np.where(cols)[0]
|
| 280 |
+
|
| 281 |
+
if len(row_indices) > 0 and len(col_indices) > 0:
|
| 282 |
+
# Calculate new bounds
|
| 283 |
+
new_minr = minr + row_indices[0]
|
| 284 |
+
new_maxr = minr + row_indices[-1] + 1
|
| 285 |
+
new_minc = minc + col_indices[0]
|
| 286 |
+
new_maxc = minc + col_indices[-1] + 1
|
| 287 |
+
|
| 288 |
+
# Extract refined frame
|
| 289 |
+
refined_frame = image[new_minr:new_maxr, new_minc:new_maxc]
|
| 290 |
+
|
| 291 |
+
return refined_frame, (new_minr, new_minc, new_maxr, new_maxc)
|
| 292 |
+
|
| 293 |
+
return frame, bbox
|
| 294 |
+
|
| 295 |
+
def detect_frame_size(self, image: np.ndarray) -> Tuple[int, int]:
|
| 296 |
+
"""
|
| 297 |
+
Detect the size of individual frames
|
| 298 |
+
|
| 299 |
+
Args:
|
| 300 |
+
image: Input sprite sheet image
|
| 301 |
+
|
| 302 |
+
Returns:
|
| 303 |
+
Tuple of (frame_width, frame_height)
|
| 304 |
+
"""
|
| 305 |
+
# Handle alpha channel
|
| 306 |
+
if len(image.shape) == 3 and image.shape[2] == 4:
|
| 307 |
+
alpha = image[:, :, 3]
|
| 308 |
+
_, binary = cv2.threshold(alpha, 10, 255, cv2.THRESH_BINARY)
|
| 309 |
+
else:
|
| 310 |
+
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
| 311 |
+
_, binary = cv2.threshold(gray, 10, 255, cv2.THRESH_BINARY)
|
| 312 |
+
|
| 313 |
+
# Find vertical projection
|
| 314 |
+
v_proj = np.sum(binary, axis=0)
|
| 315 |
+
|
| 316 |
+
# Find gaps
|
| 317 |
+
threshold = np.max(v_proj) * 0.1
|
| 318 |
+
gaps = v_proj < threshold
|
| 319 |
+
|
| 320 |
+
# Find gap positions
|
| 321 |
+
gap_starts = []
|
| 322 |
+
gap_ends = []
|
| 323 |
+
|
| 324 |
+
in_gap = False
|
| 325 |
+
for i, is_gap in enumerate(gaps):
|
| 326 |
+
if not in_gap and is_gap:
|
| 327 |
+
gap_starts.append(i)
|
| 328 |
+
in_gap = True
|
| 329 |
+
elif in_gap and not is_gap:
|
| 330 |
+
gap_ends.append(i)
|
| 331 |
+
in_gap = False
|
| 332 |
+
|
| 333 |
+
# Calculate frame width from gap positions
|
| 334 |
+
if len(gap_starts) > 0:
|
| 335 |
+
# Average distance between gaps
|
| 336 |
+
if len(gap_starts) > 1:
|
| 337 |
+
frame_width = int(np.mean(np.diff(gap_starts)))
|
| 338 |
+
else:
|
| 339 |
+
frame_width = gap_starts[0]
|
| 340 |
+
else:
|
| 341 |
+
# No gaps found, assume single frame
|
| 342 |
+
frame_width = image.shape[1]
|
| 343 |
+
|
| 344 |
+
frame_height = image.shape[0]
|
| 345 |
+
|
| 346 |
+
return frame_width, frame_height
|
frame_namer.py
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Smart Frame Naming Module
|
| 3 |
+
Automatically names frames based on their content/pose
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import cv2
|
| 7 |
+
import numpy as np
|
| 8 |
+
from typing import List
|
| 9 |
+
from scipy.spatial import distance
|
| 10 |
+
|
| 11 |
+
# Optional imports for advanced features
|
| 12 |
+
try:
|
| 13 |
+
import torch
|
| 14 |
+
from transformers import pipeline
|
| 15 |
+
HAS_TRANSFORMERS = True
|
| 16 |
+
except ImportError:
|
| 17 |
+
HAS_TRANSFORMERS = False
|
| 18 |
+
torch = None
|
| 19 |
+
pipeline = None
|
| 20 |
+
|
| 21 |
+
class FrameNamer:
|
| 22 |
+
"""Intelligent frame naming based on pose analysis"""
|
| 23 |
+
|
| 24 |
+
def __init__(self):
|
| 25 |
+
self.pose_keywords = {
|
| 26 |
+
'idle': ['standing', 'still', 'neutral', 'waiting'],
|
| 27 |
+
'walk': ['walking', 'moving', 'step'],
|
| 28 |
+
'run': ['running', 'fast', 'sprint'],
|
| 29 |
+
'jump': ['jumping', 'leap', 'air'],
|
| 30 |
+
'attack': ['attacking', 'strike', 'hit', 'swing'],
|
| 31 |
+
'hurt': ['hurt', 'damage', 'hit', 'pain'],
|
| 32 |
+
'die': ['dying', 'dead', 'fall'],
|
| 33 |
+
'cast': ['casting', 'spell', 'magic'],
|
| 34 |
+
'block': ['blocking', 'defend', 'guard'],
|
| 35 |
+
'shoot': ['shooting', 'bow', 'arrow', 'ranged']
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
# Initialize pose classifier if available
|
| 39 |
+
self.classifier = None
|
| 40 |
+
self._init_classifier()
|
| 41 |
+
|
| 42 |
+
def _init_classifier(self):
|
| 43 |
+
"""Initialize image classifier for pose detection"""
|
| 44 |
+
try:
|
| 45 |
+
# Try to load a lightweight classifier
|
| 46 |
+
# Note: In production, you'd use a custom-trained model
|
| 47 |
+
self.classifier = None # Placeholder for actual model
|
| 48 |
+
except Exception as e:
|
| 49 |
+
print(f"Could not load classifier: {e}")
|
| 50 |
+
self.classifier = None
|
| 51 |
+
|
| 52 |
+
def name_frames(self, frames: List[np.ndarray]) -> List[str]:
|
| 53 |
+
"""
|
| 54 |
+
Generate intelligent names for frames
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
frames: List of frame images
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
List of frame names
|
| 61 |
+
"""
|
| 62 |
+
if len(frames) == 0:
|
| 63 |
+
return []
|
| 64 |
+
|
| 65 |
+
# Analyze each frame
|
| 66 |
+
frame_features = []
|
| 67 |
+
for frame in frames:
|
| 68 |
+
features = self._extract_features(frame)
|
| 69 |
+
frame_features.append(features)
|
| 70 |
+
|
| 71 |
+
# Detect animation type
|
| 72 |
+
animation_type = self._detect_animation_type(frame_features)
|
| 73 |
+
|
| 74 |
+
# Generate names
|
| 75 |
+
names = []
|
| 76 |
+
for i, features in enumerate(frame_features):
|
| 77 |
+
# Determine pose variation
|
| 78 |
+
pose_variation = self._get_pose_variation(features, frame_features, i)
|
| 79 |
+
|
| 80 |
+
# Generate name
|
| 81 |
+
if animation_type == 'idle':
|
| 82 |
+
name = f"idle_{i+1:02d}"
|
| 83 |
+
elif animation_type == 'walk':
|
| 84 |
+
name = f"walk_{i+1:02d}"
|
| 85 |
+
elif animation_type == 'run':
|
| 86 |
+
name = f"run_{i+1:02d}"
|
| 87 |
+
elif animation_type == 'jump':
|
| 88 |
+
if i == 0:
|
| 89 |
+
name = "jump_start"
|
| 90 |
+
elif i == len(frames) - 1:
|
| 91 |
+
name = "jump_land"
|
| 92 |
+
else:
|
| 93 |
+
name = f"jump_{i:02d}"
|
| 94 |
+
elif animation_type == 'attack':
|
| 95 |
+
if i == 0:
|
| 96 |
+
name = "attack_windup"
|
| 97 |
+
elif i == len(frames) // 2:
|
| 98 |
+
name = "attack_strike"
|
| 99 |
+
elif i == len(frames) - 1:
|
| 100 |
+
name = "attack_recover"
|
| 101 |
+
else:
|
| 102 |
+
name = f"attack_{i:02d}"
|
| 103 |
+
elif animation_type == 'hurt':
|
| 104 |
+
name = f"hurt_{i+1:02d}"
|
| 105 |
+
elif animation_type == 'die':
|
| 106 |
+
if i == len(frames) - 1:
|
| 107 |
+
name = "die_dead"
|
| 108 |
+
else:
|
| 109 |
+
name = f"die_{i+1:02d}"
|
| 110 |
+
elif animation_type == 'cast':
|
| 111 |
+
if i == 0:
|
| 112 |
+
name = "cast_start"
|
| 113 |
+
elif i == len(frames) - 1:
|
| 114 |
+
name = "cast_release"
|
| 115 |
+
else:
|
| 116 |
+
name = f"cast_{i:02d}"
|
| 117 |
+
elif animation_type == 'block':
|
| 118 |
+
name = f"block_{i+1:02d}"
|
| 119 |
+
elif animation_type == 'shoot':
|
| 120 |
+
if i == 0:
|
| 121 |
+
name = "shoot_draw"
|
| 122 |
+
elif i == len(frames) - 1:
|
| 123 |
+
name = "shoot_release"
|
| 124 |
+
else:
|
| 125 |
+
name = f"shoot_{i:02d}"
|
| 126 |
+
else:
|
| 127 |
+
name = f"frame_{i+1:03d}"
|
| 128 |
+
|
| 129 |
+
names.append(name)
|
| 130 |
+
|
| 131 |
+
return names
|
| 132 |
+
|
| 133 |
+
def _extract_features(self, frame: np.ndarray) -> dict:
|
| 134 |
+
"""
|
| 135 |
+
Extract features from a frame for analysis
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
frame: Input frame image
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
Dictionary of features
|
| 142 |
+
"""
|
| 143 |
+
features = {}
|
| 144 |
+
|
| 145 |
+
# Handle alpha channel
|
| 146 |
+
if len(frame.shape) == 3 and frame.shape[2] == 4:
|
| 147 |
+
alpha = frame[:, :, 3]
|
| 148 |
+
bgr = frame[:, :, :3]
|
| 149 |
+
# Create mask from alpha
|
| 150 |
+
mask = alpha > 10
|
| 151 |
+
else:
|
| 152 |
+
bgr = frame
|
| 153 |
+
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
| 154 |
+
_, mask = cv2.threshold(gray, 10, 255, cv2.THRESH_BINARY)
|
| 155 |
+
|
| 156 |
+
# Get bounding box of content
|
| 157 |
+
coords = np.column_stack(np.where(mask))
|
| 158 |
+
if len(coords) > 0:
|
| 159 |
+
y_min, x_min = coords.min(axis=0)
|
| 160 |
+
y_max, x_max = coords.max(axis=0)
|
| 161 |
+
|
| 162 |
+
features['bbox'] = (x_min, y_min, x_max, y_max)
|
| 163 |
+
features['width'] = x_max - x_min
|
| 164 |
+
features['height'] = y_max - y_min
|
| 165 |
+
features['center_x'] = (x_min + x_max) / 2
|
| 166 |
+
features['center_y'] = (y_min + y_max) / 2
|
| 167 |
+
features['aspect_ratio'] = features['width'] / max(features['height'], 1)
|
| 168 |
+
|
| 169 |
+
# Calculate centroid
|
| 170 |
+
moments = cv2.moments(mask.astype(np.uint8))
|
| 171 |
+
if moments['m00'] > 0:
|
| 172 |
+
features['centroid_x'] = moments['m10'] / moments['m00']
|
| 173 |
+
features['centroid_y'] = moments['m01'] / moments['m00']
|
| 174 |
+
else:
|
| 175 |
+
features['centroid_x'] = features['center_x']
|
| 176 |
+
features['centroid_y'] = features['center_y']
|
| 177 |
+
|
| 178 |
+
# Calculate pixel count (area)
|
| 179 |
+
features['area'] = np.sum(mask)
|
| 180 |
+
|
| 181 |
+
# Calculate center of mass height ratio
|
| 182 |
+
features['com_height_ratio'] = features['centroid_y'] / frame.shape[0]
|
| 183 |
+
else:
|
| 184 |
+
features['bbox'] = (0, 0, frame.shape[1], frame.shape[0])
|
| 185 |
+
features['width'] = frame.shape[1]
|
| 186 |
+
features['height'] = frame.shape[0]
|
| 187 |
+
features['center_x'] = frame.shape[1] / 2
|
| 188 |
+
features['center_y'] = frame.shape[0] / 2
|
| 189 |
+
features['aspect_ratio'] = 1.0
|
| 190 |
+
features['centroid_x'] = frame.shape[1] / 2
|
| 191 |
+
features['centroid_y'] = frame.shape[0] / 2
|
| 192 |
+
features['area'] = 0
|
| 193 |
+
features['com_height_ratio'] = 0.5
|
| 194 |
+
|
| 195 |
+
return features
|
| 196 |
+
|
| 197 |
+
def _detect_animation_type(self, frame_features: List[dict]) -> str:
|
| 198 |
+
"""
|
| 199 |
+
Detect the type of animation based on frame features
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
frame_features: List of feature dictionaries
|
| 203 |
+
|
| 204 |
+
Returns:
|
| 205 |
+
Animation type string
|
| 206 |
+
"""
|
| 207 |
+
if len(frame_features) < 2:
|
| 208 |
+
return 'idle'
|
| 209 |
+
|
| 210 |
+
# Calculate motion metrics
|
| 211 |
+
center_x_changes = []
|
| 212 |
+
center_y_changes = []
|
| 213 |
+
area_changes = []
|
| 214 |
+
com_height_changes = []
|
| 215 |
+
|
| 216 |
+
for i in range(1, len(frame_features)):
|
| 217 |
+
prev = frame_features[i - 1]
|
| 218 |
+
curr = frame_features[i]
|
| 219 |
+
|
| 220 |
+
center_x_changes.append(abs(curr['center_x'] - prev['center_x']))
|
| 221 |
+
center_y_changes.append(abs(curr['center_y'] - prev['center_y']))
|
| 222 |
+
area_changes.append(abs(curr['area'] - prev['area']))
|
| 223 |
+
com_height_changes.append(abs(curr['com_height_ratio'] - prev['com_height_ratio']))
|
| 224 |
+
|
| 225 |
+
avg_x_change = np.mean(center_x_changes) if center_x_changes else 0
|
| 226 |
+
avg_y_change = np.mean(center_y_changes) if center_y_changes else 0
|
| 227 |
+
avg_area_change = np.mean(area_changes) if area_changes else 0
|
| 228 |
+
avg_com_height_change = np.mean(com_height_changes) if com_height_changes else 0
|
| 229 |
+
|
| 230 |
+
# Detect based on motion patterns
|
| 231 |
+
total_horizontal_movement = abs(frame_features[-1]['center_x'] - frame_features[0]['center_x'])
|
| 232 |
+
total_vertical_movement = abs(frame_features[-1]['center_y'] - frame_features[0]['center_y'])
|
| 233 |
+
|
| 234 |
+
# Calculate height variation
|
| 235 |
+
heights = [f['height'] for f in frame_features]
|
| 236 |
+
height_variance = np.var(heights)
|
| 237 |
+
max_height = max(heights)
|
| 238 |
+
min_height = min(heights)
|
| 239 |
+
height_range = max_height - min_height
|
| 240 |
+
|
| 241 |
+
# Calculate width variation
|
| 242 |
+
widths = [f['width'] for f in frame_features]
|
| 243 |
+
width_variance = np.var(widths)
|
| 244 |
+
|
| 245 |
+
# Detect animation type
|
| 246 |
+
# Jump: Significant vertical movement and height variation
|
| 247 |
+
if avg_y_change > avg_x_change * 1.5 and height_range > max_height * 0.15:
|
| 248 |
+
return 'jump'
|
| 249 |
+
|
| 250 |
+
# Attack: Large area changes (weapon swing) or horizontal extension
|
| 251 |
+
if avg_area_change > np.mean([f['area'] for f in frame_features]) * 0.1:
|
| 252 |
+
return 'attack'
|
| 253 |
+
|
| 254 |
+
# Hurt/Die: Center of mass moves down
|
| 255 |
+
if frame_features[-1]['com_height_ratio'] > frame_features[0]['com_height_ratio'] + 0.1:
|
| 256 |
+
if frame_features[-1]['height'] < frame_features[0]['height'] * 0.7:
|
| 257 |
+
return 'die'
|
| 258 |
+
return 'hurt'
|
| 259 |
+
|
| 260 |
+
# Cast: Arms up (height increases then decreases)
|
| 261 |
+
if height_variance > max_height * 0.1:
|
| 262 |
+
mid_idx = len(frame_features) // 2
|
| 263 |
+
if (frame_features[mid_idx]['height'] > frame_features[0]['height'] and
|
| 264 |
+
frame_features[mid_idx]['height'] > frame_features[-1]['height']):
|
| 265 |
+
return 'cast'
|
| 266 |
+
|
| 267 |
+
# Block: Wide stance (width increases)
|
| 268 |
+
if frame_features[0]['width'] * 1.2 < max(widths):
|
| 269 |
+
return 'block'
|
| 270 |
+
|
| 271 |
+
# Shoot: One arm extended
|
| 272 |
+
if width_variance > np.mean(widths) * 0.05:
|
| 273 |
+
return 'shoot'
|
| 274 |
+
|
| 275 |
+
# Run vs Walk: Speed of horizontal movement
|
| 276 |
+
if avg_x_change > frame_features[0]['width'] * 0.15:
|
| 277 |
+
return 'run'
|
| 278 |
+
elif avg_x_change > frame_features[0]['width'] * 0.05:
|
| 279 |
+
return 'walk'
|
| 280 |
+
|
| 281 |
+
# Default to idle
|
| 282 |
+
return 'idle'
|
| 283 |
+
|
| 284 |
+
def _get_pose_variation(self, features: dict, all_features: List[dict],
|
| 285 |
+
index: int) -> str:
|
| 286 |
+
"""
|
| 287 |
+
Get pose variation descriptor
|
| 288 |
+
|
| 289 |
+
Args:
|
| 290 |
+
features: Current frame features
|
| 291 |
+
all_features: All frame features
|
| 292 |
+
index: Current frame index
|
| 293 |
+
|
| 294 |
+
Returns:
|
| 295 |
+
Variation descriptor
|
| 296 |
+
"""
|
| 297 |
+
if index == 0:
|
| 298 |
+
return 'start'
|
| 299 |
+
elif index == len(all_features) - 1:
|
| 300 |
+
return 'end'
|
| 301 |
+
else:
|
| 302 |
+
return 'mid'
|
| 303 |
+
|
| 304 |
+
def suggest_animation_name(self, frames: List[np.ndarray]) -> str:
|
| 305 |
+
"""
|
| 306 |
+
Suggest a name for the entire animation
|
| 307 |
+
|
| 308 |
+
Args:
|
| 309 |
+
frames: List of frame images
|
| 310 |
+
|
| 311 |
+
Returns:
|
| 312 |
+
Suggested animation name
|
| 313 |
+
"""
|
| 314 |
+
animation_type = self._detect_animation_type([self._extract_features(f) for f in frames])
|
| 315 |
+
|
| 316 |
+
suggestions = {
|
| 317 |
+
'idle': 'character_idle',
|
| 318 |
+
'walk': 'character_walk',
|
| 319 |
+
'run': 'character_run',
|
| 320 |
+
'jump': 'character_jump',
|
| 321 |
+
'attack': 'character_attack',
|
| 322 |
+
'hurt': 'character_hurt',
|
| 323 |
+
'die': 'character_die',
|
| 324 |
+
'cast': 'character_cast_spell',
|
| 325 |
+
'block': 'character_block',
|
| 326 |
+
'shoot': 'character_shoot'
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
+
return suggestions.get(animation_type, 'character_animation')
|
sprite_processor.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Sprite Image Enhancement Module
|
| 3 |
+
Uses Real-ESRGAN for high-quality upscaling
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import cv2
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from PIL import Image
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
class SpriteProcessor:
|
| 13 |
+
"""Processor for enhancing sprite sheet images"""
|
| 14 |
+
|
| 15 |
+
def __init__(self):
|
| 16 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 17 |
+
self.model = None
|
| 18 |
+
self._load_model()
|
| 19 |
+
|
| 20 |
+
def _load_model(self):
|
| 21 |
+
"""Load Real-ESRGAN model"""
|
| 22 |
+
try:
|
| 23 |
+
from realesrgan import RealESRGANer
|
| 24 |
+
from basicsr.archs.rrdbnet_arch import RRDBNet
|
| 25 |
+
|
| 26 |
+
# Create model
|
| 27 |
+
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
|
| 28 |
+
num_block=23, num_grow_ch=32, scale=4)
|
| 29 |
+
|
| 30 |
+
# Initialize Real-ESRGAN
|
| 31 |
+
model_path = "weights/RealESRGAN_x4plus.pth"
|
| 32 |
+
|
| 33 |
+
if os.path.exists(model_path):
|
| 34 |
+
self.model = RealESRGANer(
|
| 35 |
+
scale=4,
|
| 36 |
+
model_path=model_path,
|
| 37 |
+
model=model,
|
| 38 |
+
tile=0,
|
| 39 |
+
pre_pad=0,
|
| 40 |
+
half=False,
|
| 41 |
+
device=self.device
|
| 42 |
+
)
|
| 43 |
+
else:
|
| 44 |
+
print("Warning: Real-ESRGAN model not found, using fallback enhancement")
|
| 45 |
+
self.model = None
|
| 46 |
+
|
| 47 |
+
except Exception as e:
|
| 48 |
+
print(f"Error loading Real-ESRGAN: {e}")
|
| 49 |
+
self.model = None
|
| 50 |
+
|
| 51 |
+
def enhance_image(self, image: np.ndarray, scale: int = 4) -> np.ndarray:
|
| 52 |
+
"""
|
| 53 |
+
Enhance image quality using Real-ESRGAN or fallback methods
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
image: Input image (BGR or BGRA)
|
| 57 |
+
scale: Upscaling factor (2 or 4)
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
Enhanced image
|
| 61 |
+
"""
|
| 62 |
+
# Handle alpha channel
|
| 63 |
+
has_alpha = len(image.shape) == 3 and image.shape[2] == 4
|
| 64 |
+
|
| 65 |
+
if has_alpha:
|
| 66 |
+
# Separate alpha channel
|
| 67 |
+
bgr = image[:, :, :3]
|
| 68 |
+
alpha = image[:, :, 3]
|
| 69 |
+
else:
|
| 70 |
+
bgr = image
|
| 71 |
+
alpha = None
|
| 72 |
+
|
| 73 |
+
# Enhance RGB channels
|
| 74 |
+
if self.model is not None and scale > 1:
|
| 75 |
+
try:
|
| 76 |
+
# Convert BGR to RGB for the model
|
| 77 |
+
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
|
| 78 |
+
|
| 79 |
+
# Apply Real-ESRGAN
|
| 80 |
+
enhanced_rgb, _ = self.model.enhance(rgb, outscale=scale)
|
| 81 |
+
|
| 82 |
+
# Convert back to BGR
|
| 83 |
+
enhanced_bgr = cv2.cvtColor(enhanced_rgb, cv2.COLOR_RGB2BGR)
|
| 84 |
+
|
| 85 |
+
except Exception as e:
|
| 86 |
+
print(f"Real-ESRGAN failed, using fallback: {e}")
|
| 87 |
+
enhanced_bgr = self._fallback_enhance(bgr, scale)
|
| 88 |
+
else:
|
| 89 |
+
enhanced_bgr = self._fallback_enhance(bgr, scale)
|
| 90 |
+
|
| 91 |
+
# Enhance alpha channel if present
|
| 92 |
+
if alpha is not None and scale > 1:
|
| 93 |
+
enhanced_alpha = cv2.resize(alpha, None, fx=scale, fy=scale,
|
| 94 |
+
interpolation=cv2.INTER_NEAREST)
|
| 95 |
+
|
| 96 |
+
# Merge channels
|
| 97 |
+
enhanced_image = cv2.merge([enhanced_bgr, enhanced_alpha])
|
| 98 |
+
else:
|
| 99 |
+
enhanced_image = enhanced_bgr
|
| 100 |
+
|
| 101 |
+
return enhanced_image
|
| 102 |
+
|
| 103 |
+
def _fallback_enhance(self, image: np.ndarray, scale: int) -> np.ndarray:
|
| 104 |
+
"""
|
| 105 |
+
Fallback enhancement using OpenCV
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
image: Input BGR image
|
| 109 |
+
scale: Upscaling factor
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
Enhanced image
|
| 113 |
+
"""
|
| 114 |
+
# Resize with high-quality interpolation
|
| 115 |
+
new_width = int(image.shape[1] * scale)
|
| 116 |
+
new_height = int(image.shape[0] * scale)
|
| 117 |
+
|
| 118 |
+
enhanced = cv2.resize(image, (new_width, new_height),
|
| 119 |
+
interpolation=cv2.INTER_CUBIC)
|
| 120 |
+
|
| 121 |
+
# Apply sharpening
|
| 122 |
+
kernel = np.array([[-1, -1, -1],
|
| 123 |
+
[-1, 9, -1],
|
| 124 |
+
[-1, -1, -1]])
|
| 125 |
+
enhanced = cv2.filter2D(enhanced, -1, kernel)
|
| 126 |
+
|
| 127 |
+
# Denoise
|
| 128 |
+
enhanced = cv2.fastNlMeansDenoisingColored(enhanced, None, 5, 5, 7, 21)
|
| 129 |
+
|
| 130 |
+
return enhanced
|
| 131 |
+
|
| 132 |
+
def sharpen_image(self, image: np.ndarray, strength: float = 1.0) -> np.ndarray:
|
| 133 |
+
"""
|
| 134 |
+
Apply sharpening filter
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
image: Input image
|
| 138 |
+
strength: Sharpening strength
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
Sharpened image
|
| 142 |
+
"""
|
| 143 |
+
kernel = np.array([[-1, -1, -1],
|
| 144 |
+
[-1, 9, -1],
|
| 145 |
+
[-1, -1, -1]]) * strength
|
| 146 |
+
|
| 147 |
+
sharpened = cv2.filter2D(image, -1, kernel)
|
| 148 |
+
return sharpened
|
| 149 |
+
|
| 150 |
+
def remove_blur(self, image: np.ndarray) -> np.ndarray:
|
| 151 |
+
"""
|
| 152 |
+
Reduce blur using deconvolution
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
image: Input image
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
Deblurred image
|
| 159 |
+
"""
|
| 160 |
+
# Create a point spread function (PSF)
|
| 161 |
+
psf_size = 5
|
| 162 |
+
psf = np.ones((psf_size, psf_size)) / (psf_size ** 2)
|
| 163 |
+
|
| 164 |
+
# Simple deconvolution (Wiener filter approximation)
|
| 165 |
+
result = image.copy()
|
| 166 |
+
|
| 167 |
+
for i in range(3): # For each channel
|
| 168 |
+
channel = image[:, :, i].astype(np.float32) / 255.0
|
| 169 |
+
|
| 170 |
+
# FFT
|
| 171 |
+
psf_fft = np.fft.fft2(psf, s=channel.shape)
|
| 172 |
+
channel_fft = np.fft.fft2(channel)
|
| 173 |
+
|
| 174 |
+
# Wiener deconvolution
|
| 175 |
+
K = 0.01 # Noise to signal ratio
|
| 176 |
+
deconv_fft = channel_fft * np.conj(psf_fft) / (np.abs(psf_fft) ** 2 + K)
|
| 177 |
+
|
| 178 |
+
# Inverse FFT
|
| 179 |
+
deconv = np.fft.ifft2(deconv_fft).real
|
| 180 |
+
|
| 181 |
+
# Clip and convert back
|
| 182 |
+
deconv = np.clip(deconv * 255, 0, 255).astype(np.uint8)
|
| 183 |
+
result[:, :, i] = deconv
|
| 184 |
+
|
| 185 |
+
return result
|
| 186 |
+
|
| 187 |
+
def enhance_contrast(self, image: np.ndarray) -> np.ndarray:
|
| 188 |
+
"""
|
| 189 |
+
Enhance contrast using CLAHE
|
| 190 |
+
|
| 191 |
+
Args:
|
| 192 |
+
image: Input image
|
| 193 |
+
|
| 194 |
+
Returns:
|
| 195 |
+
Contrast-enhanced image
|
| 196 |
+
"""
|
| 197 |
+
lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
|
| 198 |
+
l, a, b = cv2.split(lab)
|
| 199 |
+
|
| 200 |
+
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
| 201 |
+
l = clahe.apply(l)
|
| 202 |
+
|
| 203 |
+
enhanced = cv2.merge([l, a, b])
|
| 204 |
+
enhanced = cv2.cvtColor(enhanced, cv2.COLOR_LAB2BGR)
|
| 205 |
+
|
| 206 |
+
return enhanced
|