Skydata001 commited on
Commit
925fbb1
·
verified ·
1 Parent(s): 55a599e

Upload 3 files

Browse files
Files changed (3) hide show
  1. frame_detector.py +346 -0
  2. frame_namer.py +329 -0
  3. sprite_processor.py +206 -0
frame_detector.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Frame Detection Module
3
+ Automatically detects and extracts sprite frames from sprite sheets
4
+ """
5
+
6
+ import cv2
7
+ import numpy as np
8
+ from typing import List, Tuple
9
+ from scipy import ndimage
10
+ from skimage import measure
11
+
12
+ class FrameDetector:
13
+ """Detector for finding sprite frames in sprite sheets"""
14
+
15
+ def __init__(self):
16
+ self.min_frame_size = 16 # Minimum frame size in pixels
17
+ self.max_frame_size = 512 # Maximum frame size in pixels
18
+
19
+ def detect_frames_auto(self, image: np.ndarray, padding: int = 2) -> Tuple[List[np.ndarray], List[Tuple]]:
20
+ """
21
+ Automatically detect frames in a sprite sheet
22
+
23
+ Args:
24
+ image: Input sprite sheet image
25
+ padding: Padding to add around each frame
26
+
27
+ Returns:
28
+ Tuple of (list of frame images, list of bounding boxes)
29
+ """
30
+ # Handle alpha channel
31
+ if len(image.shape) == 3 and image.shape[2] == 4:
32
+ # Use alpha channel for detection
33
+ alpha = image[:, :, 3]
34
+ # Create binary mask where alpha > 0
35
+ _, binary = cv2.threshold(alpha, 10, 255, cv2.THRESH_BINARY)
36
+ else:
37
+ # Convert to grayscale and create mask
38
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
39
+ _, binary = cv2.threshold(gray, 10, 255, cv2.THRESH_BINARY)
40
+
41
+ # Find connected components
42
+ labels = measure.label(binary, connectivity=2)
43
+ regions = measure.regionprops(labels)
44
+
45
+ # Filter regions by size
46
+ valid_regions = []
47
+ for region in regions:
48
+ minr, minc, maxr, maxc = region.bbox
49
+ width = maxc - minc
50
+ height = maxr - minr
51
+
52
+ if (self.min_frame_size <= width <= self.max_frame_size and
53
+ self.min_frame_size <= height <= self.max_frame_size):
54
+ valid_regions.append(region)
55
+
56
+ # If no valid regions found, try grid-based detection
57
+ if len(valid_regions) == 0:
58
+ return self._detect_grid_based(image, padding)
59
+
60
+ # Sort regions by x-coordinate (left to right)
61
+ valid_regions.sort(key=lambda r: r.bbox[1])
62
+
63
+ # Extract frames
64
+ frames = []
65
+ frame_boxes = []
66
+
67
+ for region in valid_regions:
68
+ minr, minc, maxr, maxc = region.bbox
69
+
70
+ # Add padding
71
+ minr = max(0, minr - padding)
72
+ minc = max(0, minc - padding)
73
+ maxr = min(image.shape[0], maxr + padding)
74
+ maxc = min(image.shape[1], maxc + padding)
75
+
76
+ # Extract frame
77
+ frame = image[minr:maxr, minc:maxc]
78
+
79
+ frames.append(frame)
80
+ frame_boxes.append((minr, minc, maxr, maxc))
81
+
82
+ # If too many small regions, use grid-based approach
83
+ if len(frames) > 20:
84
+ return self._detect_grid_based(image, padding)
85
+
86
+ return frames, frame_boxes
87
+
88
+ def _detect_grid_based(self, image: np.ndarray, padding: int = 2) -> Tuple[List[np.ndarray], List[Tuple]]:
89
+ """
90
+ Detect frames using grid-based approach with improved filtering
91
+
92
+ Args:
93
+ image: Input sprite sheet image
94
+ padding: Padding to add around each frame
95
+
96
+ Returns:
97
+ Tuple of (list of frame images, list of bounding boxes)
98
+ """
99
+ # Handle alpha channel
100
+ if len(image.shape) == 3 and image.shape[2] == 4:
101
+ alpha = image[:, :, 3]
102
+ _, binary = cv2.threshold(alpha, 10, 255, cv2.THRESH_BINARY)
103
+ else:
104
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
105
+ _, binary = cv2.threshold(gray, 10, 255, cv2.THRESH_BINARY)
106
+
107
+ # Apply morphological operations to clean up noise
108
+ kernel = np.ones((3, 3), np.uint8)
109
+ binary = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel)
110
+ binary = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel)
111
+
112
+ # Find horizontal and vertical projections
113
+ h_proj = np.sum(binary, axis=1)
114
+ v_proj = np.sum(binary, axis=0)
115
+
116
+ # Find gaps (low values in projection)
117
+ h_threshold = np.max(h_proj) * 0.1
118
+ v_threshold = np.max(v_proj) * 0.1
119
+
120
+ h_gaps = h_proj < h_threshold
121
+ v_gaps = v_proj < v_threshold
122
+
123
+ # Find row boundaries
124
+ row_boundaries = []
125
+ in_gap = True
126
+ for i, is_gap in enumerate(h_gaps):
127
+ if in_gap and not is_gap:
128
+ row_boundaries.append(i)
129
+ in_gap = False
130
+ elif not in_gap and is_gap:
131
+ row_boundaries.append(i)
132
+ in_gap = True
133
+
134
+ # Find column boundaries
135
+ col_boundaries = []
136
+ in_gap = True
137
+ for i, is_gap in enumerate(v_gaps):
138
+ if in_gap and not is_gap:
139
+ col_boundaries.append(i)
140
+ in_gap = False
141
+ elif not in_gap and is_gap:
142
+ col_boundaries.append(i)
143
+ in_gap = True
144
+
145
+ # Ensure even number of boundaries
146
+ if len(row_boundaries) % 2 != 0:
147
+ row_boundaries.append(binary.shape[0])
148
+ if len(col_boundaries) % 2 != 0:
149
+ col_boundaries.append(binary.shape[1])
150
+
151
+ # Extract frames
152
+ frames = []
153
+ frame_boxes = []
154
+
155
+ # If we have valid boundaries
156
+ if len(row_boundaries) >= 2 and len(col_boundaries) >= 2:
157
+ for i in range(0, len(row_boundaries), 2):
158
+ for j in range(0, len(col_boundaries), 2):
159
+ minr = row_boundaries[i]
160
+ maxr = row_boundaries[i + 1]
161
+ minc = col_boundaries[j]
162
+ maxc = col_boundaries[j + 1]
163
+
164
+ # Add padding
165
+ minr_p = max(0, minr - padding)
166
+ minc_p = max(0, minc - padding)
167
+ maxr_p = min(image.shape[0], maxr + padding)
168
+ maxc_p = min(image.shape[1], maxc + padding)
169
+
170
+ # Extract frame
171
+ frame = image[minr_p:maxr_p, minc_p:maxc_p]
172
+
173
+ # Check if frame has content and meets size requirements
174
+ frame_width = maxc - minc
175
+ frame_height = maxr - minr
176
+
177
+ # Filter out very small frames (likely effects/particles)
178
+ min_content_width = max(15, image.shape[1] // 50) # At least 15px or 2% of image width
179
+ min_content_height = max(20, image.shape[0] // 3) # At least 20px or 33% of image height
180
+
181
+ has_content = np.sum(frame) > 0
182
+ has_valid_size = (frame_width >= min_content_width and
183
+ frame_height >= min_content_height)
184
+
185
+ if has_content and has_valid_size:
186
+ frames.append(frame)
187
+ frame_boxes.append((minr_p, minc_p, maxr_p, maxc_p))
188
+
189
+ # If still no frames, use equal division
190
+ if len(frames) == 0:
191
+ return self._detect_equal_division(image, padding)
192
+
193
+ return frames, frame_boxes
194
+
195
+ def _detect_equal_division(self, image: np.ndarray, padding: int = 2,
196
+ num_frames: int = 8) -> Tuple[List[np.ndarray], List[Tuple]]:
197
+ """
198
+ Detect frames by equal division
199
+
200
+ Args:
201
+ image: Input sprite sheet image
202
+ padding: Padding to add around each frame
203
+ num_frames: Number of frames to divide into
204
+
205
+ Returns:
206
+ Tuple of (list of frame images, list of bounding boxes)
207
+ """
208
+ frames = []
209
+ frame_boxes = []
210
+
211
+ img_width = image.shape[1]
212
+ img_height = image.shape[0]
213
+
214
+ # Assume horizontal layout
215
+ frame_width = img_width // num_frames
216
+ frame_height = img_height
217
+
218
+ for i in range(num_frames):
219
+ minc = i * frame_width
220
+ maxc = (i + 1) * frame_width if i < num_frames - 1 else img_width
221
+ minr = 0
222
+ maxr = img_height
223
+
224
+ # Add padding
225
+ minc_p = max(0, minc - padding)
226
+ minr_p = max(0, minr - padding)
227
+ maxc_p = min(img_width, maxc + padding)
228
+ maxr_p = min(img_height, maxr + padding)
229
+
230
+ frame = image[minr_p:maxr_p, minc_p:maxc_p]
231
+ frames.append(frame)
232
+ frame_boxes.append((minr_p, minc_p, maxr_p, maxc_p))
233
+
234
+ return frames, frame_boxes
235
+
236
+ def detect_frames_manual(self, image: np.ndarray, num_frames: int,
237
+ padding: int = 2) -> Tuple[List[np.ndarray], List[Tuple]]:
238
+ """
239
+ Manually specify number of frames
240
+
241
+ Args:
242
+ image: Input sprite sheet image
243
+ num_frames: Number of frames
244
+ padding: Padding to add around each frame
245
+
246
+ Returns:
247
+ Tuple of (list of frame images, list of bounding boxes)
248
+ """
249
+ return self._detect_equal_division(image, padding, num_frames)
250
+
251
+ def refine_frame_boundaries(self, image: np.ndarray, frame: np.ndarray,
252
+ bbox: Tuple) -> Tuple[np.ndarray, Tuple]:
253
+ """
254
+ Refine frame boundaries to remove excess transparent space
255
+
256
+ Args:
257
+ image: Original image
258
+ frame: Extracted frame
259
+ bbox: Bounding box (minr, minc, maxr, maxc)
260
+
261
+ Returns:
262
+ Refined frame and bounding box
263
+ """
264
+ minr, minc, maxr, maxc = bbox
265
+
266
+ # Handle alpha channel
267
+ if len(frame.shape) == 3 and frame.shape[2] == 4:
268
+ alpha = frame[:, :, 3]
269
+ # Find non-transparent pixels
270
+ rows = np.any(alpha > 10, axis=1)
271
+ cols = np.any(alpha > 10, axis=0)
272
+ else:
273
+ gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
274
+ rows = np.any(gray > 10, axis=1)
275
+ cols = np.any(gray > 10, axis=0)
276
+
277
+ # Find bounds
278
+ row_indices = np.where(rows)[0]
279
+ col_indices = np.where(cols)[0]
280
+
281
+ if len(row_indices) > 0 and len(col_indices) > 0:
282
+ # Calculate new bounds
283
+ new_minr = minr + row_indices[0]
284
+ new_maxr = minr + row_indices[-1] + 1
285
+ new_minc = minc + col_indices[0]
286
+ new_maxc = minc + col_indices[-1] + 1
287
+
288
+ # Extract refined frame
289
+ refined_frame = image[new_minr:new_maxr, new_minc:new_maxc]
290
+
291
+ return refined_frame, (new_minr, new_minc, new_maxr, new_maxc)
292
+
293
+ return frame, bbox
294
+
295
+ def detect_frame_size(self, image: np.ndarray) -> Tuple[int, int]:
296
+ """
297
+ Detect the size of individual frames
298
+
299
+ Args:
300
+ image: Input sprite sheet image
301
+
302
+ Returns:
303
+ Tuple of (frame_width, frame_height)
304
+ """
305
+ # Handle alpha channel
306
+ if len(image.shape) == 3 and image.shape[2] == 4:
307
+ alpha = image[:, :, 3]
308
+ _, binary = cv2.threshold(alpha, 10, 255, cv2.THRESH_BINARY)
309
+ else:
310
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
311
+ _, binary = cv2.threshold(gray, 10, 255, cv2.THRESH_BINARY)
312
+
313
+ # Find vertical projection
314
+ v_proj = np.sum(binary, axis=0)
315
+
316
+ # Find gaps
317
+ threshold = np.max(v_proj) * 0.1
318
+ gaps = v_proj < threshold
319
+
320
+ # Find gap positions
321
+ gap_starts = []
322
+ gap_ends = []
323
+
324
+ in_gap = False
325
+ for i, is_gap in enumerate(gaps):
326
+ if not in_gap and is_gap:
327
+ gap_starts.append(i)
328
+ in_gap = True
329
+ elif in_gap and not is_gap:
330
+ gap_ends.append(i)
331
+ in_gap = False
332
+
333
+ # Calculate frame width from gap positions
334
+ if len(gap_starts) > 0:
335
+ # Average distance between gaps
336
+ if len(gap_starts) > 1:
337
+ frame_width = int(np.mean(np.diff(gap_starts)))
338
+ else:
339
+ frame_width = gap_starts[0]
340
+ else:
341
+ # No gaps found, assume single frame
342
+ frame_width = image.shape[1]
343
+
344
+ frame_height = image.shape[0]
345
+
346
+ return frame_width, frame_height
frame_namer.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Smart Frame Naming Module
3
+ Automatically names frames based on their content/pose
4
+ """
5
+
6
+ import cv2
7
+ import numpy as np
8
+ from typing import List
9
+ from scipy.spatial import distance
10
+
11
+ # Optional imports for advanced features
12
+ try:
13
+ import torch
14
+ from transformers import pipeline
15
+ HAS_TRANSFORMERS = True
16
+ except ImportError:
17
+ HAS_TRANSFORMERS = False
18
+ torch = None
19
+ pipeline = None
20
+
21
+ class FrameNamer:
22
+ """Intelligent frame naming based on pose analysis"""
23
+
24
+ def __init__(self):
25
+ self.pose_keywords = {
26
+ 'idle': ['standing', 'still', 'neutral', 'waiting'],
27
+ 'walk': ['walking', 'moving', 'step'],
28
+ 'run': ['running', 'fast', 'sprint'],
29
+ 'jump': ['jumping', 'leap', 'air'],
30
+ 'attack': ['attacking', 'strike', 'hit', 'swing'],
31
+ 'hurt': ['hurt', 'damage', 'hit', 'pain'],
32
+ 'die': ['dying', 'dead', 'fall'],
33
+ 'cast': ['casting', 'spell', 'magic'],
34
+ 'block': ['blocking', 'defend', 'guard'],
35
+ 'shoot': ['shooting', 'bow', 'arrow', 'ranged']
36
+ }
37
+
38
+ # Initialize pose classifier if available
39
+ self.classifier = None
40
+ self._init_classifier()
41
+
42
+ def _init_classifier(self):
43
+ """Initialize image classifier for pose detection"""
44
+ try:
45
+ # Try to load a lightweight classifier
46
+ # Note: In production, you'd use a custom-trained model
47
+ self.classifier = None # Placeholder for actual model
48
+ except Exception as e:
49
+ print(f"Could not load classifier: {e}")
50
+ self.classifier = None
51
+
52
+ def name_frames(self, frames: List[np.ndarray]) -> List[str]:
53
+ """
54
+ Generate intelligent names for frames
55
+
56
+ Args:
57
+ frames: List of frame images
58
+
59
+ Returns:
60
+ List of frame names
61
+ """
62
+ if len(frames) == 0:
63
+ return []
64
+
65
+ # Analyze each frame
66
+ frame_features = []
67
+ for frame in frames:
68
+ features = self._extract_features(frame)
69
+ frame_features.append(features)
70
+
71
+ # Detect animation type
72
+ animation_type = self._detect_animation_type(frame_features)
73
+
74
+ # Generate names
75
+ names = []
76
+ for i, features in enumerate(frame_features):
77
+ # Determine pose variation
78
+ pose_variation = self._get_pose_variation(features, frame_features, i)
79
+
80
+ # Generate name
81
+ if animation_type == 'idle':
82
+ name = f"idle_{i+1:02d}"
83
+ elif animation_type == 'walk':
84
+ name = f"walk_{i+1:02d}"
85
+ elif animation_type == 'run':
86
+ name = f"run_{i+1:02d}"
87
+ elif animation_type == 'jump':
88
+ if i == 0:
89
+ name = "jump_start"
90
+ elif i == len(frames) - 1:
91
+ name = "jump_land"
92
+ else:
93
+ name = f"jump_{i:02d}"
94
+ elif animation_type == 'attack':
95
+ if i == 0:
96
+ name = "attack_windup"
97
+ elif i == len(frames) // 2:
98
+ name = "attack_strike"
99
+ elif i == len(frames) - 1:
100
+ name = "attack_recover"
101
+ else:
102
+ name = f"attack_{i:02d}"
103
+ elif animation_type == 'hurt':
104
+ name = f"hurt_{i+1:02d}"
105
+ elif animation_type == 'die':
106
+ if i == len(frames) - 1:
107
+ name = "die_dead"
108
+ else:
109
+ name = f"die_{i+1:02d}"
110
+ elif animation_type == 'cast':
111
+ if i == 0:
112
+ name = "cast_start"
113
+ elif i == len(frames) - 1:
114
+ name = "cast_release"
115
+ else:
116
+ name = f"cast_{i:02d}"
117
+ elif animation_type == 'block':
118
+ name = f"block_{i+1:02d}"
119
+ elif animation_type == 'shoot':
120
+ if i == 0:
121
+ name = "shoot_draw"
122
+ elif i == len(frames) - 1:
123
+ name = "shoot_release"
124
+ else:
125
+ name = f"shoot_{i:02d}"
126
+ else:
127
+ name = f"frame_{i+1:03d}"
128
+
129
+ names.append(name)
130
+
131
+ return names
132
+
133
+ def _extract_features(self, frame: np.ndarray) -> dict:
134
+ """
135
+ Extract features from a frame for analysis
136
+
137
+ Args:
138
+ frame: Input frame image
139
+
140
+ Returns:
141
+ Dictionary of features
142
+ """
143
+ features = {}
144
+
145
+ # Handle alpha channel
146
+ if len(frame.shape) == 3 and frame.shape[2] == 4:
147
+ alpha = frame[:, :, 3]
148
+ bgr = frame[:, :, :3]
149
+ # Create mask from alpha
150
+ mask = alpha > 10
151
+ else:
152
+ bgr = frame
153
+ gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
154
+ _, mask = cv2.threshold(gray, 10, 255, cv2.THRESH_BINARY)
155
+
156
+ # Get bounding box of content
157
+ coords = np.column_stack(np.where(mask))
158
+ if len(coords) > 0:
159
+ y_min, x_min = coords.min(axis=0)
160
+ y_max, x_max = coords.max(axis=0)
161
+
162
+ features['bbox'] = (x_min, y_min, x_max, y_max)
163
+ features['width'] = x_max - x_min
164
+ features['height'] = y_max - y_min
165
+ features['center_x'] = (x_min + x_max) / 2
166
+ features['center_y'] = (y_min + y_max) / 2
167
+ features['aspect_ratio'] = features['width'] / max(features['height'], 1)
168
+
169
+ # Calculate centroid
170
+ moments = cv2.moments(mask.astype(np.uint8))
171
+ if moments['m00'] > 0:
172
+ features['centroid_x'] = moments['m10'] / moments['m00']
173
+ features['centroid_y'] = moments['m01'] / moments['m00']
174
+ else:
175
+ features['centroid_x'] = features['center_x']
176
+ features['centroid_y'] = features['center_y']
177
+
178
+ # Calculate pixel count (area)
179
+ features['area'] = np.sum(mask)
180
+
181
+ # Calculate center of mass height ratio
182
+ features['com_height_ratio'] = features['centroid_y'] / frame.shape[0]
183
+ else:
184
+ features['bbox'] = (0, 0, frame.shape[1], frame.shape[0])
185
+ features['width'] = frame.shape[1]
186
+ features['height'] = frame.shape[0]
187
+ features['center_x'] = frame.shape[1] / 2
188
+ features['center_y'] = frame.shape[0] / 2
189
+ features['aspect_ratio'] = 1.0
190
+ features['centroid_x'] = frame.shape[1] / 2
191
+ features['centroid_y'] = frame.shape[0] / 2
192
+ features['area'] = 0
193
+ features['com_height_ratio'] = 0.5
194
+
195
+ return features
196
+
197
+ def _detect_animation_type(self, frame_features: List[dict]) -> str:
198
+ """
199
+ Detect the type of animation based on frame features
200
+
201
+ Args:
202
+ frame_features: List of feature dictionaries
203
+
204
+ Returns:
205
+ Animation type string
206
+ """
207
+ if len(frame_features) < 2:
208
+ return 'idle'
209
+
210
+ # Calculate motion metrics
211
+ center_x_changes = []
212
+ center_y_changes = []
213
+ area_changes = []
214
+ com_height_changes = []
215
+
216
+ for i in range(1, len(frame_features)):
217
+ prev = frame_features[i - 1]
218
+ curr = frame_features[i]
219
+
220
+ center_x_changes.append(abs(curr['center_x'] - prev['center_x']))
221
+ center_y_changes.append(abs(curr['center_y'] - prev['center_y']))
222
+ area_changes.append(abs(curr['area'] - prev['area']))
223
+ com_height_changes.append(abs(curr['com_height_ratio'] - prev['com_height_ratio']))
224
+
225
+ avg_x_change = np.mean(center_x_changes) if center_x_changes else 0
226
+ avg_y_change = np.mean(center_y_changes) if center_y_changes else 0
227
+ avg_area_change = np.mean(area_changes) if area_changes else 0
228
+ avg_com_height_change = np.mean(com_height_changes) if com_height_changes else 0
229
+
230
+ # Detect based on motion patterns
231
+ total_horizontal_movement = abs(frame_features[-1]['center_x'] - frame_features[0]['center_x'])
232
+ total_vertical_movement = abs(frame_features[-1]['center_y'] - frame_features[0]['center_y'])
233
+
234
+ # Calculate height variation
235
+ heights = [f['height'] for f in frame_features]
236
+ height_variance = np.var(heights)
237
+ max_height = max(heights)
238
+ min_height = min(heights)
239
+ height_range = max_height - min_height
240
+
241
+ # Calculate width variation
242
+ widths = [f['width'] for f in frame_features]
243
+ width_variance = np.var(widths)
244
+
245
+ # Detect animation type
246
+ # Jump: Significant vertical movement and height variation
247
+ if avg_y_change > avg_x_change * 1.5 and height_range > max_height * 0.15:
248
+ return 'jump'
249
+
250
+ # Attack: Large area changes (weapon swing) or horizontal extension
251
+ if avg_area_change > np.mean([f['area'] for f in frame_features]) * 0.1:
252
+ return 'attack'
253
+
254
+ # Hurt/Die: Center of mass moves down
255
+ if frame_features[-1]['com_height_ratio'] > frame_features[0]['com_height_ratio'] + 0.1:
256
+ if frame_features[-1]['height'] < frame_features[0]['height'] * 0.7:
257
+ return 'die'
258
+ return 'hurt'
259
+
260
+ # Cast: Arms up (height increases then decreases)
261
+ if height_variance > max_height * 0.1:
262
+ mid_idx = len(frame_features) // 2
263
+ if (frame_features[mid_idx]['height'] > frame_features[0]['height'] and
264
+ frame_features[mid_idx]['height'] > frame_features[-1]['height']):
265
+ return 'cast'
266
+
267
+ # Block: Wide stance (width increases)
268
+ if frame_features[0]['width'] * 1.2 < max(widths):
269
+ return 'block'
270
+
271
+ # Shoot: One arm extended
272
+ if width_variance > np.mean(widths) * 0.05:
273
+ return 'shoot'
274
+
275
+ # Run vs Walk: Speed of horizontal movement
276
+ if avg_x_change > frame_features[0]['width'] * 0.15:
277
+ return 'run'
278
+ elif avg_x_change > frame_features[0]['width'] * 0.05:
279
+ return 'walk'
280
+
281
+ # Default to idle
282
+ return 'idle'
283
+
284
+ def _get_pose_variation(self, features: dict, all_features: List[dict],
285
+ index: int) -> str:
286
+ """
287
+ Get pose variation descriptor
288
+
289
+ Args:
290
+ features: Current frame features
291
+ all_features: All frame features
292
+ index: Current frame index
293
+
294
+ Returns:
295
+ Variation descriptor
296
+ """
297
+ if index == 0:
298
+ return 'start'
299
+ elif index == len(all_features) - 1:
300
+ return 'end'
301
+ else:
302
+ return 'mid'
303
+
304
+ def suggest_animation_name(self, frames: List[np.ndarray]) -> str:
305
+ """
306
+ Suggest a name for the entire animation
307
+
308
+ Args:
309
+ frames: List of frame images
310
+
311
+ Returns:
312
+ Suggested animation name
313
+ """
314
+ animation_type = self._detect_animation_type([self._extract_features(f) for f in frames])
315
+
316
+ suggestions = {
317
+ 'idle': 'character_idle',
318
+ 'walk': 'character_walk',
319
+ 'run': 'character_run',
320
+ 'jump': 'character_jump',
321
+ 'attack': 'character_attack',
322
+ 'hurt': 'character_hurt',
323
+ 'die': 'character_die',
324
+ 'cast': 'character_cast_spell',
325
+ 'block': 'character_block',
326
+ 'shoot': 'character_shoot'
327
+ }
328
+
329
+ return suggestions.get(animation_type, 'character_animation')
sprite_processor.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sprite Image Enhancement Module
3
+ Uses Real-ESRGAN for high-quality upscaling
4
+ """
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import torch
9
+ from PIL import Image
10
+ import os
11
+
12
+ class SpriteProcessor:
13
+ """Processor for enhancing sprite sheet images"""
14
+
15
+ def __init__(self):
16
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
+ self.model = None
18
+ self._load_model()
19
+
20
+ def _load_model(self):
21
+ """Load Real-ESRGAN model"""
22
+ try:
23
+ from realesrgan import RealESRGANer
24
+ from basicsr.archs.rrdbnet_arch import RRDBNet
25
+
26
+ # Create model
27
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
28
+ num_block=23, num_grow_ch=32, scale=4)
29
+
30
+ # Initialize Real-ESRGAN
31
+ model_path = "weights/RealESRGAN_x4plus.pth"
32
+
33
+ if os.path.exists(model_path):
34
+ self.model = RealESRGANer(
35
+ scale=4,
36
+ model_path=model_path,
37
+ model=model,
38
+ tile=0,
39
+ pre_pad=0,
40
+ half=False,
41
+ device=self.device
42
+ )
43
+ else:
44
+ print("Warning: Real-ESRGAN model not found, using fallback enhancement")
45
+ self.model = None
46
+
47
+ except Exception as e:
48
+ print(f"Error loading Real-ESRGAN: {e}")
49
+ self.model = None
50
+
51
+ def enhance_image(self, image: np.ndarray, scale: int = 4) -> np.ndarray:
52
+ """
53
+ Enhance image quality using Real-ESRGAN or fallback methods
54
+
55
+ Args:
56
+ image: Input image (BGR or BGRA)
57
+ scale: Upscaling factor (2 or 4)
58
+
59
+ Returns:
60
+ Enhanced image
61
+ """
62
+ # Handle alpha channel
63
+ has_alpha = len(image.shape) == 3 and image.shape[2] == 4
64
+
65
+ if has_alpha:
66
+ # Separate alpha channel
67
+ bgr = image[:, :, :3]
68
+ alpha = image[:, :, 3]
69
+ else:
70
+ bgr = image
71
+ alpha = None
72
+
73
+ # Enhance RGB channels
74
+ if self.model is not None and scale > 1:
75
+ try:
76
+ # Convert BGR to RGB for the model
77
+ rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
78
+
79
+ # Apply Real-ESRGAN
80
+ enhanced_rgb, _ = self.model.enhance(rgb, outscale=scale)
81
+
82
+ # Convert back to BGR
83
+ enhanced_bgr = cv2.cvtColor(enhanced_rgb, cv2.COLOR_RGB2BGR)
84
+
85
+ except Exception as e:
86
+ print(f"Real-ESRGAN failed, using fallback: {e}")
87
+ enhanced_bgr = self._fallback_enhance(bgr, scale)
88
+ else:
89
+ enhanced_bgr = self._fallback_enhance(bgr, scale)
90
+
91
+ # Enhance alpha channel if present
92
+ if alpha is not None and scale > 1:
93
+ enhanced_alpha = cv2.resize(alpha, None, fx=scale, fy=scale,
94
+ interpolation=cv2.INTER_NEAREST)
95
+
96
+ # Merge channels
97
+ enhanced_image = cv2.merge([enhanced_bgr, enhanced_alpha])
98
+ else:
99
+ enhanced_image = enhanced_bgr
100
+
101
+ return enhanced_image
102
+
103
+ def _fallback_enhance(self, image: np.ndarray, scale: int) -> np.ndarray:
104
+ """
105
+ Fallback enhancement using OpenCV
106
+
107
+ Args:
108
+ image: Input BGR image
109
+ scale: Upscaling factor
110
+
111
+ Returns:
112
+ Enhanced image
113
+ """
114
+ # Resize with high-quality interpolation
115
+ new_width = int(image.shape[1] * scale)
116
+ new_height = int(image.shape[0] * scale)
117
+
118
+ enhanced = cv2.resize(image, (new_width, new_height),
119
+ interpolation=cv2.INTER_CUBIC)
120
+
121
+ # Apply sharpening
122
+ kernel = np.array([[-1, -1, -1],
123
+ [-1, 9, -1],
124
+ [-1, -1, -1]])
125
+ enhanced = cv2.filter2D(enhanced, -1, kernel)
126
+
127
+ # Denoise
128
+ enhanced = cv2.fastNlMeansDenoisingColored(enhanced, None, 5, 5, 7, 21)
129
+
130
+ return enhanced
131
+
132
+ def sharpen_image(self, image: np.ndarray, strength: float = 1.0) -> np.ndarray:
133
+ """
134
+ Apply sharpening filter
135
+
136
+ Args:
137
+ image: Input image
138
+ strength: Sharpening strength
139
+
140
+ Returns:
141
+ Sharpened image
142
+ """
143
+ kernel = np.array([[-1, -1, -1],
144
+ [-1, 9, -1],
145
+ [-1, -1, -1]]) * strength
146
+
147
+ sharpened = cv2.filter2D(image, -1, kernel)
148
+ return sharpened
149
+
150
+ def remove_blur(self, image: np.ndarray) -> np.ndarray:
151
+ """
152
+ Reduce blur using deconvolution
153
+
154
+ Args:
155
+ image: Input image
156
+
157
+ Returns:
158
+ Deblurred image
159
+ """
160
+ # Create a point spread function (PSF)
161
+ psf_size = 5
162
+ psf = np.ones((psf_size, psf_size)) / (psf_size ** 2)
163
+
164
+ # Simple deconvolution (Wiener filter approximation)
165
+ result = image.copy()
166
+
167
+ for i in range(3): # For each channel
168
+ channel = image[:, :, i].astype(np.float32) / 255.0
169
+
170
+ # FFT
171
+ psf_fft = np.fft.fft2(psf, s=channel.shape)
172
+ channel_fft = np.fft.fft2(channel)
173
+
174
+ # Wiener deconvolution
175
+ K = 0.01 # Noise to signal ratio
176
+ deconv_fft = channel_fft * np.conj(psf_fft) / (np.abs(psf_fft) ** 2 + K)
177
+
178
+ # Inverse FFT
179
+ deconv = np.fft.ifft2(deconv_fft).real
180
+
181
+ # Clip and convert back
182
+ deconv = np.clip(deconv * 255, 0, 255).astype(np.uint8)
183
+ result[:, :, i] = deconv
184
+
185
+ return result
186
+
187
+ def enhance_contrast(self, image: np.ndarray) -> np.ndarray:
188
+ """
189
+ Enhance contrast using CLAHE
190
+
191
+ Args:
192
+ image: Input image
193
+
194
+ Returns:
195
+ Contrast-enhanced image
196
+ """
197
+ lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
198
+ l, a, b = cv2.split(lab)
199
+
200
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
201
+ l = clahe.apply(l)
202
+
203
+ enhanced = cv2.merge([l, a, b])
204
+ enhanced = cv2.cvtColor(enhanced, cv2.COLOR_LAB2BGR)
205
+
206
+ return enhanced