Spaces:
Running on Zero
Running on Zero
File size: 8,205 Bytes
796e051 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 | '''
This file is for OmniShotCut benchmark massive testing
'''
import os, sys, shutil
import argparse
import numpy as np
import math
import subprocess
import cv2
from tqdm import tqdm
import ffmpeg
import time
import torch
import json
import torchvision.transforms as T
from torch.utils.data import DataLoader
import pickle
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
# Import files from the local folder
root_path = os.path.abspath('.')
sys.path.append(root_path)
from config.argument_setting import get_args_parser
from architecture.backbone import build_backbone
from architecture.transformer import build_transformer
from architecture.model import OmniShotCut
from datasets.transforms import Video_Augmentation_Transform
from util.visualization import visualize_concated_frames, concat_image_lists_horizontal
from config.label_correspondence import unique_intra_label_mapping, unique_inter_label_mapping
from test_code.inference import single_video_infernece, dump_list_of_dict
from evaluation.evaluate_SBD import evaluate_metrics
# Video Transform
video_transform = Video_Augmentation_Transform(set_type = "val")
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--checkpoint_path",
type = str,
default = "/scratch/usy5km/Cut_Anything/cut_anything_checkpoints/results_training_v12/ckpt_epoch82.pth",
help = "Path to checkpoint file."
)
parser.add_argument(
"--test_dataset_pkl_path",
type = str,
default = "/scratch/usy5km/Cut_Anything/cut_anything_benchmark/labels_5round.pkl",
help = "Path to test dataset pkl file."
)
parser.add_argument(
"--result_store_path",
type = str,
default = "results.json",
help="Path to save result json."
)
parser.add_argument(
"--num_context_frames",
type = int,
default = 0,
help = "Path to save result json."
)
parser.add_argument(
"--visual_store_folder_path",
type = str,
default = None,
help = "Path to save visualization results. Set to None to disable."
)
parser.add_argument(
"--merge_sudden_jump",
action = "store_true",
default = False,
help = "Whether to merge sudden jump."
)
return parser.parse_args()
if __name__ == '__main__':
# Setting
inference_args = parse_args()
checkpoint_path = inference_args.checkpoint_path
test_dataset_pkl_path = inference_args.test_dataset_pkl_path
result_store_path = inference_args.result_store_path
visual_store_folder_path = inference_args.visual_store_folder_path
merge_sudden_jump = inference_args.merge_sudden_jump
# Prepare the folder
if visual_store_folder_path is not None:
os.makedirs(visual_store_folder_path, exist_ok = True)
# Load Checkpoint & Model Config
assert(os.path.exists(checkpoint_path))
state_dict = torch.load(checkpoint_path, map_location='cpu')
model_args = state_dict['args']
print("Checkpoint stored args are", model_args)
# Init the Model
print("Load OmniShotCut Model!")
backbone = build_backbone(model_args)
transformer = build_transformer(model_args)
model = OmniShotCut(
backbone,
transformer,
num_intra_relation_classes = model_args.num_intra_relation_classes,
num_inter_relation_classes = model_args.num_inter_relation_classes,
num_frames = model_args.max_process_window_length,
num_queries = model_args.num_queries,
aux_loss = model_args.aux_loss,
)
model.load_state_dict(state_dict['model'], strict=True)
model.to("cuda")
model.eval()
# Read the pkl file
with open(test_dataset_pkl_path, "rb") as f:
test_data = pickle.load(f)
# Iterate all cases
print("Start Inference!")
pred_results = []
start_time = time.time()
for instance_idx, info_dict in enumerate(tqdm(test_data, desc="Testing")):
# Fetch info
video_path = info_dict["video_path"]
if not os.path.exists(video_path):
print("We cannot find", video_path)
assert(False)
# print("video path is", video_path, "for instance", instance_idx)
# Init result log
pred_result = {}
pred_result["video_path"] = video_path
pred_result["gt_ranges"] = info_dict["ranges"]
pred_result["gt_intra_labels"] = info_dict["intra_labels"]
pred_result["gt_inter_labels"] = info_dict["inter_labels"]
pred_result["gt_confidences"] = info_dict["confidences"]
# Do the single inference
pred_ranges_full, pred_intra_labels_full, pred_inter_labels_full, video_np_full = single_video_infernece(video_path, model, model_args, inference_args)
# Append prediction resutls
pred_result["pred_ranges"] = pred_ranges_full
pred_result["pred_intra_labels"] = pred_intra_labels_full
pred_result["pred_inter_labels"] = pred_inter_labels_full
pred_results.append(pred_result)
# Visualize
if visual_store_folder_path is not None:
# Visualize predictions
prediction_visual_store_path = os.path.join(visual_store_folder_path, "instance" + str(instance_idx) + "_pred")
if os.path.exists(prediction_visual_store_path):
shutil.rmtree(prediction_visual_store_path)
pred_saved_paths = visualize_concated_frames(video_np_full, prediction_visual_store_path, pred_ranges_full, max_frames_per_img=264, end_range_exclusive=True, fps=fps, start_index = 0)
# Visualize the GT results
gt_visual_store_path = os.path.join(visual_store_folder_path, "instance" + str(instance_idx) + "_gt")
if os.path.exists(gt_visual_store_path):
shutil.rmtree(gt_visual_store_path)
gt_ranges_full = info_dict['ranges']
gt_saved_paths = visualize_concated_frames(video_np_full, gt_visual_store_path, gt_ranges_full, max_frames_per_img=264, end_range_exclusive=True, fps=fps, start_index = 0)
# Merge Pred and GT on One for easier visual
merged_visual_store_path = os.path.join(visual_store_folder_path, "instance" + str(instance_idx) + "_merged")
if os.path.exists(merged_visual_store_path):
shutil.rmtree(merged_visual_store_path)
merged_paths = concat_image_lists_horizontal( # Left: ours predictions;Right: GT
list1 = pred_saved_paths,
list2 = gt_saved_paths,
out_dir = merged_visual_store_path,
bar_width = 80,
bar_color = (0, 255, 0), # (0, 255, 0) is green
)
# Store the result as json
dump_list_of_dict(pred_results, result_store_path)
# Do the evluation here automatically
evaluate_metrics(result_store_path, last_frame_exclusive=True)
# Final Log
print("Total time spent is", int(time.time() - start_time), "s!")
print("Finished!")
|