| import os.path |
| import re |
| import torch |
| import time |
| import tempfile |
|
|
| import streamlit as st |
| from training.zoo.classifiers import DeepFakeClassifier |
| from kernel_utils import VideoReader, FaceExtractor, confident_strategy, predict_on_video_set |
|
|
|
|
| def load_model(): |
| path = 'weights/best.pth' |
| model = DeepFakeClassifier(encoder="tf_efficientnet_b7_ns") |
| print("loading state dict {}".format(path)) |
| checkpoint = torch.load(path, map_location="cpu") |
| state_dict = checkpoint.get("state_dict", checkpoint) |
| model.load_state_dict( |
| {re.sub("^module.", "", k): v for k, v in state_dict.items()}, |
| strict=True) |
| model.eval() |
| del checkpoint |
| return model |
|
|
|
|
| def write_bytesio_to_file(filename, bytesio): |
| with open(filename, "wb") as outfile: |
| outfile.write(bytesio.getbuffer()) |
|
|
|
|
| def load_video(): |
| uploaded_file = st.file_uploader(label='Pick a video (mp4) file to test') |
| if uploaded_file is not None: |
| video_data = uploaded_file.getvalue() |
| tfile = tempfile.NamedTemporaryFile(delete=False) |
| tfile.write(video_data) |
| return tfile.name |
| else: |
| return None |
|
|
|
|
| def inference(model, test_video): |
| frames_per_video = 32 |
| video_reader = VideoReader() |
| video_read_fn = lambda x: video_reader.read_frames( |
| x, num_frames=frames_per_video) |
| face_extractor = FaceExtractor(video_read_fn) |
| input_size = 380 |
| strategy = confident_strategy |
|
|
| test_videos = [test_video] |
| print("Predicting {} videos".format(len(test_videos))) |
| models = [model] |
| predictions = predict_on_video_set(face_extractor=face_extractor, |
| input_size=input_size, models=models, |
| strategy=strategy, |
| frames_per_video=frames_per_video, |
| videos=test_videos, |
| num_workers=6, test_dir="test_video") |
| st.write("Prediction: ", predictions[0]) |
|
|
|
|
| def main(): |
| st.title('Deepfake video inference demo') |
| model = load_model() |
| video_data_path = load_video() |
|
|
| if video_data_path is not None and os.path.exists(video_data_path): |
| st.video(video_data_path) |
|
|
| result = st.button('Run on video') |
| if result: |
| st.write("Inference on video...") |
| stime = time.time() |
| inference(model, video_data_path) |
| st.write("Elapsed time: ", time.time() - stime, " seconds") |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|