| |
| |
| |
| |
|
|
| import spaces |
| import gradio as gr |
| import cv2 |
| import numpy as np |
| import time |
| import random |
| from PIL import Image |
| import torch |
| import re |
| import os |
| import shutil |
| import subprocess |
| import tempfile |
|
|
| torch.jit.script = lambda f: f |
|
|
| from transparent_background import Remover |
|
|
| @spaces.GPU(duration=90) |
| def doo(video, color, mode, out_format, progress=gr.Progress()): |
| print(str(color)) |
| if str(color).startswith('#'): |
| color = color.lstrip('#') |
| rgb = tuple(int(color[i:i+2], 16) for i in (0, 2, 4)) |
| color = str(list(rgb)) |
| elif str(color).startswith('rgba'): |
| rgba_match = re.match(r'rgba\(([\d.]+), ([\d.]+), ([\d.]+), [\d.]+\)', color) |
| if rgba_match: |
| r, g, b = rgba_match.groups() |
| color = str([int(float(r)), int(float(g)), int(float(b))]) |
| print("Parsed color:", color) |
| if mode == 'Fast': |
| remover = Remover(mode='fast') |
| else: |
| remover = Remover() |
|
|
| cap = cv2.VideoCapture(video) |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| fps = cap.get(cv2.CAP_PROP_FPS) or 25.0 |
| writer = None |
| tmpname = random.randint(111111111, 999999999) |
| processed_frames = 0 |
| start_time = time.time() |
|
|
| mp4_path = str(tmpname) + '.mp4' |
| webm_path = str(tmpname) + '.webm' |
|
|
| if out_format == 'mp4': |
| while cap.isOpened(): |
| ret, frame = cap.read() |
|
|
| if ret is False: |
| break |
|
|
| if time.time() - start_time >= 20 * 60 - 5: |
| print("GPU Timeout is coming") |
| cap.release() |
| if writer is not None: |
| writer.release() |
| return mp4_path |
|
|
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| img = Image.fromarray(frame).convert('RGB') |
|
|
| if writer is None: |
| writer = cv2.VideoWriter(mp4_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, img.size) |
|
|
| processed_frames += 1 |
| print(f"Processing frame {processed_frames}") |
| progress(processed_frames / total_frames, desc=f"Processing frame {processed_frames}/{total_frames}") |
|
|
| out = remover.process(img, type=color) |
|
|
| frame_bgr = cv2.cvtColor(np.array(out), cv2.COLOR_RGB2BGR) |
| writer.write(frame_bgr) |
|
|
| cap.release() |
| if writer is not None: |
| writer.release() |
| return mp4_path |
|
|
| else: |
| temp_dir = tempfile.mkdtemp(prefix=f"tb_{tmpname}_") |
| try: |
| frame_idx = 0 |
| while cap.isOpened(): |
| ret, frame = cap.read() |
|
|
| if ret is False: |
| break |
|
|
| if time.time() - start_time >= 20 * 60 - 5: |
| print("GPU Timeout is coming") |
| cap.release() |
| |
| shutil.rmtree(temp_dir, ignore_errors=True) |
| return webm_path |
|
|
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| img = Image.fromarray(frame).convert('RGB') |
|
|
| processed_frames += 1 |
| frame_idx += 1 |
| print(f"Processing frame {processed_frames}") |
| progress(processed_frames / total_frames, desc=f"Processing frame {processed_frames}/{total_frames}") |
|
|
| out = remover.process(img, type='rgba') |
| out = out.convert('RGBA') |
|
|
| frame_name = os.path.join(temp_dir, f"frame_{frame_idx:06d}.png") |
| out.save(frame_name, 'PNG') |
|
|
| cap.release() |
|
|
| fr_str = str(int(round(fps))) if fps > 0 else "25" |
| pattern = os.path.join(temp_dir, "frame_%06d.png") |
| ffmpeg_cmd = [ |
| "ffmpeg", "-y", |
| "-framerate", fr_str, |
| "-i", pattern, |
| "-i", str(video), |
| "-map", "0:v", |
| "-map", "1:a?", |
| "-c:v", "libvpx-vp9", |
| "-pix_fmt", "yuva420p", |
| "-auto-alt-ref", "0", |
| "-metadata:s:v:0", "alpha_mode=1", |
| "-c:a", "libopus", |
| "-shortest", |
| webm_path |
| ] |
| print("Running ffmpeg:", " ".join(ffmpeg_cmd)) |
| subprocess.run(ffmpeg_cmd, check=True) |
|
|
| shutil.rmtree(temp_dir, ignore_errors=True) |
| return webm_path |
|
|
| except subprocess.CalledProcessError as e: |
| print("ffmpeg failed:", e) |
| shutil.rmtree(temp_dir, ignore_errors=True) |
| return webm_path |
| except Exception as e: |
| print("Error during processing:", e) |
| shutil.rmtree(temp_dir, ignore_errors=True) |
| raise |
|
|
| title = "🎞️ Video Background Removal Tool 🎥" |
| description = """*Please note that if your video file is long (has a high number of frames), there is a chance that processing break due to GPU timeout. In this case, consider trying Fast mode.*""" |
|
|
| examples = [ |
| ['./input.mp4', '#00FF00', 'Normal', 'mp4'], |
| ] |
|
|
| iface = gr.Interface( |
| fn=doo, |
| inputs=[ |
| "video", |
| gr.ColorPicker(label="Background color", value="#00FF00"), |
| gr.components.Radio(['Normal', 'Fast'], label='Select mode', value='Normal', info='Normal is more accurate, but takes longer. | Fast has lower accuracy so the process will be faster.'), |
| gr.components.Radio(['mp4', 'webm'], label='Output format', value='mp4') |
| ], |
| outputs="video", |
| examples=examples, |
| title=title, |
| description=description |
| ) |
| iface.launch() |
|
|