Spaces:
Running on Zero
Running on Zero
| ''' | |
| This file is for OmniShotCut benchmark massive testing | |
| ''' | |
| import os, sys, shutil | |
| import argparse | |
| import numpy as np | |
| import math | |
| import subprocess | |
| import cv2 | |
| from tqdm import tqdm | |
| import ffmpeg | |
| import time | |
| import torch | |
| import json | |
| import torchvision.transforms as T | |
| from torch.utils.data import DataLoader | |
| import pickle | |
| import warnings | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| # Import files from the local folder | |
| root_path = os.path.abspath('.') | |
| sys.path.append(root_path) | |
| from config.argument_setting import get_args_parser | |
| from architecture.backbone import build_backbone | |
| from architecture.transformer import build_transformer | |
| from architecture.model import OmniShotCut | |
| from datasets.transforms import Video_Augmentation_Transform | |
| from util.visualization import visualize_concated_frames, concat_image_lists_horizontal | |
| from config.label_correspondence import unique_intra_label_mapping, unique_inter_label_mapping | |
| from test_code.inference import single_video_infernece, dump_list_of_dict | |
| from evaluation.evaluate_SBD import evaluate_metrics | |
| # Video Transform | |
| video_transform = Video_Augmentation_Transform(set_type = "val") | |
| def parse_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--checkpoint_path", | |
| type = str, | |
| default = "/scratch/usy5km/Cut_Anything/cut_anything_checkpoints/results_training_v12/ckpt_epoch82.pth", | |
| help = "Path to checkpoint file." | |
| ) | |
| parser.add_argument( | |
| "--test_dataset_pkl_path", | |
| type = str, | |
| default = "/scratch/usy5km/Cut_Anything/cut_anything_benchmark/labels_5round.pkl", | |
| help = "Path to test dataset pkl file." | |
| ) | |
| parser.add_argument( | |
| "--result_store_path", | |
| type = str, | |
| default = "results.json", | |
| help="Path to save result json." | |
| ) | |
| parser.add_argument( | |
| "--num_context_frames", | |
| type = int, | |
| default = 0, | |
| help = "Path to save result json." | |
| ) | |
| parser.add_argument( | |
| "--visual_store_folder_path", | |
| type = str, | |
| default = None, | |
| help = "Path to save visualization results. Set to None to disable." | |
| ) | |
| parser.add_argument( | |
| "--merge_sudden_jump", | |
| action = "store_true", | |
| default = False, | |
| help = "Whether to merge sudden jump." | |
| ) | |
| return parser.parse_args() | |
| if __name__ == '__main__': | |
| # Setting | |
| inference_args = parse_args() | |
| checkpoint_path = inference_args.checkpoint_path | |
| test_dataset_pkl_path = inference_args.test_dataset_pkl_path | |
| result_store_path = inference_args.result_store_path | |
| visual_store_folder_path = inference_args.visual_store_folder_path | |
| merge_sudden_jump = inference_args.merge_sudden_jump | |
| # Prepare the folder | |
| if visual_store_folder_path is not None: | |
| os.makedirs(visual_store_folder_path, exist_ok = True) | |
| # Load Checkpoint & Model Config | |
| assert(os.path.exists(checkpoint_path)) | |
| state_dict = torch.load(checkpoint_path, map_location='cpu') | |
| model_args = state_dict['args'] | |
| print("Checkpoint stored args are", model_args) | |
| # Init the Model | |
| print("Load OmniShotCut Model!") | |
| backbone = build_backbone(model_args) | |
| transformer = build_transformer(model_args) | |
| model = OmniShotCut( | |
| backbone, | |
| transformer, | |
| num_intra_relation_classes = model_args.num_intra_relation_classes, | |
| num_inter_relation_classes = model_args.num_inter_relation_classes, | |
| num_frames = model_args.max_process_window_length, | |
| num_queries = model_args.num_queries, | |
| aux_loss = model_args.aux_loss, | |
| ) | |
| model.load_state_dict(state_dict['model'], strict=True) | |
| model.to("cuda") | |
| model.eval() | |
| # Read the pkl file | |
| with open(test_dataset_pkl_path, "rb") as f: | |
| test_data = pickle.load(f) | |
| # Iterate all cases | |
| print("Start Inference!") | |
| pred_results = [] | |
| start_time = time.time() | |
| for instance_idx, info_dict in enumerate(tqdm(test_data, desc="Testing")): | |
| # Fetch info | |
| video_path = info_dict["video_path"] | |
| if not os.path.exists(video_path): | |
| print("We cannot find", video_path) | |
| assert(False) | |
| # print("video path is", video_path, "for instance", instance_idx) | |
| # Init result log | |
| pred_result = {} | |
| pred_result["video_path"] = video_path | |
| pred_result["gt_ranges"] = info_dict["ranges"] | |
| pred_result["gt_intra_labels"] = info_dict["intra_labels"] | |
| pred_result["gt_inter_labels"] = info_dict["inter_labels"] | |
| pred_result["gt_confidences"] = info_dict["confidences"] | |
| # Do the single inference | |
| pred_ranges_full, pred_intra_labels_full, pred_inter_labels_full, video_np_full = single_video_infernece(video_path, model, model_args, inference_args) | |
| # Append prediction resutls | |
| pred_result["pred_ranges"] = pred_ranges_full | |
| pred_result["pred_intra_labels"] = pred_intra_labels_full | |
| pred_result["pred_inter_labels"] = pred_inter_labels_full | |
| pred_results.append(pred_result) | |
| # Visualize | |
| if visual_store_folder_path is not None: | |
| # Visualize predictions | |
| prediction_visual_store_path = os.path.join(visual_store_folder_path, "instance" + str(instance_idx) + "_pred") | |
| if os.path.exists(prediction_visual_store_path): | |
| shutil.rmtree(prediction_visual_store_path) | |
| pred_saved_paths = visualize_concated_frames(video_np_full, prediction_visual_store_path, pred_ranges_full, max_frames_per_img=264, end_range_exclusive=True, fps=fps, start_index = 0) | |
| # Visualize the GT results | |
| gt_visual_store_path = os.path.join(visual_store_folder_path, "instance" + str(instance_idx) + "_gt") | |
| if os.path.exists(gt_visual_store_path): | |
| shutil.rmtree(gt_visual_store_path) | |
| gt_ranges_full = info_dict['ranges'] | |
| gt_saved_paths = visualize_concated_frames(video_np_full, gt_visual_store_path, gt_ranges_full, max_frames_per_img=264, end_range_exclusive=True, fps=fps, start_index = 0) | |
| # Merge Pred and GT on One for easier visual | |
| merged_visual_store_path = os.path.join(visual_store_folder_path, "instance" + str(instance_idx) + "_merged") | |
| if os.path.exists(merged_visual_store_path): | |
| shutil.rmtree(merged_visual_store_path) | |
| merged_paths = concat_image_lists_horizontal( # Left: ours predictions;Right: GT | |
| list1 = pred_saved_paths, | |
| list2 = gt_saved_paths, | |
| out_dir = merged_visual_store_path, | |
| bar_width = 80, | |
| bar_color = (0, 255, 0), # (0, 255, 0) is green | |
| ) | |
| # Store the result as json | |
| dump_list_of_dict(pred_results, result_store_path) | |
| # Do the evluation here automatically | |
| evaluate_metrics(result_store_path, last_frame_exclusive=True) | |
| # Final Log | |
| print("Total time spent is", int(time.time() - start_time), "s!") | |
| print("Finished!") | |