OmniShotCut / test_code /test_benchmark.py
HikariDawn's picture
feat: initial push
796e051
raw
history blame
8.21 kB
'''
This file is for OmniShotCut benchmark massive testing
'''
import os, sys, shutil
import argparse
import numpy as np
import math
import subprocess
import cv2
from tqdm import tqdm
import ffmpeg
import time
import torch
import json
import torchvision.transforms as T
from torch.utils.data import DataLoader
import pickle
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
# Import files from the local folder
root_path = os.path.abspath('.')
sys.path.append(root_path)
from config.argument_setting import get_args_parser
from architecture.backbone import build_backbone
from architecture.transformer import build_transformer
from architecture.model import OmniShotCut
from datasets.transforms import Video_Augmentation_Transform
from util.visualization import visualize_concated_frames, concat_image_lists_horizontal
from config.label_correspondence import unique_intra_label_mapping, unique_inter_label_mapping
from test_code.inference import single_video_infernece, dump_list_of_dict
from evaluation.evaluate_SBD import evaluate_metrics
# Video Transform
video_transform = Video_Augmentation_Transform(set_type = "val")
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--checkpoint_path",
type = str,
default = "/scratch/usy5km/Cut_Anything/cut_anything_checkpoints/results_training_v12/ckpt_epoch82.pth",
help = "Path to checkpoint file."
)
parser.add_argument(
"--test_dataset_pkl_path",
type = str,
default = "/scratch/usy5km/Cut_Anything/cut_anything_benchmark/labels_5round.pkl",
help = "Path to test dataset pkl file."
)
parser.add_argument(
"--result_store_path",
type = str,
default = "results.json",
help="Path to save result json."
)
parser.add_argument(
"--num_context_frames",
type = int,
default = 0,
help = "Path to save result json."
)
parser.add_argument(
"--visual_store_folder_path",
type = str,
default = None,
help = "Path to save visualization results. Set to None to disable."
)
parser.add_argument(
"--merge_sudden_jump",
action = "store_true",
default = False,
help = "Whether to merge sudden jump."
)
return parser.parse_args()
if __name__ == '__main__':
# Setting
inference_args = parse_args()
checkpoint_path = inference_args.checkpoint_path
test_dataset_pkl_path = inference_args.test_dataset_pkl_path
result_store_path = inference_args.result_store_path
visual_store_folder_path = inference_args.visual_store_folder_path
merge_sudden_jump = inference_args.merge_sudden_jump
# Prepare the folder
if visual_store_folder_path is not None:
os.makedirs(visual_store_folder_path, exist_ok = True)
# Load Checkpoint & Model Config
assert(os.path.exists(checkpoint_path))
state_dict = torch.load(checkpoint_path, map_location='cpu')
model_args = state_dict['args']
print("Checkpoint stored args are", model_args)
# Init the Model
print("Load OmniShotCut Model!")
backbone = build_backbone(model_args)
transformer = build_transformer(model_args)
model = OmniShotCut(
backbone,
transformer,
num_intra_relation_classes = model_args.num_intra_relation_classes,
num_inter_relation_classes = model_args.num_inter_relation_classes,
num_frames = model_args.max_process_window_length,
num_queries = model_args.num_queries,
aux_loss = model_args.aux_loss,
)
model.load_state_dict(state_dict['model'], strict=True)
model.to("cuda")
model.eval()
# Read the pkl file
with open(test_dataset_pkl_path, "rb") as f:
test_data = pickle.load(f)
# Iterate all cases
print("Start Inference!")
pred_results = []
start_time = time.time()
for instance_idx, info_dict in enumerate(tqdm(test_data, desc="Testing")):
# Fetch info
video_path = info_dict["video_path"]
if not os.path.exists(video_path):
print("We cannot find", video_path)
assert(False)
# print("video path is", video_path, "for instance", instance_idx)
# Init result log
pred_result = {}
pred_result["video_path"] = video_path
pred_result["gt_ranges"] = info_dict["ranges"]
pred_result["gt_intra_labels"] = info_dict["intra_labels"]
pred_result["gt_inter_labels"] = info_dict["inter_labels"]
pred_result["gt_confidences"] = info_dict["confidences"]
# Do the single inference
pred_ranges_full, pred_intra_labels_full, pred_inter_labels_full, video_np_full = single_video_infernece(video_path, model, model_args, inference_args)
# Append prediction resutls
pred_result["pred_ranges"] = pred_ranges_full
pred_result["pred_intra_labels"] = pred_intra_labels_full
pred_result["pred_inter_labels"] = pred_inter_labels_full
pred_results.append(pred_result)
# Visualize
if visual_store_folder_path is not None:
# Visualize predictions
prediction_visual_store_path = os.path.join(visual_store_folder_path, "instance" + str(instance_idx) + "_pred")
if os.path.exists(prediction_visual_store_path):
shutil.rmtree(prediction_visual_store_path)
pred_saved_paths = visualize_concated_frames(video_np_full, prediction_visual_store_path, pred_ranges_full, max_frames_per_img=264, end_range_exclusive=True, fps=fps, start_index = 0)
# Visualize the GT results
gt_visual_store_path = os.path.join(visual_store_folder_path, "instance" + str(instance_idx) + "_gt")
if os.path.exists(gt_visual_store_path):
shutil.rmtree(gt_visual_store_path)
gt_ranges_full = info_dict['ranges']
gt_saved_paths = visualize_concated_frames(video_np_full, gt_visual_store_path, gt_ranges_full, max_frames_per_img=264, end_range_exclusive=True, fps=fps, start_index = 0)
# Merge Pred and GT on One for easier visual
merged_visual_store_path = os.path.join(visual_store_folder_path, "instance" + str(instance_idx) + "_merged")
if os.path.exists(merged_visual_store_path):
shutil.rmtree(merged_visual_store_path)
merged_paths = concat_image_lists_horizontal( # Left: ours predictions;Right: GT
list1 = pred_saved_paths,
list2 = gt_saved_paths,
out_dir = merged_visual_store_path,
bar_width = 80,
bar_color = (0, 255, 0), # (0, 255, 0) is green
)
# Store the result as json
dump_list_of_dict(pred_results, result_store_path)
# Do the evluation here automatically
evaluate_metrics(result_store_path, last_frame_exclusive=True)
# Final Log
print("Total time spent is", int(time.time() - start_time), "s!")
print("Finished!")