| import torch |
| from torchvision import transforms |
|
|
| import numpy as np |
| from skimage.color import rgb2lab, lab2rgb |
| import skimage.transform |
| from PIL import Image |
|
|
| import os |
| from tqdm import tqdm |
| from moviepy.editor import VideoFileClip, AudioFileClip |
| from moviepy.tools import cvsecs |
| import cv2 |
|
|
| from pdb import set_trace |
|
|
|
|
| def lab_to_rgb(L, ab): |
| """ |
| Takes a batch of images |
| """ |
| L = (L + 1.) * 50. |
| ab = ab * 110. |
| Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy() |
| rgb_imgs = [] |
| for img in Lab: |
| img_rgb = lab2rgb(img) |
| rgb_imgs.append(img_rgb) |
| return np.stack(rgb_imgs, axis=0) |
|
|
|
|
| SIZE = 256 |
|
|
|
|
| def get_L(img): |
| img = transforms.Resize( |
| (SIZE, SIZE), transforms.InterpolationMode.BICUBIC)(img) |
| img = np.array(img) |
| img_lab = rgb2lab(img).astype("float32") |
| img_lab = transforms.ToTensor()(img_lab) |
| L = img_lab[[0], ...] / 50. - 1. |
|
|
| return L |
|
|
|
|
| def get_predictions(model, L): |
| |
| model.eval() |
| with torch.no_grad(): |
| model.L = L.to(torch.device('cpu')) |
| model.forward() |
| fake_color = model.fake_color.detach() |
| fake_imgs = lab_to_rgb(L, fake_color) |
|
|
| return fake_imgs |
|
|
|
|
| def colorize_img(model, img): |
| L = get_L(img) |
| L = L[None] |
| fake_imgs = get_predictions(model, L) |
| fake_img = fake_imgs[0] |
| resized_fake_img = skimage.transform.resize( |
| fake_img, img.size[::-1]) |
|
|
| return resized_fake_img |
|
|
|
|
| def valid_start_end(duration, start_input, end_input): |
| start = start_input |
| end = end_input |
| if start == '': |
| start = 0 |
| if end == '': |
| end = duration |
|
|
| try: |
| start = cvsecs(start) |
| end = cvsecs(end) |
| except BaseException: |
| |
| raise Exception("Invalid start, end values") |
|
|
| |
| start = max(start, 0) |
| end = min(duration, end) |
|
|
| |
| if start >= end: |
| raise Exception("Start must be before end.") |
|
|
| return start, end |
|
|
|
|
| def colorize_vid(path_input, model, fps, start_input, end_input): |
|
|
| original_video = VideoFileClip(path_input) |
|
|
| |
| start, end = valid_start_end( |
| original_video.duration, start_input, end_input) |
|
|
| input_video = original_video.subclip(start, end) |
|
|
| if isinstance(fps, int): |
| used_fps = fps |
| nframes = np.round(fps * input_video.duration) |
| else: |
| used_fps = input_video.fps |
| nframes = input_video.reader.nframes |
| print( |
| f"Colorizing output with FPS: {fps}, nframes: {nframes}, resolution: {input_video.size}.") |
|
|
| frames = input_video.iter_frames(fps=used_fps) |
|
|
| |
| base_path, suffix = os.path.splitext(path_input) |
| path_video_tmp = base_path + "_tmp" + suffix |
|
|
| |
| size = input_video.size |
| out = cv2.VideoWriter( |
| path_video_tmp, |
| cv2.VideoWriter_fourcc( |
| *'mp4v'), |
| used_fps, |
| size) |
| |
|
|
| for frame in tqdm(frames, total=nframes): |
| |
| color_frame = colorize_img(model, Image.fromarray(frame)) |
|
|
| if color_frame.max() <= 1: |
| color_frame = (color_frame * 255).astype(np.uint8) |
|
|
| color_frame = cv2.cvtColor(color_frame, cv2.COLOR_BGR2RGB) |
| out.write(color_frame) |
| out.release() |
|
|
| |
| path_output = base_path + "_out" + suffix |
|
|
| |
| path_audio_tmp = base_path + "audio_tmp.mp3" |
| input_video.audio.write_audiofile(path_audio_tmp, logger=None) |
| input_audio = AudioFileClip(path_audio_tmp) |
|
|
| output_video = VideoFileClip(path_video_tmp) |
| output_video = output_video.set_audio(input_audio) |
| output_video.write_videofile(path_output, logger=None) |
|
|
| os.remove(path_video_tmp) |
| os.remove(path_audio_tmp) |
|
|
| print("Done.") |
| return path_output |
|
|