| import io |
| import IPython.display |
| import PIL.Image |
| import os |
| from pprint import pformat |
| import numpy as np |
|
|
| def imgrid(imarray, cols=4, pad=1, padval=255, row_major=True): |
| """Lays out a [N, H, W, C] image array as a single image grid.""" |
| pad = int(pad) |
| if pad < 0: |
| raise ValueError('pad must be non-negative') |
| cols = int(cols) |
| assert cols >= 1 |
| N, H, W, C = imarray.shape |
| rows = N // cols + int(N % cols != 0) |
| batch_pad = rows * cols - N |
| assert batch_pad >= 0 |
| post_pad = [batch_pad, pad, pad, 0] |
| pad_arg = [[0, p] for p in post_pad] |
| imarray = np.pad(imarray, pad_arg, 'constant', constant_values=padval) |
| H += pad |
| W += pad |
| grid = (imarray |
| .reshape(rows, cols, H, W, C) |
| .transpose(0, 2, 1, 3, 4) |
| .reshape(rows*H, cols*W, C)) |
| if pad: |
| grid = grid[:-pad, :-pad] |
| return grid |
|
|
| def interleave(*args): |
| """Interleaves input arrays of the same shape along the batch axis.""" |
| if not args: |
| raise ValueError('At least one argument is required.') |
| a0 = args[0] |
| if any(a.shape != a0.shape for a in args): |
| raise ValueError('All inputs must have the same shape.') |
| if not a0.shape: |
| raise ValueError('Inputs must have at least one axis.') |
| out = np.transpose(args, [1, 0] + list(range(2, len(a0.shape) + 1))) |
| out = out.reshape(-1, *a0.shape[1:]) |
| return out |
|
|
| def imshow(a, format='png', jpeg_fallback=True): |
| """Displays an image in the given format.""" |
| a = a.astype(np.uint8) |
| data = io.BytesIO() |
| PIL.Image.fromarray(a).save(data, format) |
| im_data = data.getvalue() |
| try: |
| disp = IPython.display.display(IPython.display.Image(im_data)) |
| except IOError: |
| if jpeg_fallback and format != 'jpeg': |
| print ('Warning: image was too large to display in format "{}"; ' |
| 'trying jpeg instead.').format(format) |
| return imshow(a, format='jpeg') |
| else: |
| raise |
| return disp |
|
|
| def image_to_uint8(x): |
| """Converts [-1, 1] float array to [0, 255] uint8.""" |
| x = np.asarray(x) |
| x = (256. / 2.) * (x + 1.) |
| x = np.clip(x, 0, 255) |
| x = x.astype(np.uint8) |
| return x |
|
|