Diffusers
Safetensors
File size: 6,220 Bytes
40a3ea8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
from PIL import Image
import matplotlib
import numpy as np
import cv2
from PIL import Image
import torch
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import resize


def resize_image(image, target_size):
    """
    Resize output image to target size
    Args:
        image: Image in PIL.Image, numpy.array or torch.tensor format
        target_size: tuple, target size (H, W)
    Returns:
        Resized image in original format
    """
    if isinstance(image, list):
        return [resize_image(img, target_size) for img in image]
    
    if isinstance(image, Image.Image):
        return image.resize(target_size[::-1], Image.BILINEAR)
    elif isinstance(image, np.ndarray):
        # Handle numpy array with shape (1, H, W, 3)
        if image.ndim == 4:
            resized = np.stack([cv2.resize(img, target_size[::-1]) for img in image])
            return resized
        else:
            return cv2.resize(image, target_size[::-1])
    elif isinstance(image, torch.Tensor):
        # Handle tensor with shape (1, 3, H, W)
        if image.dim() == 4:
            return torch.nn.functional.interpolate(
                image,
                size=target_size,
                mode='bilinear',
                align_corners=False
            )
        else:
            return torch.nn.functional.interpolate(
                image.unsqueeze(0),
                size=target_size, 
                mode='bilinear',
                align_corners=False
            ).squeeze(0)
    else:
        raise ValueError(f"Unsupported image format: {type(image)}")

def resize_image_first(image_tensor, process_res=None):
    if process_res:
        max_edge = max(image_tensor.shape[2], image_tensor.shape[3])
        if max_edge > process_res:
            scale = process_res / max_edge
            new_height = int(image_tensor.shape[2] * scale)
            new_width = int(image_tensor.shape[3] * scale)
            image_tensor = resize_image(image_tensor, (new_height, new_width))
    
    image_tensor = resize_to_multiple_of_16(image_tensor)
    
    return image_tensor

def resize_to_multiple_of_16(image_tensor):
    """
    Resize image tensor to make shorter side closest multiple of 16 while maintaining aspect ratio
    Args:
        image_tensor: Input tensor of shape (B, C, H, W)
    Returns:
        Resized tensor where shorter side is multiple of 16
    """
    # Calculate scale ratio based on shorter side to make it closest multiple of 16
    h, w = image_tensor.shape[2], image_tensor.shape[3]
    min_side = min(h, w)
    scale = (min_side // 16) * 16 / min_side
    
    # Calculate new height and width
    new_h = int(h * scale)
    new_w = int(w * scale)
    
    # Ensure both height and width are multiples of 16
    new_h = (new_h // 16) * 16  
    new_w = (new_w // 16) * 16

    # Resize image while maintaining aspect ratio
    resized_tensor = torch.nn.functional.interpolate(
        image_tensor,
        size=(new_h, new_w),
        mode='bilinear',
        align_corners=False
    )
    return resized_tensor

def colorize_depth_map(depth, mask=None, reverse_color=False):
    cm = matplotlib.colormaps["Spectral"]
    # normalize
    depth = ((depth - depth.min()) / (depth.max() - depth.min()))
    # colorize
    if reverse_color:
        img_colored_np = cm(1 - depth, bytes=False)[:, :, 0:3]  # Invert the depth values before applying colormap
    else:
        img_colored_np = cm(depth, bytes=False)[:, :, 0:3] # (h,w,3)

    depth_colored = (img_colored_np * 255).astype(np.uint8) 
    if mask is not None:
        masked_image = np.zeros_like(depth_colored)
        masked_image[mask.numpy()] = depth_colored[mask.numpy()]
        depth_colored_img = Image.fromarray(masked_image)
    else:
        depth_colored_img = Image.fromarray(depth_colored)
    return depth_colored_img


def concatenate_images(*image_lists):
    # Ensure at least one image list is provided
    if not image_lists or not image_lists[0]:
        raise ValueError("At least one non-empty image list must be provided")
    
    # Determine the maximum width of any single row and the total height
    max_width = 0
    total_height = 0
    row_widths = []
    row_heights = []

    # Compute dimensions for each row
    for image_list in image_lists:
        if image_list:  # Ensure the list is not empty
            width = sum(img.width for img in image_list)
            height = image_list[0].height  # Assuming all images in the list have the same height
            max_width = max(max_width, width)
            total_height += height
            row_widths.append(width)
            row_heights.append(height)
    
    # Create a new image to concatenate everything into
    new_image = Image.new('RGB', (max_width, total_height))
    
    # Concatenate each row of images
    y_offset = 0
    for i, image_list in enumerate(image_lists):
        x_offset = 0
        for img in image_list:
            new_image.paste(img, (x_offset, y_offset))
            x_offset += img.width
        y_offset += row_heights[i]  # Move the offset down to the next row
    
    return new_image


def resize_max_res(
    img: torch.Tensor,
    max_edge_resolution: int,
    resample_method: InterpolationMode = InterpolationMode.BILINEAR,
) -> torch.Tensor:
    """
    Resize image to limit maximum edge length while keeping aspect ratio.

    Args:
        img (`torch.Tensor`):
            Image tensor to be resized. Expected shape: [B, C, H, W]
        max_edge_resolution (`int`):
            Maximum edge length (pixel).
        resample_method (`PIL.Image.Resampling`):
            Resampling method used to resize images.

    Returns:
        `torch.Tensor`: Resized image.
    """
    assert 4 == img.dim(), f"Invalid input shape {img.shape}"

    original_height, original_width = img.shape[-2:]
    downscale_factor = min(
        max_edge_resolution / original_width, max_edge_resolution / original_height
    )

    new_width = int(original_width * downscale_factor)
    new_height = int(original_height * downscale_factor)

    resized_img = resize(img, (new_height, new_width), resample_method, antialias=True)
    return resized_img