Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import matplotlib.pyplot as plt | |
| def image_grid( | |
| images, | |
| rows=None, | |
| cols=None, | |
| fill: bool = True, | |
| show_axes: bool = False, | |
| rgb: bool = True, | |
| ): | |
| """ | |
| A util function for plotting a grid of images. | |
| Args: | |
| images: (N, H, W, 4) array of RGBA images | |
| rows: number of rows in the grid | |
| cols: number of columns in the grid | |
| fill: boolean indicating if the space between images should be filled | |
| show_axes: boolean indicating if the axes of the plots should be visible | |
| rgb: boolean, If True, only RGB channels are plotted. | |
| If False, only the alpha channel is plotted. | |
| Returns: | |
| None | |
| """ | |
| if (rows is None) != (cols is None): | |
| raise ValueError("Specify either both rows and cols or neither.") | |
| if rows is None: | |
| rows = len(images) | |
| cols = 1 | |
| gridspec_kw = {"wspace": 0.0, "hspace": 0.0} if fill else {} | |
| fig, axarr = plt.subplots(rows, cols, gridspec_kw=gridspec_kw, figsize=(15, 9)) | |
| bleed = 0 | |
| fig.subplots_adjust(left=bleed, bottom=bleed, right=(1 - bleed), top=(1 - bleed)) | |
| for ax, im in zip(axarr.ravel(), images): | |
| if rgb: | |
| # only render RGB channels | |
| ax.imshow(im[..., :3]) | |
| else: | |
| # only render Alpha channel | |
| ax.imshow(im[..., 3]) | |
| if not show_axes: | |
| ax.set_axis_off() | |