|
|
| import tensorflow as tf |
| import numpy as np |
| from einops import rearrange |
| from decord import VideoReader |
|
|
| num_frames = 16 |
| input_size = 224 |
| patch_size = (16, 16) |
| IMAGENET_MEAN = np.array([0.45, 0.45, 0.45]) |
| IMAGENET_STD = np.array([0.225, 0.225, 0.225]) |
|
|
| def format_frames(frame, output_size): |
| frame = tf.image.convert_image_dtype(frame, tf.uint8) |
| frame = tf.image.resize(frame, size=output_size) |
| frame = frame / 255. |
| frame = frame - IMAGENET_MEAN |
| frame = frame / IMAGENET_STD |
| return frame |
|
|
| def read_video(file_path): |
| container = VideoReader(file_path) |
| return container |
|
|
| def frame_sampling(container, num_frames): |
| interval = len(container) // num_frames |
| bids = np.arange(num_frames) * interval |
| offset = np.random.randint(interval, size=bids.shape) |
| frame_index = bids + offset |
| frames = container.get_batch(frame_index).asnumpy() |
| frames = np.stack(frames) |
| frames = format_frames(frames, [input_size] * 2) |
| return frames |
|
|
| def denormalize(image): |
| image = image.numpy() if not isinstance(image, np.ndarray) else image |
| image = image * IMAGENET_STD + IMAGENET_MEAN |
| image = (image * 255).clip(0, 255).astype('uint8') |
| return image |
|
|
| def reconstrunction(input_frame, bool_mask, pretrained_pred): |
| img_squeeze = rearrange( |
| input_frame.numpy(), |
| 'b (t p0) (h p1) (w p2) c -> b (t h w) (p0 p1 p2) c', |
| p0=2, p1=patch_size[0], p2=patch_size[0] |
| ) |
| img_mean = np.mean(img_squeeze, axis=-2, keepdims=True) |
| img_variance = np.var(img_squeeze, axis=-2, ddof=1, keepdims=True) |
| img_norm = (img_squeeze - img_mean) / (np.sqrt(img_variance) + 1e-6) |
| img_patch = rearrange(img_norm, 'b n p c -> b n (p c)') |
| img_patch[bool_mask] = pretrained_pred |
|
|
| |
| mask = np.ones_like(img_patch) |
| mask[bool_mask] = 0 |
| mask = rearrange( |
| mask, 'b n (p c) -> b n p c', c=3 |
| ) |
| mask = rearrange( |
| mask, |
| 'b (t h w) (p0 p1 p2) c -> b (t p0) (h p1) (w p2) c', |
| p0=2, p1=patch_size[0], p2=patch_size[1], h=14, w=14 |
| ) |
|
|
| |
| rec_img = rearrange(img_patch, 'b n (p c) -> b n p c', c=3) |
|
|
| |
| img_mean = np.mean(img_squeeze, axis=-2, keepdims=True) |
| img_std = np.sqrt(np.var(img_squeeze, axis=-2, ddof=1, keepdims=True) + 1e-6) |
| rec_img = rec_img * img_std + img_mean |
| rec_img = rearrange( |
| rec_img, |
| 'b (t h w) (p0 p1 p2) c -> b (t p0) (h p1) (w p2) c', |
| p0=2, p1=patch_size[0], p2=patch_size[1], h=14, w=14 |
| ) |
|
|
| return ( |
| rec_img[0], |
| mask[0] |
| ) |
|
|
|
|
| class TubeMaskingGenerator: |
| def __init__(self, input_size, mask_ratio): |
| self.frames, self.height, self.width = input_size |
| self.num_patches_per_frame = self.height * self.width |
| self.total_patches = self.frames * self.num_patches_per_frame |
| self.num_masks_per_frame = int(mask_ratio * self.num_patches_per_frame) |
| self.total_masks = self.frames * self.num_masks_per_frame |
|
|
| def __repr__(self): |
| repr_str = "Maks: total patches {}, mask patches {}".format( |
| self.total_patches, self.total_masks |
| ) |
| return repr_str |
|
|
| def __call__(self): |
| mask_per_frame = np.hstack([ |
| np.zeros(self.num_patches_per_frame - self.num_masks_per_frame), |
| np.ones(self.num_masks_per_frame), |
| ]) |
| np.random.shuffle(mask_per_frame) |
| mask = np.tile(mask_per_frame, (self.frames,1)).flatten() |
| return mask |