File size: 8,205 Bytes
796e051
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
'''
    This file is for OmniShotCut benchmark massive testing
'''
import os, sys, shutil
import argparse
import numpy as np
import math
import subprocess
import cv2
from tqdm import tqdm
import ffmpeg
import time
import torch
import json
import torchvision.transforms as T
from torch.utils.data import DataLoader
import pickle
import warnings
warnings.filterwarnings("ignore", category=UserWarning)


# Import files from the local folder
root_path = os.path.abspath('.')
sys.path.append(root_path)
from config.argument_setting import get_args_parser
from architecture.backbone import build_backbone
from architecture.transformer import build_transformer
from architecture.model import OmniShotCut
from datasets.transforms import Video_Augmentation_Transform
from util.visualization import visualize_concated_frames, concat_image_lists_horizontal
from config.label_correspondence import unique_intra_label_mapping, unique_inter_label_mapping
from test_code.inference import single_video_infernece, dump_list_of_dict
from evaluation.evaluate_SBD import evaluate_metrics

# Video Transform
video_transform = Video_Augmentation_Transform(set_type = "val")





def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument(
                            "--checkpoint_path",
                            type = str,
                            default = "/scratch/usy5km/Cut_Anything/cut_anything_checkpoints/results_training_v12/ckpt_epoch82.pth",
                            help = "Path to checkpoint file."
                        )
    parser.add_argument(
                            "--test_dataset_pkl_path",
                            type = str,
                            default = "/scratch/usy5km/Cut_Anything/cut_anything_benchmark/labels_5round.pkl",
                            help = "Path to test dataset pkl file."
                        )
    parser.add_argument(
                            "--result_store_path",
                            type = str,
                            default = "results.json",
                            help="Path to save result json."
                        )
    parser.add_argument(
                            "--num_context_frames",
                            type = int,
                            default = 0,
                            help = "Path to save result json."
                        )
    parser.add_argument(
                            "--visual_store_folder_path",
                            type = str,
                            default = None,
                            help = "Path to save visualization results. Set to None to disable."
                        )
    parser.add_argument(
                            "--merge_sudden_jump",
                            action = "store_true",
                            default = False,
                            help = "Whether to merge sudden jump."
                        )

    return parser.parse_args()




if __name__ == '__main__':

    # Setting
    inference_args = parse_args()
    checkpoint_path = inference_args.checkpoint_path
    test_dataset_pkl_path = inference_args.test_dataset_pkl_path
    result_store_path = inference_args.result_store_path
    visual_store_folder_path = inference_args.visual_store_folder_path
    merge_sudden_jump = inference_args.merge_sudden_jump
    


    # Prepare the folder
    if visual_store_folder_path is not None:
        os.makedirs(visual_store_folder_path, exist_ok = True)


    # Load Checkpoint & Model Config
    assert(os.path.exists(checkpoint_path))
    state_dict = torch.load(checkpoint_path, map_location='cpu')
    model_args = state_dict['args']
    print("Checkpoint stored args are", model_args)


    # Init the Model
    print("Load OmniShotCut Model!")
    backbone = build_backbone(model_args)
    transformer = build_transformer(model_args)
    model = OmniShotCut(
                            backbone,
                            transformer,
                            num_intra_relation_classes = model_args.num_intra_relation_classes,
                            num_inter_relation_classes = model_args.num_inter_relation_classes,
                            num_frames = model_args.max_process_window_length, 
                            num_queries = model_args.num_queries,
                            aux_loss = model_args.aux_loss,
                        )
    model.load_state_dict(state_dict['model'], strict=True)
    model.to("cuda")
    model.eval()
    


    # Read the pkl file
    with open(test_dataset_pkl_path, "rb") as f:
        test_data = pickle.load(f)



    # Iterate all cases
    print("Start Inference!")
    pred_results = []
    start_time = time.time()
    for instance_idx, info_dict in enumerate(tqdm(test_data, desc="Testing")):
        
        # Fetch info
        video_path = info_dict["video_path"]
        if not os.path.exists(video_path):
            print("We cannot find", video_path)
            assert(False)
        # print("video path is", video_path, "for instance", instance_idx)
        

        # Init result log
        pred_result = {}
        pred_result["video_path"] = video_path
        pred_result["gt_ranges"] = info_dict["ranges"]
        pred_result["gt_intra_labels"] = info_dict["intra_labels"]
        pred_result["gt_inter_labels"] = info_dict["inter_labels"]
        pred_result["gt_confidences"] = info_dict["confidences"]


        # Do the single inference
        pred_ranges_full, pred_intra_labels_full, pred_inter_labels_full, video_np_full = single_video_infernece(video_path, model, model_args, inference_args)


        # Append prediction resutls
        pred_result["pred_ranges"] = pred_ranges_full
        pred_result["pred_intra_labels"] = pred_intra_labels_full
        pred_result["pred_inter_labels"] = pred_inter_labels_full
        pred_results.append(pred_result)



        # Visualize
        if visual_store_folder_path is not None:

            # Visualize predictions
            prediction_visual_store_path = os.path.join(visual_store_folder_path, "instance" + str(instance_idx) + "_pred") 
            if os.path.exists(prediction_visual_store_path):
                shutil.rmtree(prediction_visual_store_path)
            pred_saved_paths = visualize_concated_frames(video_np_full, prediction_visual_store_path, pred_ranges_full, max_frames_per_img=264, end_range_exclusive=True, fps=fps, start_index = 0)


            # Visualize the GT results
            gt_visual_store_path = os.path.join(visual_store_folder_path, "instance" + str(instance_idx) + "_gt") 
            if os.path.exists(gt_visual_store_path):
                shutil.rmtree(gt_visual_store_path)
            gt_ranges_full = info_dict['ranges']
            gt_saved_paths = visualize_concated_frames(video_np_full, gt_visual_store_path, gt_ranges_full, max_frames_per_img=264, end_range_exclusive=True, fps=fps, start_index = 0)


            # Merge Pred and GT on One for easier visual
            merged_visual_store_path = os.path.join(visual_store_folder_path, "instance" + str(instance_idx) + "_merged") 
            if os.path.exists(merged_visual_store_path):
                shutil.rmtree(merged_visual_store_path)
            merged_paths = concat_image_lists_horizontal(           # Left: ours predictions;Right: GT
                                                            list1 = pred_saved_paths,
                                                            list2 = gt_saved_paths,
                                                            out_dir = merged_visual_store_path,
                                                            bar_width = 80,                     
                                                            bar_color = (0, 255, 0),            # (0, 255, 0) is green
                                                        )


    # Store the result as json
    dump_list_of_dict(pred_results, result_store_path)
    


    # Do the evluation here automatically
    evaluate_metrics(result_store_path, last_frame_exclusive=True)



    # Final Log
    print("Total time spent is", int(time.time() - start_time), "s!")
    print("Finished!")