Spaces:
Running
Running
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from concurrent.futures import ThreadPoolExecutor | |
| import torch | |
| from PIL import Image | |
| from torchvision import transforms as TF | |
| from tqdm.auto import tqdm | |
| import numpy as np | |
| def load_and_preprocess_images_square(image_path_list, target_size=1024): | |
| """ | |
| Load and preprocess images by center padding to square and resizing to target size. | |
| Also returns the position information of original pixels after transformation. | |
| Args: | |
| image_path_list (list): List of paths to image files | |
| target_size (int, optional): Target size for both width and height. Defaults to 518. | |
| Returns: | |
| tuple: ( | |
| torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, target_size, target_size), | |
| torch.Tensor: Array of shape (N, 5) containing [x1, y1, x2, y2, width, height] for each image | |
| ) | |
| Raises: | |
| ValueError: If the input list is empty | |
| """ | |
| # Check for empty list | |
| if len(image_path_list) == 0: | |
| raise ValueError("At least 1 image is required") | |
| images = [] | |
| original_coords = [] # Renamed from position_info to be more descriptive | |
| to_tensor = TF.ToTensor() | |
| for image_path in image_path_list: | |
| # Open image | |
| img = Image.open(image_path) | |
| # If there's an alpha channel, blend onto white background | |
| if img.mode == "RGBA": | |
| background = Image.new("RGBA", img.size, (255, 255, 255, 255)) | |
| img = Image.alpha_composite(background, img) | |
| # Convert to RGB | |
| img = img.convert("RGB") | |
| # Get original dimensions | |
| width, height = img.size | |
| # Make the image square by padding the shorter dimension | |
| max_dim = max(width, height) | |
| # Calculate padding | |
| left = (max_dim - width) // 2 | |
| top = (max_dim - height) // 2 | |
| # Calculate scale factor for resizing | |
| scale = target_size / max_dim | |
| # Calculate final coordinates of original image in target space | |
| x1 = left * scale | |
| y1 = top * scale | |
| x2 = (left + width) * scale | |
| y2 = (top + height) * scale | |
| # Store original image coordinates and scale | |
| original_coords.append(np.array([x1, y1, x2, y2, width, height])) | |
| # Create a new black square image and paste original | |
| square_img = Image.new("RGB", (max_dim, max_dim), (0, 0, 0)) | |
| square_img.paste(img, (left, top)) | |
| # Resize to target size | |
| square_img = square_img.resize((target_size, target_size), Image.Resampling.BICUBIC) | |
| # Convert to tensor | |
| img_tensor = to_tensor(square_img) | |
| images.append(img_tensor) | |
| # Stack all images | |
| images = torch.stack(images) | |
| original_coords = torch.from_numpy(np.array(original_coords)).float() | |
| # Add additional dimension if single image to ensure correct shape | |
| if len(image_path_list) == 1: | |
| if images.dim() == 3: | |
| images = images.unsqueeze(0) | |
| original_coords = original_coords.unsqueeze(0) | |
| return images, original_coords | |
| def load_and_preprocess_images(image_path_list, fx=None, fy=None, cx=None, cy=None, mode="crop", image_size=512, patch_size=16): | |
| """ | |
| A quick start function to load and preprocess images for model input. | |
| This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes. | |
| Args: | |
| image_path_list (list): List of paths to image files | |
| mode (str, optional): Preprocessing mode, either "crop" or "pad". | |
| - "crop" (default): Sets width to 518px and center crops height if needed. | |
| - "pad": Preserves all pixels by making the largest dimension 518px | |
| and padding the smaller dimension to reach a square shape. | |
| Returns: | |
| torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W) | |
| Raises: | |
| ValueError: If the input list is empty or if mode is invalid | |
| Notes: | |
| - Images with different dimensions will be padded with white (value=1.0) | |
| - A warning is printed when images have different shapes | |
| - When mode="crop": The function ensures width=518px while maintaining aspect ratio | |
| and height is center-cropped if larger than 518px | |
| - When mode="pad": The function ensures the largest dimension is 518px while maintaining aspect ratio | |
| and the smaller dimension is padded to reach a square shape (518x518) | |
| - Dimensions are adjusted to be divisible by 14 for compatibility with model requirements | |
| """ | |
| # Check for empty list | |
| if len(image_path_list) == 0: | |
| raise ValueError("At least 1 image is required") | |
| # Validate mode | |
| if mode not in ["crop", "pad"]: | |
| raise ValueError("Mode must be either 'crop' or 'pad'") | |
| target_size = image_size | |
| to_tensor = TF.ToTensor() | |
| def _load_one(idx_path): | |
| i, image_path = idx_path | |
| img = Image.open(image_path) | |
| if img.mode == "RGBA": | |
| background = Image.new("RGBA", img.size, (255, 255, 255, 255)) | |
| img = Image.alpha_composite(background, img) | |
| img = img.convert("RGB") | |
| width, height = img.size | |
| fx_val = fy_val = cx_val = cy_val = None | |
| if fx is not None: | |
| fx_val = fx[i] * width | |
| fy_val = fy[i] * height | |
| cx_val = cx[i] * width | |
| cy_val = cy[i] * height | |
| if mode == "pad": | |
| if width >= height: | |
| new_width = target_size | |
| new_height = round(height * (new_width / width) / patch_size) * patch_size | |
| else: | |
| new_height = target_size | |
| new_width = round(width * (new_height / height) / patch_size) * patch_size | |
| else: # crop | |
| new_width = target_size | |
| new_height = round(height * (new_width / width) / patch_size) * patch_size | |
| img = img.resize((new_width, new_height), Image.Resampling.BICUBIC) | |
| img = to_tensor(img) | |
| if mode == "crop" and new_height > target_size: | |
| start_y = (new_height - target_size) // 2 | |
| img = img[:, start_y : start_y + target_size, :] | |
| if fx is not None: | |
| fx_val = fx_val * new_width / width | |
| fy_val = fy_val * new_height / height | |
| cx_val = img.shape[2] / 2 | |
| cy_val = img.shape[1] / 2 | |
| if mode == "pad": | |
| h_padding = target_size - img.shape[1] | |
| w_padding = target_size - img.shape[2] | |
| if h_padding > 0 or w_padding > 0: | |
| pad_top = h_padding // 2 | |
| pad_bottom = h_padding - pad_top | |
| pad_left = w_padding // 2 | |
| pad_right = w_padding - pad_left | |
| img = torch.nn.functional.pad( | |
| img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0 | |
| ) | |
| return i, img, (fx_val, fy_val, cx_val, cy_val) | |
| # Parallel load with progress bar | |
| num_workers = min(16, len(image_path_list)) | |
| results = [None] * len(image_path_list) | |
| with ThreadPoolExecutor(max_workers=num_workers) as pool: | |
| futures = pool.map(_load_one, enumerate(image_path_list)) | |
| for i, img, calib in tqdm(futures, total=len(image_path_list), desc="Loading images"): | |
| results[i] = img | |
| if fx is not None: | |
| fx[i], fy[i], cx[i], cy[i] = calib | |
| images = results | |
| shapes = set((img.shape[1], img.shape[2]) for img in images) | |
| # Check if we have different shapes | |
| # In theory our model can also work well with different shapes | |
| if len(shapes) > 1: | |
| print(f"Warning: Found images with different shapes: {shapes}") | |
| # Find maximum dimensions | |
| max_height = max(shape[0] for shape in shapes) | |
| max_width = max(shape[1] for shape in shapes) | |
| # Pad images if necessary | |
| padded_images = [] | |
| for img in images: | |
| h_padding = max_height - img.shape[1] | |
| w_padding = max_width - img.shape[2] | |
| if h_padding > 0 or w_padding > 0: | |
| pad_top = h_padding // 2 | |
| pad_bottom = h_padding - pad_top | |
| pad_left = w_padding // 2 | |
| pad_right = w_padding - pad_left | |
| img = torch.nn.functional.pad( | |
| img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0 | |
| ) | |
| padded_images.append(img) | |
| images = padded_images | |
| images = torch.stack(images) # concatenate images | |
| # Ensure correct shape when single image | |
| if len(image_path_list) == 1: | |
| # Verify shape is (1, C, H, W) | |
| if images.dim() == 3: | |
| images = images.unsqueeze(0) | |
| if fx is not None: | |
| return images, fx, fy, cx, cy | |
| return images | |