welcom / frame_detector.py
Skydata001's picture
Upload 3 files
925fbb1 verified
"""
Frame Detection Module
Automatically detects and extracts sprite frames from sprite sheets
"""
import cv2
import numpy as np
from typing import List, Tuple
from scipy import ndimage
from skimage import measure
class FrameDetector:
"""Detector for finding sprite frames in sprite sheets"""
def __init__(self):
self.min_frame_size = 16 # Minimum frame size in pixels
self.max_frame_size = 512 # Maximum frame size in pixels
def detect_frames_auto(self, image: np.ndarray, padding: int = 2) -> Tuple[List[np.ndarray], List[Tuple]]:
"""
Automatically detect frames in a sprite sheet
Args:
image: Input sprite sheet image
padding: Padding to add around each frame
Returns:
Tuple of (list of frame images, list of bounding boxes)
"""
# Handle alpha channel
if len(image.shape) == 3 and image.shape[2] == 4:
# Use alpha channel for detection
alpha = image[:, :, 3]
# Create binary mask where alpha > 0
_, binary = cv2.threshold(alpha, 10, 255, cv2.THRESH_BINARY)
else:
# Convert to grayscale and create mask
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
_, binary = cv2.threshold(gray, 10, 255, cv2.THRESH_BINARY)
# Find connected components
labels = measure.label(binary, connectivity=2)
regions = measure.regionprops(labels)
# Filter regions by size
valid_regions = []
for region in regions:
minr, minc, maxr, maxc = region.bbox
width = maxc - minc
height = maxr - minr
if (self.min_frame_size <= width <= self.max_frame_size and
self.min_frame_size <= height <= self.max_frame_size):
valid_regions.append(region)
# If no valid regions found, try grid-based detection
if len(valid_regions) == 0:
return self._detect_grid_based(image, padding)
# Sort regions by x-coordinate (left to right)
valid_regions.sort(key=lambda r: r.bbox[1])
# Extract frames
frames = []
frame_boxes = []
for region in valid_regions:
minr, minc, maxr, maxc = region.bbox
# Add padding
minr = max(0, minr - padding)
minc = max(0, minc - padding)
maxr = min(image.shape[0], maxr + padding)
maxc = min(image.shape[1], maxc + padding)
# Extract frame
frame = image[minr:maxr, minc:maxc]
frames.append(frame)
frame_boxes.append((minr, minc, maxr, maxc))
# If too many small regions, use grid-based approach
if len(frames) > 20:
return self._detect_grid_based(image, padding)
return frames, frame_boxes
def _detect_grid_based(self, image: np.ndarray, padding: int = 2) -> Tuple[List[np.ndarray], List[Tuple]]:
"""
Detect frames using grid-based approach with improved filtering
Args:
image: Input sprite sheet image
padding: Padding to add around each frame
Returns:
Tuple of (list of frame images, list of bounding boxes)
"""
# Handle alpha channel
if len(image.shape) == 3 and image.shape[2] == 4:
alpha = image[:, :, 3]
_, binary = cv2.threshold(alpha, 10, 255, cv2.THRESH_BINARY)
else:
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
_, binary = cv2.threshold(gray, 10, 255, cv2.THRESH_BINARY)
# Apply morphological operations to clean up noise
kernel = np.ones((3, 3), np.uint8)
binary = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel)
binary = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel)
# Find horizontal and vertical projections
h_proj = np.sum(binary, axis=1)
v_proj = np.sum(binary, axis=0)
# Find gaps (low values in projection)
h_threshold = np.max(h_proj) * 0.1
v_threshold = np.max(v_proj) * 0.1
h_gaps = h_proj < h_threshold
v_gaps = v_proj < v_threshold
# Find row boundaries
row_boundaries = []
in_gap = True
for i, is_gap in enumerate(h_gaps):
if in_gap and not is_gap:
row_boundaries.append(i)
in_gap = False
elif not in_gap and is_gap:
row_boundaries.append(i)
in_gap = True
# Find column boundaries
col_boundaries = []
in_gap = True
for i, is_gap in enumerate(v_gaps):
if in_gap and not is_gap:
col_boundaries.append(i)
in_gap = False
elif not in_gap and is_gap:
col_boundaries.append(i)
in_gap = True
# Ensure even number of boundaries
if len(row_boundaries) % 2 != 0:
row_boundaries.append(binary.shape[0])
if len(col_boundaries) % 2 != 0:
col_boundaries.append(binary.shape[1])
# Extract frames
frames = []
frame_boxes = []
# If we have valid boundaries
if len(row_boundaries) >= 2 and len(col_boundaries) >= 2:
for i in range(0, len(row_boundaries), 2):
for j in range(0, len(col_boundaries), 2):
minr = row_boundaries[i]
maxr = row_boundaries[i + 1]
minc = col_boundaries[j]
maxc = col_boundaries[j + 1]
# Add padding
minr_p = max(0, minr - padding)
minc_p = max(0, minc - padding)
maxr_p = min(image.shape[0], maxr + padding)
maxc_p = min(image.shape[1], maxc + padding)
# Extract frame
frame = image[minr_p:maxr_p, minc_p:maxc_p]
# Check if frame has content and meets size requirements
frame_width = maxc - minc
frame_height = maxr - minr
# Filter out very small frames (likely effects/particles)
min_content_width = max(15, image.shape[1] // 50) # At least 15px or 2% of image width
min_content_height = max(20, image.shape[0] // 3) # At least 20px or 33% of image height
has_content = np.sum(frame) > 0
has_valid_size = (frame_width >= min_content_width and
frame_height >= min_content_height)
if has_content and has_valid_size:
frames.append(frame)
frame_boxes.append((minr_p, minc_p, maxr_p, maxc_p))
# If still no frames, use equal division
if len(frames) == 0:
return self._detect_equal_division(image, padding)
return frames, frame_boxes
def _detect_equal_division(self, image: np.ndarray, padding: int = 2,
num_frames: int = 8) -> Tuple[List[np.ndarray], List[Tuple]]:
"""
Detect frames by equal division
Args:
image: Input sprite sheet image
padding: Padding to add around each frame
num_frames: Number of frames to divide into
Returns:
Tuple of (list of frame images, list of bounding boxes)
"""
frames = []
frame_boxes = []
img_width = image.shape[1]
img_height = image.shape[0]
# Assume horizontal layout
frame_width = img_width // num_frames
frame_height = img_height
for i in range(num_frames):
minc = i * frame_width
maxc = (i + 1) * frame_width if i < num_frames - 1 else img_width
minr = 0
maxr = img_height
# Add padding
minc_p = max(0, minc - padding)
minr_p = max(0, minr - padding)
maxc_p = min(img_width, maxc + padding)
maxr_p = min(img_height, maxr + padding)
frame = image[minr_p:maxr_p, minc_p:maxc_p]
frames.append(frame)
frame_boxes.append((minr_p, minc_p, maxr_p, maxc_p))
return frames, frame_boxes
def detect_frames_manual(self, image: np.ndarray, num_frames: int,
padding: int = 2) -> Tuple[List[np.ndarray], List[Tuple]]:
"""
Manually specify number of frames
Args:
image: Input sprite sheet image
num_frames: Number of frames
padding: Padding to add around each frame
Returns:
Tuple of (list of frame images, list of bounding boxes)
"""
return self._detect_equal_division(image, padding, num_frames)
def refine_frame_boundaries(self, image: np.ndarray, frame: np.ndarray,
bbox: Tuple) -> Tuple[np.ndarray, Tuple]:
"""
Refine frame boundaries to remove excess transparent space
Args:
image: Original image
frame: Extracted frame
bbox: Bounding box (minr, minc, maxr, maxc)
Returns:
Refined frame and bounding box
"""
minr, minc, maxr, maxc = bbox
# Handle alpha channel
if len(frame.shape) == 3 and frame.shape[2] == 4:
alpha = frame[:, :, 3]
# Find non-transparent pixels
rows = np.any(alpha > 10, axis=1)
cols = np.any(alpha > 10, axis=0)
else:
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
rows = np.any(gray > 10, axis=1)
cols = np.any(gray > 10, axis=0)
# Find bounds
row_indices = np.where(rows)[0]
col_indices = np.where(cols)[0]
if len(row_indices) > 0 and len(col_indices) > 0:
# Calculate new bounds
new_minr = minr + row_indices[0]
new_maxr = minr + row_indices[-1] + 1
new_minc = minc + col_indices[0]
new_maxc = minc + col_indices[-1] + 1
# Extract refined frame
refined_frame = image[new_minr:new_maxr, new_minc:new_maxc]
return refined_frame, (new_minr, new_minc, new_maxr, new_maxc)
return frame, bbox
def detect_frame_size(self, image: np.ndarray) -> Tuple[int, int]:
"""
Detect the size of individual frames
Args:
image: Input sprite sheet image
Returns:
Tuple of (frame_width, frame_height)
"""
# Handle alpha channel
if len(image.shape) == 3 and image.shape[2] == 4:
alpha = image[:, :, 3]
_, binary = cv2.threshold(alpha, 10, 255, cv2.THRESH_BINARY)
else:
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
_, binary = cv2.threshold(gray, 10, 255, cv2.THRESH_BINARY)
# Find vertical projection
v_proj = np.sum(binary, axis=0)
# Find gaps
threshold = np.max(v_proj) * 0.1
gaps = v_proj < threshold
# Find gap positions
gap_starts = []
gap_ends = []
in_gap = False
for i, is_gap in enumerate(gaps):
if not in_gap and is_gap:
gap_starts.append(i)
in_gap = True
elif in_gap and not is_gap:
gap_ends.append(i)
in_gap = False
# Calculate frame width from gap positions
if len(gap_starts) > 0:
# Average distance between gaps
if len(gap_starts) > 1:
frame_width = int(np.mean(np.diff(gap_starts)))
else:
frame_width = gap_starts[0]
else:
# No gaps found, assume single frame
frame_width = image.shape[1]
frame_height = image.shape[0]
return frame_width, frame_height