| import os |
| from pathlib import Path |
| import numpy as np |
| import tempfile |
| import tensorflow as tf |
| import mediapy |
| from PIL import Image |
| import cog |
|
|
| from eval import interpolator, util |
|
|
| _UINT8_MAX_F = float(np.iinfo(np.uint8).max) |
|
|
|
|
| class Predictor(cog.Predictor): |
| def setup(self): |
| import tensorflow as tf |
| print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU'))) |
| self.interpolator = interpolator.Interpolator("pretrained_models/film_net/Style/saved_model", None) |
|
|
| |
| self.batch_dt = np.full(shape=(1,), fill_value=0.5, dtype=np.float32) |
|
|
| @cog.input( |
| "frame1", |
| type=Path, |
| help="The first input frame", |
| ) |
| @cog.input( |
| "frame2", |
| type=Path, |
| help="The second input frame", |
| ) |
| @cog.input( |
| "times_to_interpolate", |
| type=int, |
| default=1, |
| min=1, |
| max=8, |
| help="Controls the number of times the frame interpolator is invoked If set to 1, the output will be the " |
| "sub-frame at t=0.5; when set to > 1, the output will be the interpolation video with " |
| "(2^times_to_interpolate + 1) frames, fps of 30.", |
| ) |
| def predict(self, frame1, frame2, times_to_interpolate): |
| INPUT_EXT = ['.png', '.jpg', '.jpeg'] |
| assert os.path.splitext(str(frame1))[-1] in INPUT_EXT and os.path.splitext(str(frame2))[-1] in INPUT_EXT, \ |
| "Please provide png, jpg or jpeg images." |
|
|
| |
| img1 = Image.open(str(frame1)) |
| img2 = Image.open(str(frame2)) |
| if not img1.size == img2.size: |
| img1 = img1.crop((0, 0, min(img1.size[0], img2.size[0]), min(img1.size[1], img2.size[1]))) |
| img2 = img2.crop((0, 0, min(img1.size[0], img2.size[0]), min(img1.size[1], img2.size[1]))) |
| frame1 = 'new_frame1.png' |
| frame2 = 'new_frame2.png' |
| img1.save(frame1) |
| img2.save(frame2) |
|
|
| if times_to_interpolate == 1: |
| |
| image_1 = util.read_image(str(frame1)) |
| image_batch_1 = np.expand_dims(image_1, axis=0) |
| |
| |
| image_2 = util.read_image(str(frame2)) |
| image_batch_2 = np.expand_dims(image_2, axis=0) |
| |
| |
| |
| mid_frame = self.interpolator.interpolate(image_batch_1, image_batch_2, self.batch_dt)[0] |
| out_path = Path(tempfile.mkdtemp()) / "out.png" |
| util.write_image(str(out_path), mid_frame) |
| return out_path |
|
|
|
|
| input_frames = [str(frame1), str(frame2)] |
|
|
| frames = list( |
| util.interpolate_recursively_from_files( |
| input_frames, times_to_interpolate, self.interpolator)) |
| print('Interpolated frames generated, saving now as output video.') |
|
|
| ffmpeg_path = util.get_ffmpeg_path() |
| mediapy.set_ffmpeg(ffmpeg_path) |
| out_path = Path(tempfile.mkdtemp()) / "out.mp4" |
| mediapy.write_video(str(out_path), frames, fps=30) |
| return out_path |
|
|