| |
| import subprocess |
|
|
| """Replace with the correct path to your install.sh script""" |
| install_script_path = "script/install.sh" |
|
|
| """Run the shell command""" |
| subprocess.run(["bash", install_script_path]) |
|
|
| |
| from PIL.ImageOps import colorize, scale |
| import gradio as gr |
| import importlib |
| import sys |
| import os |
| import pdb |
| from matplotlib.pyplot import step |
|
|
| from model_args import segtracker_args,sam_args,aot_args |
| from SegTracker import SegTracker |
| from tool.transfer_tools import draw_outline, draw_points |
| |
| |
|
|
| import cv2 |
| from PIL import Image |
| from skimage.morphology.binary import binary_dilation |
| import argparse |
| import torch |
| import time, math |
| from seg_track_anything import aot_model2ckpt, tracking_objects_in_video, draw_mask |
| import gc |
| import numpy as np |
| import json |
| from tool.transfer_tools import mask2bbox |
|
|
| def clean(): |
| return None, None, None, None, None, None, [[], []] |
|
|
| def get_click_prompt(click_stack, point): |
|
|
| click_stack[0].append(point["coord"]) |
| click_stack[1].append(point["mode"] |
| ) |
| |
| prompt = { |
| "points_coord":click_stack[0], |
| "points_mode":click_stack[1], |
| "multimask":"True", |
| } |
|
|
| return prompt |
|
|
| def get_meta_from_video(input_video): |
| if input_video is None: |
| return None, None, None, "" |
|
|
| print("get meta information of input video") |
| cap = cv2.VideoCapture(input_video) |
| |
| _, first_frame = cap.read() |
| cap.release() |
|
|
| first_frame = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB) |
|
|
| return first_frame, first_frame, first_frame, "" |
|
|
| def get_meta_from_img_seq(input_img_seq): |
| if input_img_seq is None: |
| return None, None, None, "" |
|
|
| print("get meta information of img seq") |
| |
| file_name = input_img_seq.name.split('/')[-1].split('.')[0] |
| file_path = f'./assets/{file_name}' |
| if os.path.isdir(file_path): |
| os.system(f'rm -r {file_path}') |
| os.makedirs(file_path) |
| |
| os.system(f'unzip {input_img_seq.name} -d ./assets ') |
| |
| imgs_path = sorted([os.path.join(file_path, img_name) for img_name in os.listdir(file_path)]) |
| first_frame = imgs_path[0] |
| first_frame = cv2.imread(first_frame) |
| first_frame = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB) |
|
|
| return first_frame, first_frame, first_frame, "" |
|
|
| def SegTracker_add_first_frame(Seg_Tracker, origin_frame, predicted_mask): |
| with torch.cuda.amp.autocast(): |
| |
| frame_idx = 0 |
| Seg_Tracker.restart_tracker() |
| Seg_Tracker.add_reference(origin_frame, predicted_mask, frame_idx) |
| Seg_Tracker.first_frame_mask = predicted_mask |
|
|
| return Seg_Tracker |
|
|
| def init_SegTracker(aot_model, long_term_mem, max_len_long_term, sam_gap, max_obj_num, points_per_side, origin_frame): |
| |
| if origin_frame is None: |
| return None, origin_frame, [[], []], "" |
|
|
| |
| aot_args["model"] = aot_model |
| aot_args["model_path"] = aot_model2ckpt[aot_model] |
| aot_args["long_term_mem_gap"] = long_term_mem |
| aot_args["max_len_long_term"] = max_len_long_term |
| |
| segtracker_args["sam_gap"] = sam_gap |
| segtracker_args["max_obj_num"] = max_obj_num |
| sam_args["generator_args"]["points_per_side"] = points_per_side |
| |
| Seg_Tracker = SegTracker(segtracker_args, sam_args, aot_args) |
| Seg_Tracker.restart_tracker() |
|
|
| return Seg_Tracker, origin_frame, [[], []], "" |
|
|
| def init_SegTracker_Stroke(aot_model, long_term_mem, max_len_long_term, sam_gap, max_obj_num, points_per_side, origin_frame): |
| |
| if origin_frame is None: |
| return None, origin_frame, [[], []], origin_frame |
|
|
| |
| aot_args["model"] = aot_model |
| aot_args["model_path"] = aot_model2ckpt[aot_model] |
| aot_args["long_term_mem_gap"] = long_term_mem |
| aot_args["max_len_long_term"] = max_len_long_term |
|
|
| |
| segtracker_args["sam_gap"] = sam_gap |
| segtracker_args["max_obj_num"] = max_obj_num |
| sam_args["generator_args"]["points_per_side"] = points_per_side |
| |
| Seg_Tracker = SegTracker(segtracker_args, sam_args, aot_args) |
| Seg_Tracker.restart_tracker() |
| return Seg_Tracker, origin_frame, [[], []], origin_frame |
|
|
| def undo_click_stack_and_refine_seg(Seg_Tracker, origin_frame, click_stack, aot_model, long_term_mem, max_len_long_term, sam_gap, max_obj_num, points_per_side): |
| |
| if Seg_Tracker is None: |
| return Seg_Tracker, origin_frame, [[], []] |
|
|
| print("Undo!") |
| if len(click_stack[0]) > 0: |
| click_stack[0] = click_stack[0][: -1] |
| click_stack[1] = click_stack[1][: -1] |
| |
| if len(click_stack[0]) > 0: |
| prompt = { |
| "points_coord":click_stack[0], |
| "points_mode":click_stack[1], |
| "multimask":"True", |
| } |
|
|
| masked_frame = seg_acc_click(Seg_Tracker, prompt, origin_frame) |
| return Seg_Tracker, masked_frame, click_stack |
| else: |
| return Seg_Tracker, origin_frame, [[], []] |
|
|
| def roll_back_undo_click_stack_and_refine_seg(Seg_Tracker, origin_frame, click_stack, aot_model, long_term_mem, max_len_long_term, sam_gap, max_obj_num, points_per_side,input_video, input_img_seq, frame_num, refine_idx): |
| |
| if Seg_Tracker is None: |
| return Seg_Tracker, origin_frame, [[], []] |
|
|
| print("Undo!") |
| if len(click_stack[0]) > 0: |
| click_stack[0] = click_stack[0][: -1] |
| click_stack[1] = click_stack[1][: -1] |
| |
| if len(click_stack[0]) > 0: |
| prompt = { |
| "points_coord":click_stack[0], |
| "points_mode":click_stack[1], |
| "multimask":"True", |
| } |
|
|
| chosen_frame_show, curr_mask, ori_frame = res_by_num(input_video, input_img_seq, frame_num) |
| Seg_Tracker.curr_idx = refine_idx |
| predicted_mask, masked_frame = Seg_Tracker.seg_acc_click( |
| origin_frame=origin_frame, |
| coords=np.array(prompt["points_coord"]), |
| modes=np.array(prompt["points_mode"]), |
| multimask=prompt["multimask"], |
| ) |
| curr_mask[curr_mask == refine_idx] = 0 |
| curr_mask[predicted_mask != 0] = refine_idx |
| predicted_mask=curr_mask |
| Seg_Tracker = SegTracker_add_first_frame(Seg_Tracker, origin_frame, predicted_mask) |
| return Seg_Tracker, masked_frame, click_stack |
| else: |
| return Seg_Tracker, origin_frame, [[], []] |
|
|
|
|
| def seg_acc_click(Seg_Tracker, prompt, origin_frame): |
| |
| predicted_mask, masked_frame = Seg_Tracker.seg_acc_click( |
| origin_frame=origin_frame, |
| coords=np.array(prompt["points_coord"]), |
| modes=np.array(prompt["points_mode"]), |
| multimask=prompt["multimask"], |
| ) |
|
|
| Seg_Tracker = SegTracker_add_first_frame(Seg_Tracker, origin_frame, predicted_mask) |
|
|
| return masked_frame |
|
|
|
|
| def sam_click(Seg_Tracker, origin_frame, point_mode, click_stack, aot_model, long_term_mem, max_len_long_term, sam_gap, max_obj_num, points_per_side, evt:gr.SelectData): |
| """ |
| Args: |
| origin_frame: nd.array |
| click_stack: [[coordinate], [point_mode]] |
| """ |
|
|
| print("Click") |
|
|
| if point_mode == "Positive": |
| point = {"coord": [evt.index[0], evt.index[1]], "mode": 1} |
| else: |
| |
| point = {"coord": [evt.index[0], evt.index[1]], "mode": 0} |
|
|
| if Seg_Tracker is None: |
| Seg_Tracker, _, _, _ = init_SegTracker(aot_model, long_term_mem, max_len_long_term, sam_gap, max_obj_num, points_per_side, origin_frame) |
|
|
| |
| click_prompt = get_click_prompt(click_stack, point) |
|
|
| |
| masked_frame = seg_acc_click(Seg_Tracker, click_prompt, origin_frame) |
|
|
| return Seg_Tracker, masked_frame, click_stack |
|
|
|
|
| def roll_back_sam_click(Seg_Tracker, origin_frame, point_mode, click_stack, aot_model, long_term_mem, max_len_long_term, sam_gap, max_obj_num, points_per_side, input_video, input_img_seq, frame_num, refine_idx, evt:gr.SelectData): |
| """ |
| Args: |
| origin_frame: nd.array |
| click_stack: [[coordinate], [point_mode]] |
| """ |
|
|
| print("Click") |
|
|
| if point_mode == "Positive": |
| point = {"coord": [evt.index[0], evt.index[1]], "mode": 1} |
| else: |
| |
| point = {"coord": [evt.index[0], evt.index[1]], "mode": 0} |
|
|
| if Seg_Tracker is None: |
| Seg_Tracker, _, _, _ = init_SegTracker(aot_model, long_term_mem, max_len_long_term, sam_gap, max_obj_num, points_per_side, origin_frame) |
|
|
| |
| prompt = get_click_prompt(click_stack, point) |
|
|
| chosen_frame_show, curr_mask, ori_frame = res_by_num(input_video, input_img_seq, frame_num) |
|
|
| Seg_Tracker.curr_idx = refine_idx |
|
|
| predicted_mask, masked_frame = Seg_Tracker.seg_acc_click( |
| origin_frame=origin_frame, |
| coords=np.array(prompt["points_coord"]), |
| modes=np.array(prompt["points_mode"]), |
| multimask=prompt["multimask"], |
| ) |
| curr_mask[curr_mask == refine_idx] = 0 |
| curr_mask[predicted_mask != 0] = refine_idx |
| predicted_mask=curr_mask |
|
|
|
|
| Seg_Tracker = SegTracker_add_first_frame(Seg_Tracker, origin_frame, predicted_mask) |
|
|
| return Seg_Tracker, masked_frame, click_stack |
|
|
| def sam_stroke(Seg_Tracker, origin_frame, drawing_board, aot_model, long_term_mem, max_len_long_term, sam_gap, max_obj_num, points_per_side): |
|
|
| if Seg_Tracker is None: |
| Seg_Tracker, _ , _, _ = init_SegTracker(aot_model, long_term_mem, max_len_long_term, sam_gap, max_obj_num, points_per_side, origin_frame) |
| |
| print("Stroke") |
| mask = drawing_board["mask"] |
| bbox = mask2bbox(mask[:, :, 0]) |
| predicted_mask, masked_frame = Seg_Tracker.seg_acc_bbox(origin_frame, bbox) |
|
|
| Seg_Tracker = SegTracker_add_first_frame(Seg_Tracker, origin_frame, predicted_mask) |
|
|
| return Seg_Tracker, masked_frame, origin_frame |
|
|
| def gd_detect(Seg_Tracker, origin_frame, grounding_caption, box_threshold, text_threshold, aot_model, long_term_mem, max_len_long_term, sam_gap, max_obj_num, points_per_side): |
| if Seg_Tracker is None: |
| Seg_Tracker, _ , _, _ = init_SegTracker(aot_model, long_term_mem, max_len_long_term, sam_gap, max_obj_num, points_per_side, origin_frame) |
|
|
| print("Detect") |
| predicted_mask, annotated_frame= Seg_Tracker.detect_and_seg(origin_frame, grounding_caption, box_threshold, text_threshold) |
|
|
| Seg_Tracker = SegTracker_add_first_frame(Seg_Tracker, origin_frame, predicted_mask) |
|
|
|
|
| masked_frame = draw_mask(annotated_frame, predicted_mask) |
|
|
| return Seg_Tracker, masked_frame, origin_frame |
|
|
| def segment_everything(Seg_Tracker, aot_model, long_term_mem, max_len_long_term, origin_frame, sam_gap, max_obj_num, points_per_side): |
| |
| if Seg_Tracker is None: |
| Seg_Tracker, _ , _, _ = init_SegTracker(aot_model, long_term_mem, max_len_long_term, sam_gap, max_obj_num, points_per_side, origin_frame) |
|
|
| print("Everything") |
|
|
| frame_idx = 0 |
|
|
| with torch.cuda.amp.autocast(): |
| pred_mask = Seg_Tracker.seg(origin_frame) |
| torch.cuda.empty_cache() |
| gc.collect() |
| Seg_Tracker.add_reference(origin_frame, pred_mask, frame_idx) |
| Seg_Tracker.first_frame_mask = pred_mask |
|
|
| masked_frame = draw_mask(origin_frame.copy(), pred_mask) |
|
|
| return Seg_Tracker, masked_frame |
|
|
| def add_new_object(Seg_Tracker): |
|
|
| prev_mask = Seg_Tracker.first_frame_mask |
| Seg_Tracker.update_origin_merged_mask(prev_mask) |
| Seg_Tracker.curr_idx += 1 |
|
|
| print("Ready to add new object!") |
|
|
| return Seg_Tracker, [[], []] |
|
|
| def tracking_objects(Seg_Tracker, input_video, input_img_seq, fps, frame_num=0): |
| print("Start tracking !") |
| |
| |
| |
| return tracking_objects_in_video(Seg_Tracker, input_video, input_img_seq, fps, frame_num) |
|
|
|
|
| def res_by_num(input_video, input_img_seq, frame_num): |
| if input_video is not None: |
| video_name = os.path.basename(input_video).split('.')[0] |
|
|
| cap = cv2.VideoCapture(input_video) |
| for i in range(0,frame_num+1): |
| _, ori_frame = cap.read() |
| cap.release() |
| ori_frame = cv2.cvtColor(ori_frame, cv2.COLOR_BGR2RGB) |
| elif input_img_seq is not None: |
| file_name = input_img_seq.name.split('/')[-1].split('.')[0] |
| file_path = f'./assets/{file_name}' |
| video_name = file_name |
|
|
| imgs_path = sorted([os.path.join(file_path, img_name) for img_name in os.listdir(file_path)]) |
| ori_frame = imgs_path[frame_num] |
| ori_frame = cv2.imread(ori_frame) |
| ori_frame = cv2.cvtColor(ori_frame, cv2.COLOR_BGR2RGB) |
| else: |
| return None, None, None |
|
|
| tracking_result_dir = f'{os.path.join(os.path.dirname(__file__), "tracking_results", f"{video_name}")}' |
| output_masked_frame_dir = f'{tracking_result_dir}/{video_name}_masked_frames' |
| output_masked_frame_path = sorted([os.path.join(output_masked_frame_dir, img_name) for img_name in os.listdir(output_masked_frame_dir)]) |
|
|
| output_mask_dir = f'{tracking_result_dir}/{video_name}_masks' |
| output_mask_path = sorted([os.path.join(output_mask_dir, img_name) for img_name in os.listdir(output_mask_dir)]) |
|
|
|
|
| if len(output_masked_frame_path) == 0: |
| return None, None, None |
| else: |
| if frame_num >= len(output_masked_frame_path): |
| print("num out of frames range") |
| return None, None, None |
| else: |
| print("choose", frame_num, "to refine") |
| chosen_frame_show = output_masked_frame_path[frame_num] |
| chosen_frame_show = cv2.imread(chosen_frame_show) |
| chosen_frame_show = cv2.cvtColor(chosen_frame_show, cv2.COLOR_BGR2RGB) |
|
|
| chosen_mask = output_mask_path[frame_num] |
| chosen_mask = cv2.imread(chosen_mask) |
|
|
| chosen_mask = Image.open(output_mask_path[frame_num]).convert('P') |
| chosen_mask = np.array(chosen_mask) |
|
|
| return chosen_frame_show, chosen_mask, ori_frame |
|
|
| def show_res_by_slider(input_video, input_img_seq, frame_per): |
| if input_video is not None: |
| video_name = os.path.basename(input_video).split('.')[0] |
| elif input_img_seq is not None: |
| file_name = input_img_seq.name.split('/')[-1].split('.')[0] |
| file_path = f'./assets/{file_name}' |
| video_name = file_name |
| else: |
| print("Not find output res") |
| return None, None |
|
|
| tracking_result_dir = f'{os.path.join(os.path.dirname(__file__), "tracking_results", f"{video_name}")}' |
| output_masked_frame_dir = f'{tracking_result_dir}/{video_name}_masked_frames' |
| output_masked_frame_path = sorted([os.path.join(output_masked_frame_dir, img_name) for img_name in os.listdir(output_masked_frame_dir)]) |
| total_frames_num = len(output_masked_frame_path) |
| if total_frames_num == 0: |
| print("Not find output res") |
| return None, None |
| else: |
| frame_num = math.floor(total_frames_num * frame_per / 100) |
| if frame_per == 100: |
| frame_num = frame_num -1 |
| chosen_frame_show, _, _ = res_by_num(input_video, input_img_seq, frame_num) |
| return chosen_frame_show, frame_num |
|
|
| def choose_obj_to_refine(input_video, input_img_seq, Seg_Tracker, frame_num, evt:gr.SelectData): |
| chosen_frame_show, curr_mask, _ = res_by_num(input_video, input_img_seq, frame_num) |
| |
| |
| if curr_mask is not None and chosen_frame_show is not None: |
| idx = curr_mask[evt.index[1],evt.index[0]] |
| curr_idx_mask = np.where(curr_mask == idx, 1, 0).astype(np.uint8) |
| chosen_frame_show = draw_points(points=np.array([[evt.index[0],evt.index[1]]]), modes=np.array([[1]]), frame=chosen_frame_show) |
| chosen_frame_show = draw_outline(mask=curr_idx_mask, frame=chosen_frame_show) |
| print(idx) |
| |
| return chosen_frame_show, idx |
|
|
| def show_chosen_idx_to_refine(aot_model, long_term_mem, max_len_long_term, sam_gap, max_obj_num, points_per_side, input_video, input_img_seq, Seg_Tracker, frame_num, idx): |
| chosen_frame_show, curr_mask, ori_frame = res_by_num(input_video, input_img_seq, frame_num) |
| if Seg_Tracker is None: |
| print("reset aot args, new SegTracker") |
| Seg_Tracker, _ , _, _ = init_SegTracker(aot_model, long_term_mem, max_len_long_term, sam_gap, max_obj_num, points_per_side, ori_frame) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| Seg_Tracker.restart_tracker() |
| Seg_Tracker.curr_idx = 1 |
| Seg_Tracker.object_idx = 1 |
| Seg_Tracker.origin_merged_mask = None |
| Seg_Tracker.first_frame_mask = None |
| Seg_Tracker.reference_objs_list=[] |
| Seg_Tracker.everything_points = [] |
| Seg_Tracker.everything_labels = [] |
| Seg_Tracker.sam.have_embedded = False |
| Seg_Tracker.sam.interactive_predictor.features = None |
| return ori_frame, Seg_Tracker, ori_frame, [[], []], "" |
| |
|
|
|
|
|
|
| def seg_track_app(): |
|
|
| |
| |
| |
| app = gr.Blocks() |
|
|
| with app: |
| gr.Markdown( |
| ''' |
| <div style="text-align:center;"> |
| <span style="font-size:3em; font-weight:bold;">Segment and Track Anything(SAM-Track)</span> |
| </div> |
| ''' |
| ) |
|
|
| click_stack = gr.State([[],[]]) |
| origin_frame = gr.State(None) |
| Seg_Tracker = gr.State(None) |
|
|
| current_frame_num = gr.State(None) |
| refine_idx = gr.State(None) |
| frame_num = gr.State(None) |
|
|
| aot_model = gr.State(None) |
| sam_gap = gr.State(None) |
| points_per_side = gr.State(None) |
| max_obj_num = gr.State(None) |
|
|
| with gr.Row(): |
| |
| with gr.Column(scale=0.5): |
|
|
| tab_video_input = gr.Tab(label="Video type input") |
| with tab_video_input: |
| input_video = gr.Video(label='Input video').style(height=550) |
| |
| tab_img_seq_input = gr.Tab(label="Image-Seq type input") |
| with tab_img_seq_input: |
| with gr.Row(): |
| input_img_seq = gr.File(label='Input Image-Seq').style(height=550) |
| with gr.Column(scale=0.25): |
| extract_button = gr.Button(value="extract") |
| fps = gr.Slider(label='fps', minimum=5, maximum=50, value=8, step=1) |
|
|
| input_first_frame = gr.Image(label='Segment result of first frame',interactive=True).style(height=550) |
|
|
|
|
| tab_everything = gr.Tab(label="Everything") |
| with tab_everything: |
| with gr.Row(): |
| seg_every_first_frame = gr.Button(value="Segment everything for first frame", interactive=True) |
| point_mode = gr.Radio( |
| choices=["Positive"], |
| value="Positive", |
| label="Point Prompt", |
| interactive=True) |
|
|
| every_undo_but = gr.Button( |
| value="Undo", |
| interactive=True |
| ) |
|
|
| |
| |
| |
| |
|
|
| tab_click = gr.Tab(label="Click") |
| with tab_click: |
| with gr.Row(): |
| point_mode = gr.Radio( |
| choices=["Positive", "Negative"], |
| value="Positive", |
| label="Point Prompt", |
| interactive=True) |
|
|
| |
| click_undo_but = gr.Button( |
| value="Undo", |
| interactive=True |
| ) |
| |
| |
| |
| |
|
|
| tab_stroke = gr.Tab(label="Stroke") |
| with tab_stroke: |
| drawing_board = gr.Image(label='Drawing Board', tool="sketch", brush_radius=10, interactive=True) |
| with gr.Row(): |
| seg_acc_stroke = gr.Button(value="Segment", interactive=True) |
| |
| |
| |
| |
| |
| tab_text = gr.Tab(label="Text") |
| with tab_text: |
| grounding_caption = gr.Textbox(label="Detection Prompt") |
| detect_button = gr.Button(value="Detect") |
| with gr.Accordion("Advanced options", open=False): |
| with gr.Row(): |
| with gr.Column(scale=0.5): |
| box_threshold = gr.Slider( |
| label="Box Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001 |
| ) |
| with gr.Column(scale=0.5): |
| text_threshold = gr.Slider( |
| label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001 |
| ) |
|
|
| with gr.Row(): |
| with gr.Column(scale=0.5): |
| with gr.Tab(label="SegTracker Args"): |
| |
| points_per_side = gr.Slider( |
| label = "points_per_side", |
| minimum= 1, |
| step = 1, |
| maximum=100, |
| value=16, |
| interactive=True |
| ) |
|
|
| sam_gap = gr.Slider( |
| label='sam_gap', |
| minimum = 1, |
| step=1, |
| maximum = 9999, |
| value=100, |
| interactive=True, |
| ) |
|
|
| max_obj_num = gr.Slider( |
| label='max_obj_num', |
| minimum = 50, |
| step=1, |
| maximum = 300, |
| value=255, |
| interactive=True |
| ) |
| with gr.Accordion("aot advanced options", open=False): |
| aot_model = gr.Dropdown( |
| label="aot_model", |
| choices = [ |
| "deaotb", |
| "deaotl", |
| "r50_deaotl" |
| ], |
| value = "r50_deaotl", |
| interactive=True, |
| ) |
| long_term_mem = gr.Slider(label="long term memory gap", minimum=1, maximum=9999, value=9999, step=1) |
| max_len_long_term = gr.Slider(label="max len of long term memory", minimum=1, maximum=9999, value=9999, step=1) |
|
|
|
|
| |
| with gr.Column(): |
| new_object_button = gr.Button( |
| value="Add new object", |
| interactive=True |
| ) |
| reset_button = gr.Button( |
| value="Reset", |
| interactive=True, |
| ) |
| track_for_video = gr.Button( |
| value="Start Tracking", |
| interactive=True, |
| ) |
|
|
| with gr.Column(scale=0.5): |
| |
| output_video = gr.File(label="Predicted video") |
| output_mask = gr.File(label="Predicted masks") |
| with gr.Row(): |
| with gr.Column(scale=1): |
| with gr.Accordion("roll back options", open=False): |
| |
| |
| output_res = gr.Image(label='Segment result of all frames').style(height=550) |
| frame_per = gr.Slider( |
| label = "Percentage of Frames Viewed", |
| minimum= 0.0, |
| maximum= 100.0, |
| step=0.01, |
| value=0.0, |
| ) |
| frame_per.release(show_res_by_slider, inputs=[input_video, input_img_seq, frame_per], outputs=[output_res, frame_num]) |
| roll_back_button = gr.Button(value="Choose this mask to refine") |
| refine_res = gr.Image(label='Refine masks').style(height=550)\ |
|
|
| tab_roll_back_click = gr.Tab(label="Click") |
| with tab_roll_back_click: |
| with gr.Row(): |
| roll_back_point_mode = gr.Radio( |
| choices=["Positive", "Negative"], |
| value="Positive", |
| label="Point Prompt", |
| interactive=True) |
|
|
| |
| roll_back_click_undo_but = gr.Button( |
| value="Undo", |
| interactive=True |
| ) |
| roll_back_track_for_video = gr.Button( |
| value="Start tracking to refine", |
| interactive=True, |
| ) |
|
|
| |
| |
| |
|
|
| |
| input_video.change( |
| fn=get_meta_from_video, |
| inputs=[ |
| input_video |
| ], |
| outputs=[ |
| input_first_frame, origin_frame, drawing_board, grounding_caption |
| ] |
| ) |
|
|
| |
| input_img_seq.change( |
| fn=get_meta_from_img_seq, |
| inputs=[ |
| input_img_seq |
| ], |
| outputs=[ |
| input_first_frame, origin_frame, drawing_board, grounding_caption |
| ] |
| ) |
| |
| |
| tab_video_input.select( |
| fn = clean, |
| inputs=[], |
| outputs=[ |
| input_video, |
| input_img_seq, |
| Seg_Tracker, |
| input_first_frame, |
| origin_frame, |
| drawing_board, |
| click_stack, |
| ] |
| ) |
|
|
| tab_img_seq_input.select( |
| fn = clean, |
| inputs=[], |
| outputs=[ |
| input_video, |
| input_img_seq, |
| Seg_Tracker, |
| input_first_frame, |
| origin_frame, |
| drawing_board, |
| click_stack, |
| ] |
| ) |
|
|
|
|
| extract_button.click( |
| fn=get_meta_from_img_seq, |
| inputs=[ |
| input_img_seq |
| ], |
| outputs=[ |
| input_first_frame, origin_frame, drawing_board, grounding_caption |
| ] |
| ) |
|
|
|
|
| |
|
|
| |
| tab_everything.select( |
| fn=init_SegTracker, |
| inputs=[ |
| aot_model, |
| long_term_mem, |
| max_len_long_term, |
| sam_gap, |
| max_obj_num, |
| points_per_side, |
| origin_frame |
| ], |
| outputs=[ |
| Seg_Tracker, input_first_frame, click_stack, grounding_caption |
| ], |
| queue=False, |
| |
| ) |
| |
| tab_click.select( |
| fn=init_SegTracker, |
| inputs=[ |
| aot_model, |
| long_term_mem, |
| max_len_long_term, |
| sam_gap, |
| max_obj_num, |
| points_per_side, |
| origin_frame |
| ], |
| outputs=[ |
| Seg_Tracker, input_first_frame, click_stack, grounding_caption |
| ], |
| queue=False, |
| ) |
|
|
| tab_stroke.select( |
| fn=init_SegTracker_Stroke, |
| inputs=[ |
| aot_model, |
| long_term_mem, |
| max_len_long_term, |
| sam_gap, |
| max_obj_num, |
| points_per_side, |
| origin_frame, |
| ], |
| outputs=[ |
| Seg_Tracker, input_first_frame, click_stack, drawing_board |
| ], |
| queue=False, |
| ) |
|
|
| tab_text.select( |
| fn=init_SegTracker, |
| inputs=[ |
| aot_model, |
| long_term_mem, |
| max_len_long_term, |
| sam_gap, |
| max_obj_num, |
| points_per_side, |
| origin_frame |
| ], |
| outputs=[ |
| Seg_Tracker, input_first_frame, click_stack, grounding_caption |
| ], |
| queue=False, |
| ) |
|
|
| |
| seg_every_first_frame.click( |
| fn=segment_everything, |
| inputs=[ |
| Seg_Tracker, |
| aot_model, |
| long_term_mem, |
| max_len_long_term, |
| origin_frame, |
| sam_gap, |
| max_obj_num, |
| points_per_side, |
|
|
| ], |
| outputs=[ |
| Seg_Tracker, |
| input_first_frame, |
| ], |
| ) |
| |
| |
| input_first_frame.select( |
| fn=sam_click, |
| inputs=[ |
| Seg_Tracker, origin_frame, point_mode, click_stack, |
| aot_model, |
| long_term_mem, |
| max_len_long_term, |
| sam_gap, |
| max_obj_num, |
| points_per_side, |
| ], |
| outputs=[ |
| Seg_Tracker, input_first_frame, click_stack |
| ] |
| ) |
|
|
| |
| seg_acc_stroke.click( |
| fn=sam_stroke, |
| inputs=[ |
| Seg_Tracker, origin_frame, drawing_board, |
| aot_model, |
| long_term_mem, |
| max_len_long_term, |
| sam_gap, |
| max_obj_num, |
| points_per_side, |
| ], |
| outputs=[ |
| Seg_Tracker, input_first_frame, drawing_board |
| ] |
| ) |
|
|
| |
| detect_button.click( |
| fn=gd_detect, |
| inputs=[ |
| Seg_Tracker, origin_frame, grounding_caption, box_threshold, text_threshold, |
| aot_model, long_term_mem, max_len_long_term, sam_gap, max_obj_num, points_per_side |
| ], |
| outputs=[ |
| Seg_Tracker, input_first_frame |
| ] |
| ) |
|
|
| |
| new_object_button.click( |
| fn=add_new_object, |
| inputs= |
| [ |
| Seg_Tracker |
| ], |
| outputs= |
| [ |
| Seg_Tracker, click_stack |
| ] |
| ) |
|
|
| |
| track_for_video.click( |
| fn=tracking_objects, |
| inputs=[ |
| Seg_Tracker, |
| input_video, |
| input_img_seq, |
| fps, |
| ], |
| outputs=[ |
| output_video, output_mask |
| ] |
| ) |
|
|
|
|
| |
|
|
| output_res.select( |
| fn = choose_obj_to_refine, |
| inputs=[ |
| input_video, input_img_seq, Seg_Tracker, frame_num |
| ], |
| outputs=[output_res, refine_idx] |
| ) |
| |
|
|
| roll_back_button.click( |
| fn=show_chosen_idx_to_refine, |
| inputs=[ |
| aot_model, |
| long_term_mem, |
| max_len_long_term, |
| sam_gap, |
| max_obj_num, |
| points_per_side, |
| input_video, input_img_seq, Seg_Tracker, frame_num, refine_idx |
| ], |
| outputs=[ |
| refine_res, Seg_Tracker, origin_frame, click_stack, grounding_caption |
| ], |
| queue=False, |
| show_progress=False |
| ) |
|
|
|
|
|
|
| roll_back_click_undo_but.click( |
| fn = roll_back_undo_click_stack_and_refine_seg, |
| inputs=[ |
| Seg_Tracker, origin_frame, click_stack, |
| aot_model, |
| long_term_mem, |
| max_len_long_term, |
| sam_gap, |
| max_obj_num, |
| points_per_side, |
| input_video, input_img_seq, frame_num, refine_idx |
| ], |
| outputs=[ |
| Seg_Tracker, refine_res, click_stack |
| ] |
| ) |
|
|
| refine_res.select( |
| fn=roll_back_sam_click, |
| inputs=[ |
| Seg_Tracker, origin_frame, roll_back_point_mode, click_stack, |
| aot_model, |
| long_term_mem, |
| max_len_long_term, |
| sam_gap, |
| max_obj_num, |
| points_per_side, |
| input_video, input_img_seq, frame_num, refine_idx |
| ], |
| outputs=[ |
| Seg_Tracker, refine_res, click_stack |
| ] |
| ) |
|
|
|
|
| |
| roll_back_track_for_video.click( |
| fn=tracking_objects, |
| inputs=[ |
| Seg_Tracker, |
| input_video, |
| input_img_seq, |
| fps, frame_num |
| ], |
| outputs=[ |
| output_video, output_mask |
| ] |
| ) |
|
|
|
|
| |
|
|
| |
| reset_button.click( |
| fn=init_SegTracker, |
| inputs=[ |
| aot_model, |
| long_term_mem, |
| max_len_long_term, |
| sam_gap, |
| max_obj_num, |
| points_per_side, |
| origin_frame |
| ], |
| outputs=[ |
| Seg_Tracker, input_first_frame, click_stack, grounding_caption |
| ], |
| queue=False, |
| show_progress=False |
| ) |
|
|
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| click_undo_but.click( |
| fn = undo_click_stack_and_refine_seg, |
| inputs=[ |
| Seg_Tracker, origin_frame, click_stack, |
| aot_model, |
| long_term_mem, |
| max_len_long_term, |
| sam_gap, |
| max_obj_num, |
| points_per_side, |
| ], |
| outputs=[ |
| Seg_Tracker, input_first_frame, click_stack |
| ] |
| ) |
|
|
| every_undo_but.click( |
| fn = undo_click_stack_and_refine_seg, |
| inputs=[ |
| Seg_Tracker, origin_frame, click_stack, |
| aot_model, |
| long_term_mem, |
| max_len_long_term, |
| sam_gap, |
| max_obj_num, |
| points_per_side, |
| ], |
| outputs=[ |
| Seg_Tracker, input_first_frame, click_stack |
| ] |
| ) |
| |
| with gr.Tab(label='Video example'): |
| gr.Examples( |
| examples=[ |
| |
| os.path.join(os.path.dirname(__file__), "assets", "blackswan.mp4"), |
| |
| |
| |
| |
| ], |
| inputs=[input_video], |
| ) |
| |
| with gr.Tab(label='Image-seq expamle'): |
| gr.Examples( |
| examples=[ |
| os.path.join(os.path.dirname(__file__), "assets", "840_iSXIa0hE8Ek.zip"), |
| ], |
| inputs=[input_img_seq], |
| ) |
| |
| app.queue(concurrency_count=1) |
| app.launch(debug=True, enable_queue=True, share=True) |
|
|
|
|
| if __name__ == "__main__": |
| seg_track_app() |
|
|