Shreyas1441AI's picture
Create app.py
c9b6800 verified
import cv2
import numpy as np
import gradio as gr
import tempfile
import os
from tqdm import tqdm
# ----------------------------
# LOAD MODEL (GLOBAL)
# ----------------------------
MODEL_PATH = "mosaic.t7" # place model in repo root
net = cv2.dnn.readNetFromTorch(MODEL_PATH)
def style_video(input_video):
# ----------------------------
# OPEN INPUT VIDEO
# ----------------------------
cap = cv2.VideoCapture(input_video)
if not cap.isOpened():
raise RuntimeError("Could not open video")
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# ----------------------------
# TEMP OUTPUT FILE
# ----------------------------
temp_out = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
temp_out.close()
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
writer = cv2.VideoWriter(
temp_out.name,
fourcc,
fps,
(width, height)
)
# ----------------------------
# PROCESS FRAMES
# ----------------------------
for _ in tqdm(range(total_frames), desc="Styling frames"):
ret, frame = cap.read()
if not ret:
break
blob = cv2.dnn.blobFromImage(
frame,
1.0,
(width, height),
(103.939, 116.779, 123.680),
swapRB=False,
crop=False
)
net.setInput(blob)
output = net.forward()
output = output.reshape(3, output.shape[2], output.shape[3])
output[0] += 103.939
output[1] += 116.779
output[2] += 123.680
output = output.transpose(1, 2, 0)
output = np.clip(output, 0, 255).astype("uint8")
writer.write(output)
# ----------------------------
# CLEANUP
# ----------------------------
cap.release()
writer.release()
return temp_out.name
# ----------------------------
# GRADIO UI
# ----------------------------
app = gr.Interface(
fn=style_video,
inputs=gr.Video(label="Upload Video"),
outputs=gr.Video(label="Styled Video"),
title="Neural Style Transfer on Video",
description="Applies fast neural style transfer (Torch .t7) frame-by-frame using OpenCV."
)
if __name__ == "__main__":
app.launch()