Commit Β·
77b7c2e
1
Parent(s): d6e7745
integrate MatAnyone 1 & 2
Browse files- README.md +1 -1
- hugging_face/app.py +159 -49
- hugging_face/{matanyone_wrapper.py β matanyone2_wrapper.py} +10 -6
- matanyone/utils/__init__.py +0 -0
- matanyone2/__init__.py +2 -0
- {matanyone β matanyone2/config}/__init__.py +0 -0
- {matanyone β matanyone2}/config/eval_matanyone_config.yaml +1 -1
- {matanyone β matanyone2}/config/hydra/job_logging/custom-no-rank.yaml +0 -0
- {matanyone β matanyone2}/config/hydra/job_logging/custom.yaml +0 -0
- {matanyone β matanyone2}/config/model/base.yaml +0 -0
- {matanyone/config β matanyone2/inference}/__init__.py +0 -0
- {matanyone β matanyone2}/inference/image_feature_store.py +2 -2
- {matanyone β matanyone2}/inference/inference_core.py +155 -11
- {matanyone β matanyone2}/inference/kv_memory_store.py +0 -0
- {matanyone β matanyone2}/inference/memory_manager.py +6 -6
- {matanyone β matanyone2}/inference/object_info.py +0 -0
- {matanyone β matanyone2}/inference/object_manager.py +1 -1
- {matanyone/inference β matanyone2/inference/utils}/__init__.py +0 -0
- {matanyone β matanyone2}/inference/utils/args_utils.py +0 -0
- {matanyone/inference/utils β matanyone2/model}/__init__.py +0 -0
- {matanyone β matanyone2}/model/aux_modules.py +2 -2
- {matanyone β matanyone2}/model/big_modules.py +5 -4
- {matanyone β matanyone2}/model/channel_attn.py +0 -0
- {matanyone β matanyone2}/model/group_modules.py +1 -1
- matanyone/model/matanyone.py β matanyone2/model/matanyone2.py +22 -17
- {matanyone β matanyone2}/model/modules.py +6 -5
- {matanyone/model β matanyone2/model/transformer}/__init__.py +0 -0
- {matanyone β matanyone2}/model/transformer/object_summarizer.py +4 -3
- {matanyone β matanyone2}/model/transformer/object_transformer.py +4 -4
- {matanyone β matanyone2}/model/transformer/positional_encoding.py +4 -2
- {matanyone β matanyone2}/model/transformer/transformer_layers.py +1 -1
- {matanyone/model/transformer β matanyone2/model/utils}/__init__.py +0 -0
- {matanyone β matanyone2}/model/utils/memory_utils.py +0 -0
- {matanyone β matanyone2}/model/utils/parameter_groups.py +0 -0
- {matanyone β matanyone2}/model/utils/resnet.py +0 -0
- {matanyone/model β matanyone2}/utils/__init__.py +0 -0
- matanyone2/utils/device.py +33 -0
- {matanyone β matanyone2}/utils/get_default_model.py +6 -6
- matanyone2/utils/inference_utils.py +54 -0
- {matanyone β matanyone2}/utils/tensor_utils.py +3 -3
README.md
CHANGED
|
@@ -8,7 +8,7 @@ sdk_version: 5.16.0
|
|
| 8 |
app_file: hugging_face/app.py
|
| 9 |
pinned: false
|
| 10 |
license: other
|
| 11 |
-
short_description: Gradio demo for MatAnyone
|
| 12 |
---
|
| 13 |
|
| 14 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 8 |
app_file: hugging_face/app.py
|
| 9 |
pinned: false
|
| 10 |
license: other
|
| 11 |
+
short_description: Gradio demo for MatAnyone 1 & 2
|
| 12 |
---
|
| 13 |
|
| 14 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
hugging_face/app.py
CHANGED
|
@@ -21,9 +21,13 @@ from tools.interact_tools import SamControler
|
|
| 21 |
from tools.misc import get_device
|
| 22 |
from tools.download_util import load_file_from_url
|
| 23 |
|
| 24 |
-
from
|
| 25 |
-
from
|
| 26 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
def parse_augment():
|
| 29 |
parser = argparse.ArgumentParser()
|
|
@@ -121,7 +125,6 @@ def get_frames_from_video(video_input, video_state):
|
|
| 121 |
except Exception as e:
|
| 122 |
print(f"Audio extraction error: {str(e)}")
|
| 123 |
audio_path = "" # Set to "" if extraction fails
|
| 124 |
-
# print(f'audio_path: {audio_path}')
|
| 125 |
|
| 126 |
# extract frames
|
| 127 |
try:
|
|
@@ -140,8 +143,8 @@ def get_frames_from_video(video_input, video_state):
|
|
| 140 |
print("read_frame_source:{} error. {}\n".format(video_path, str(e)))
|
| 141 |
image_size = (frames[0].shape[0],frames[0].shape[1])
|
| 142 |
|
| 143 |
-
# resize if resolution too big
|
| 144 |
-
if image_size[0]>=
|
| 145 |
scale = 1080 / min(image_size)
|
| 146 |
new_w = int(image_size[1] * scale)
|
| 147 |
new_h = int(image_size[0] * scale)
|
|
@@ -165,8 +168,7 @@ def get_frames_from_video(video_input, video_state):
|
|
| 165 |
video_info = "Video Name: {},\nFPS: {},\nTotal Frames: {},\nImage Size:{}".format(video_state["video_name"], round(video_state["fps"], 0), len(frames), image_size)
|
| 166 |
model.samcontroler.sam_controler.reset_image()
|
| 167 |
model.samcontroler.sam_controler.set_image(video_state["origin_images"][0])
|
| 168 |
-
return video_state, video_info, video_state["origin_images"][0], \
|
| 169 |
-
gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=False, maximum=len(frames), value=len(frames)), \
|
| 170 |
gr.update(visible=True), gr.update(visible=True), \
|
| 171 |
gr.update(visible=True), gr.update(visible=True),\
|
| 172 |
gr.update(visible=True), gr.update(visible=True), \
|
|
@@ -267,8 +269,18 @@ def show_mask(video_state, interactive_state, mask_dropdown):
|
|
| 267 |
return select_frame
|
| 268 |
|
| 269 |
# image matting
|
| 270 |
-
def image_matting(video_state, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size, refine_iter):
|
| 271 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
if interactive_state["track_end_number"]:
|
| 273 |
following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
|
| 274 |
else:
|
|
@@ -289,14 +301,25 @@ def image_matting(video_state, interactive_state, mask_dropdown, erode_kernel_si
|
|
| 289 |
# operation error
|
| 290 |
if len(np.unique(template_mask))==1:
|
| 291 |
template_mask[0][0]=1
|
| 292 |
-
foreground, alpha =
|
| 293 |
foreground_output = Image.fromarray(foreground[-1])
|
| 294 |
alpha_output = Image.fromarray(alpha[-1][:,:,0])
|
|
|
|
| 295 |
return foreground_output, alpha_output
|
| 296 |
|
| 297 |
# video matting
|
| 298 |
-
def video_matting(video_state, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size):
|
| 299 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
if interactive_state["track_end_number"]:
|
| 301 |
following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
|
| 302 |
else:
|
|
@@ -320,11 +343,11 @@ def video_matting(video_state, interactive_state, mask_dropdown, erode_kernel_si
|
|
| 320 |
# operation error
|
| 321 |
if len(np.unique(template_mask))==1:
|
| 322 |
template_mask[0][0]=1
|
| 323 |
-
foreground, alpha =
|
| 324 |
|
| 325 |
foreground_output = generate_video_from_frames(foreground, output_path="./results/{}_fg.mp4".format(video_state["video_name"]), fps=fps, audio_path=audio_path) # import video_input to name the output video
|
| 326 |
alpha_output = generate_video_from_frames(alpha, output_path="./results/{}_alpha.mp4".format(video_state["video_name"]), fps=fps, gray2rgb=True, audio_path=audio_path) # import video_input to name the output video
|
| 327 |
-
|
| 328 |
return foreground_output, alpha_output
|
| 329 |
|
| 330 |
|
|
@@ -415,47 +438,113 @@ sam_checkpoint = load_file_from_url(sam_checkpoint_url_dict[args.sam_model_type]
|
|
| 415 |
# initialize sams
|
| 416 |
model = MaskGenerator(sam_checkpoint, args)
|
| 417 |
|
| 418 |
-
# initialize matanyone
|
| 419 |
-
#
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 426 |
|
| 427 |
-
|
| 428 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 429 |
|
| 430 |
# download test samples
|
| 431 |
-
media_url = "https://github.com/pq-yang/MatAnyone/releases/download/media/"
|
| 432 |
test_sample_path = os.path.join('/home/user/app/hugging_face/', "test_sample/")
|
| 433 |
-
load_file_from_url(
|
| 434 |
-
load_file_from_url(
|
| 435 |
-
load_file_from_url(
|
| 436 |
-
load_file_from_url(
|
| 437 |
-
load_file_from_url(
|
| 438 |
-
load_file_from_url(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 439 |
|
| 440 |
# download assets
|
| 441 |
assets_path = os.path.join('/home/user/app/hugging_face/', "assets/")
|
| 442 |
-
load_file_from_url(
|
| 443 |
-
load_file_from_url(
|
| 444 |
|
| 445 |
# documents
|
| 446 |
-
title = r"""<div class="multi-layer" align="center"><span>MatAnyone</span></div>
|
| 447 |
"""
|
| 448 |
description = r"""
|
| 449 |
-
<b>Official Gradio demo</b> for <a href='https://github.com/pq-yang/
|
| 450 |
-
π₯ MatAnyone
|
| 451 |
-
|
|
|
|
| 452 |
|
| 453 |
*Note: Due to the online GPU memory constraints, any input with too big resolution will be resized to 1080p.<br>*
|
| 454 |
-
π <b> If you encounter any issue (e.g., frozen video output) or wish to run on higher resolution inputs, please consider
|
| 455 |
-
|
| 456 |
"""
|
| 457 |
article = r"""<h3>
|
| 458 |
-
<b>If
|
| 459 |
|
| 460 |
---
|
| 461 |
|
|
@@ -463,6 +552,13 @@ article = r"""<h3>
|
|
| 463 |
<br>
|
| 464 |
If our work is useful for your research, please consider citing:
|
| 465 |
```bibtex
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 466 |
@InProceedings{yang2025matanyone,
|
| 467 |
title = {{MatAnyone}: Stable Video Matting with Consistent Memory Propagation},
|
| 468 |
author = {Yang, Peiqing and Zhou, Shangchen and Zhao, Jixin and Tao, Qingyi and Loy, Chen Change},
|
|
@@ -558,10 +654,10 @@ with gr.Blocks(theme=gr.themes.Monochrome(), css=my_custom_css) as demo:
|
|
| 558 |
<div class="title-container">
|
| 559 |
<h1 class="title is-2 publication-title"
|
| 560 |
style="font-size:50px; font-family: 'Sarpanch', serif;
|
| 561 |
-
background: linear-gradient(to right, #
|
| 562 |
display: inline-block; -webkit-background-clip: text;
|
| 563 |
-webkit-text-fill-color: transparent;">
|
| 564 |
-
MatAnyone
|
| 565 |
</h1>
|
| 566 |
</div>
|
| 567 |
''')
|
|
@@ -614,7 +710,14 @@ with gr.Blocks(theme=gr.themes.Monochrome(), css=my_custom_css) as demo:
|
|
| 614 |
|
| 615 |
with gr.Group(elem_classes="gr-monochrome-group", visible=True):
|
| 616 |
with gr.Row():
|
| 617 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 618 |
with gr.Row():
|
| 619 |
erode_kernel_size = gr.Slider(label='Erode Kernel Size',
|
| 620 |
minimum=0,
|
|
@@ -722,7 +825,7 @@ with gr.Blocks(theme=gr.themes.Monochrome(), css=my_custom_css) as demo:
|
|
| 722 |
# video matting
|
| 723 |
matting_button.click(
|
| 724 |
fn=video_matting,
|
| 725 |
-
inputs=[video_state, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size],
|
| 726 |
outputs=[foreground_video_output, alpha_video_output]
|
| 727 |
)
|
| 728 |
|
|
@@ -775,7 +878,7 @@ with gr.Blocks(theme=gr.themes.Monochrome(), css=my_custom_css) as demo:
|
|
| 775 |
gr.Markdown("---")
|
| 776 |
gr.Markdown("## Examples")
|
| 777 |
gr.Examples(
|
| 778 |
-
examples=[os.path.join(os.path.dirname(__file__), "./test_sample/", test_sample) for test_sample in ["test-
|
| 779 |
inputs=[video_input],
|
| 780 |
)
|
| 781 |
|
|
@@ -811,7 +914,14 @@ with gr.Blocks(theme=gr.themes.Monochrome(), css=my_custom_css) as demo:
|
|
| 811 |
|
| 812 |
with gr.Group(elem_classes="gr-monochrome-group", visible=True):
|
| 813 |
with gr.Row():
|
| 814 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 815 |
with gr.Row():
|
| 816 |
erode_kernel_size = gr.Slider(label='Erode Kernel Size',
|
| 817 |
minimum=0,
|
|
@@ -918,7 +1028,7 @@ with gr.Blocks(theme=gr.themes.Monochrome(), css=my_custom_css) as demo:
|
|
| 918 |
# image matting
|
| 919 |
matting_button.click(
|
| 920 |
fn=image_matting,
|
| 921 |
-
inputs=[image_state, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size, image_selection_slider],
|
| 922 |
outputs=[foreground_image_output, alpha_image_output]
|
| 923 |
)
|
| 924 |
|
|
@@ -971,7 +1081,7 @@ with gr.Blocks(theme=gr.themes.Monochrome(), css=my_custom_css) as demo:
|
|
| 971 |
gr.Markdown("---")
|
| 972 |
gr.Markdown("## Examples")
|
| 973 |
gr.Examples(
|
| 974 |
-
examples=[os.path.join(os.path.dirname(__file__), "./test_sample/", test_sample) for test_sample in ["test-
|
| 975 |
inputs=[image_input],
|
| 976 |
)
|
| 977 |
|
|
|
|
| 21 |
from tools.misc import get_device
|
| 22 |
from tools.download_util import load_file_from_url
|
| 23 |
|
| 24 |
+
from matanyone2_wrapper import matanyone2
|
| 25 |
+
from matanyone2.utils.get_default_model import get_matanyone2_model
|
| 26 |
+
from matanyone2.inference.inference_core import InferenceCore
|
| 27 |
+
from hydra.core.global_hydra import GlobalHydra
|
| 28 |
+
|
| 29 |
+
import warnings
|
| 30 |
+
warnings.filterwarnings("ignore")
|
| 31 |
|
| 32 |
def parse_augment():
|
| 33 |
parser = argparse.ArgumentParser()
|
|
|
|
| 125 |
except Exception as e:
|
| 126 |
print(f"Audio extraction error: {str(e)}")
|
| 127 |
audio_path = "" # Set to "" if extraction fails
|
|
|
|
| 128 |
|
| 129 |
# extract frames
|
| 130 |
try:
|
|
|
|
| 143 |
print("read_frame_source:{} error. {}\n".format(video_path, str(e)))
|
| 144 |
image_size = (frames[0].shape[0],frames[0].shape[1])
|
| 145 |
|
| 146 |
+
# [remove for local demo] resize if resolution too big
|
| 147 |
+
if image_size[0]>=1080 and image_size[0]>=1080:
|
| 148 |
scale = 1080 / min(image_size)
|
| 149 |
new_w = int(image_size[1] * scale)
|
| 150 |
new_h = int(image_size[0] * scale)
|
|
|
|
| 168 |
video_info = "Video Name: {},\nFPS: {},\nTotal Frames: {},\nImage Size:{}".format(video_state["video_name"], round(video_state["fps"], 0), len(frames), image_size)
|
| 169 |
model.samcontroler.sam_controler.reset_image()
|
| 170 |
model.samcontroler.sam_controler.set_image(video_state["origin_images"][0])
|
| 171 |
+
return video_state, video_info, video_state["origin_images"][0], gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=False, maximum=len(frames), value=len(frames)), \
|
|
|
|
| 172 |
gr.update(visible=True), gr.update(visible=True), \
|
| 173 |
gr.update(visible=True), gr.update(visible=True),\
|
| 174 |
gr.update(visible=True), gr.update(visible=True), \
|
|
|
|
| 269 |
return select_frame
|
| 270 |
|
| 271 |
# image matting
|
| 272 |
+
def image_matting(video_state, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size, refine_iter, model_selection):
|
| 273 |
+
# Load model if not already loaded
|
| 274 |
+
try:
|
| 275 |
+
selected_model = load_model(model_selection)
|
| 276 |
+
except (FileNotFoundError, ValueError) as e:
|
| 277 |
+
# Fallback to first available model
|
| 278 |
+
if available_models:
|
| 279 |
+
print(f"Warning: {str(e)}. Using {available_models[0]} instead.")
|
| 280 |
+
selected_model = load_model(available_models[0])
|
| 281 |
+
else:
|
| 282 |
+
raise ValueError("No models are available! Please check if the model files exist.")
|
| 283 |
+
matanyone_processor = InferenceCore(selected_model, cfg=selected_model.cfg)
|
| 284 |
if interactive_state["track_end_number"]:
|
| 285 |
following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
|
| 286 |
else:
|
|
|
|
| 301 |
# operation error
|
| 302 |
if len(np.unique(template_mask))==1:
|
| 303 |
template_mask[0][0]=1
|
| 304 |
+
foreground, alpha = matanyone2(matanyone_processor, following_frames, template_mask*255, r_erode=erode_kernel_size, r_dilate=dilate_kernel_size, n_warmup=refine_iter)
|
| 305 |
foreground_output = Image.fromarray(foreground[-1])
|
| 306 |
alpha_output = Image.fromarray(alpha[-1][:,:,0])
|
| 307 |
+
|
| 308 |
return foreground_output, alpha_output
|
| 309 |
|
| 310 |
# video matting
|
| 311 |
+
def video_matting(video_state, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size, model_selection):
|
| 312 |
+
# Load model if not already loaded
|
| 313 |
+
try:
|
| 314 |
+
selected_model = load_model(model_selection)
|
| 315 |
+
except (FileNotFoundError, ValueError) as e:
|
| 316 |
+
# Fallback to first available model
|
| 317 |
+
if available_models:
|
| 318 |
+
print(f"Warning: {str(e)}. Using {available_models[0]} instead.")
|
| 319 |
+
selected_model = load_model(available_models[0])
|
| 320 |
+
else:
|
| 321 |
+
raise ValueError("No models are available! Please check if the model files exist.")
|
| 322 |
+
matanyone_processor = InferenceCore(selected_model, cfg=selected_model.cfg)
|
| 323 |
if interactive_state["track_end_number"]:
|
| 324 |
following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
|
| 325 |
else:
|
|
|
|
| 343 |
# operation error
|
| 344 |
if len(np.unique(template_mask))==1:
|
| 345 |
template_mask[0][0]=1
|
| 346 |
+
foreground, alpha = matanyone2(matanyone_processor, following_frames, template_mask*255, r_erode=erode_kernel_size, r_dilate=dilate_kernel_size)
|
| 347 |
|
| 348 |
foreground_output = generate_video_from_frames(foreground, output_path="./results/{}_fg.mp4".format(video_state["video_name"]), fps=fps, audio_path=audio_path) # import video_input to name the output video
|
| 349 |
alpha_output = generate_video_from_frames(alpha, output_path="./results/{}_alpha.mp4".format(video_state["video_name"]), fps=fps, gray2rgb=True, audio_path=audio_path) # import video_input to name the output video
|
| 350 |
+
|
| 351 |
return foreground_output, alpha_output
|
| 352 |
|
| 353 |
|
|
|
|
| 438 |
# initialize sams
|
| 439 |
model = MaskGenerator(sam_checkpoint, args)
|
| 440 |
|
| 441 |
+
# initialize matanyone - lazy loading
|
| 442 |
+
# Model display names to file names mapping
|
| 443 |
+
model_display_to_file = {
|
| 444 |
+
"MatAnyone": "matanyone.pth",
|
| 445 |
+
"MatAnyone 2": "matanyone2.pth"
|
| 446 |
+
}
|
| 447 |
+
|
| 448 |
+
# Model URLs
|
| 449 |
+
model_urls = {
|
| 450 |
+
"matanyone.pth": "https://github.com/pq-yang/MatAnyone/releases/download/v1.0.0/matanyone.pth",
|
| 451 |
+
"matanyone2.pth": "https://github.com/pq-yang/MatAnyone2/releases/download/v1.0.0/matanyone2.pth"
|
| 452 |
+
}
|
| 453 |
+
|
| 454 |
+
# Model paths - download models using load_file_from_url
|
| 455 |
+
model_paths = {
|
| 456 |
+
"matanyone.pth": load_file_from_url(model_urls["matanyone.pth"], checkpoint_folder),
|
| 457 |
+
"matanyone2.pth": load_file_from_url(model_urls["matanyone2.pth"], checkpoint_folder)
|
| 458 |
+
}
|
| 459 |
|
| 460 |
+
# Cache for loaded models (lazy loading)
|
| 461 |
+
loaded_models = {}
|
| 462 |
+
|
| 463 |
+
def load_model(display_name):
|
| 464 |
+
"""Load a model if not already loaded"""
|
| 465 |
+
# Convert display name to file name
|
| 466 |
+
if display_name in model_display_to_file:
|
| 467 |
+
model_file = model_display_to_file[display_name]
|
| 468 |
+
elif display_name in model_paths:
|
| 469 |
+
# Also support direct file name for backward compatibility
|
| 470 |
+
model_file = display_name
|
| 471 |
+
else:
|
| 472 |
+
raise ValueError(f"Unknown model: {display_name}")
|
| 473 |
+
|
| 474 |
+
if model_file in loaded_models:
|
| 475 |
+
return loaded_models[model_file]
|
| 476 |
+
|
| 477 |
+
if model_file not in model_paths:
|
| 478 |
+
raise ValueError(f"Unknown model file: {model_file}")
|
| 479 |
+
|
| 480 |
+
ckpt_path = model_paths[model_file]
|
| 481 |
+
if not os.path.exists(ckpt_path):
|
| 482 |
+
raise FileNotFoundError(f"Model file not found: {ckpt_path}")
|
| 483 |
+
|
| 484 |
+
# Clear Hydra instance if already initialized (to allow loading different models)
|
| 485 |
+
try:
|
| 486 |
+
GlobalHydra.instance().clear()
|
| 487 |
+
except:
|
| 488 |
+
pass # If Hydra is not initialized, this is fine
|
| 489 |
+
|
| 490 |
+
print(f"Loading model: {display_name} ({model_file})...")
|
| 491 |
+
model = get_matanyone2_model(ckpt_path, args.device)
|
| 492 |
+
model = model.to(args.device).eval()
|
| 493 |
+
loaded_models[model_file] = model
|
| 494 |
+
print(f"Model {display_name} loaded successfully.")
|
| 495 |
+
return model
|
| 496 |
+
|
| 497 |
+
# Get available model choices for the UI (check if files exist)
|
| 498 |
+
# Order: MatAnyone 2 first, then MatAnyone
|
| 499 |
+
available_models = []
|
| 500 |
+
# Check MatAnyone 2 first
|
| 501 |
+
if "MatAnyone 2" in model_display_to_file:
|
| 502 |
+
file_name = model_display_to_file["MatAnyone 2"]
|
| 503 |
+
if file_name in model_paths and os.path.exists(model_paths[file_name]):
|
| 504 |
+
available_models.append("MatAnyone 2")
|
| 505 |
+
# Then check MatAnyone
|
| 506 |
+
if "MatAnyone" in model_display_to_file:
|
| 507 |
+
file_name = model_display_to_file["MatAnyone"]
|
| 508 |
+
if file_name in model_paths and os.path.exists(model_paths[file_name]):
|
| 509 |
+
available_models.append("MatAnyone")
|
| 510 |
+
|
| 511 |
+
if not available_models:
|
| 512 |
+
raise RuntimeError("No models are available! Please ensure at least one model file exists in ../pretrained_models/")
|
| 513 |
+
default_model = "MatAnyone 2" if "MatAnyone 2" in available_models else available_models[0]
|
| 514 |
|
| 515 |
# download test samples
|
|
|
|
| 516 |
test_sample_path = os.path.join('/home/user/app/hugging_face/', "test_sample/")
|
| 517 |
+
load_file_from_url('https://github.com/pq-yang/MatAnyone2/releases/download/media/test-sample-0-1080p.mp4', test_sample_path)
|
| 518 |
+
load_file_from_url('https://github.com/pq-yang/MatAnyone2/releases/download/media/test-sample-1-1080p.mp4', test_sample_path)
|
| 519 |
+
load_file_from_url('https://github.com/pq-yang/MatAnyone2/releases/download/media/test-sample-2-720p.mp4', test_sample_path)
|
| 520 |
+
load_file_from_url('https://github.com/pq-yang/MatAnyone2/releases/download/media/test-sample-3-720p.mp4', test_sample_path)
|
| 521 |
+
load_file_from_url('https://github.com/pq-yang/MatAnyone2/releases/download/media/test-sample-4-720p.mp4', test_sample_path)
|
| 522 |
+
load_file_from_url('https://github.com/pq-yang/MatAnyone2/releases/download/media/test-sample-5-720p.mp4', test_sample_path)
|
| 523 |
+
load_file_from_url('https://github.com/pq-yang/MatAnyone2/releases/download/media/test-sample-0.jpg', test_sample_path)
|
| 524 |
+
load_file_from_url('https://github.com/pq-yang/MatAnyone2/releases/download/media/test-sample-1.jpg', test_sample_path)
|
| 525 |
+
load_file_from_url('https://github.com/pq-yang/MatAnyone2/releases/download/media/test-sample-2.jpg', test_sample_path)
|
| 526 |
+
load_file_from_url('https://github.com/pq-yang/MatAnyone2/releases/download/media/test-sample-3.jpg', test_sample_path)
|
| 527 |
|
| 528 |
# download assets
|
| 529 |
assets_path = os.path.join('/home/user/app/hugging_face/', "assets/")
|
| 530 |
+
load_file_from_url('https://github.com/pq-yang/MatAnyone/releases/download/media/tutorial_single_target.mp4', assets_path)
|
| 531 |
+
load_file_from_url('https://github.com/pq-yang/MatAnyone/releases/download/media/tutorial_multi_targets.mp4', assets_path)
|
| 532 |
|
| 533 |
# documents
|
| 534 |
+
title = r"""<div class="multi-layer" align="center"><span>MatAnyone Series</span></div>
|
| 535 |
"""
|
| 536 |
description = r"""
|
| 537 |
+
<b>Official Gradio demo</b> for <a href='https://github.com/pq-yang/MatAnyone2' target='_blank'><b>MatAnyone 2</b></a> and <a href='https://github.com/pq-yang/MatAnyone' target='_blank'><b>MatAnyone</b></a>.<br>
|
| 538 |
+
π₯ MatAnyone series provide practical human video matting framework supporting target assignment.<br>
|
| 539 |
+
π§ <b>We use <u>MatAnyone 2</u> as the default model. You can also choose <u>MatAnyone</u> in "Model Selection".</b><br>
|
| 540 |
+
πͺ Try to drop your video/image, assign the target masks with a few clicks, and get the the matting results!<br>
|
| 541 |
|
| 542 |
*Note: Due to the online GPU memory constraints, any input with too big resolution will be resized to 1080p.<br>*
|
| 543 |
+
π <b> If you encounter any issue (e.g., frozen video output) or wish to run on higher resolution inputs, please consider duplicating this space or
|
| 544 |
+
launching the demo locally following the <a href='https://github.com/pq-yang/MatAnyone2?tab=readme-ov-file#-interactive-demo' target='_blank'>GitHub instructions</a>.</b>
|
| 545 |
"""
|
| 546 |
article = r"""<h3>
|
| 547 |
+
<b>If our projects are helpful, please help to π the Github Repo for <a href='https://github.com/pq-yang/MatAnyone2' target='_blank'>MatAnyone 2</a> and <a href='https://github.com/pq-yang/MatAnyone' target='_blank'>MatAnyone</a>. Thanks!</b></h3>
|
| 548 |
|
| 549 |
---
|
| 550 |
|
|
|
|
| 552 |
<br>
|
| 553 |
If our work is useful for your research, please consider citing:
|
| 554 |
```bibtex
|
| 555 |
+
@InProceedings{yang2026matanyone2,
|
| 556 |
+
title = {{MatAnyone 2}: Scaling Video Matting via a Learned Quality Evaluator},
|
| 557 |
+
author = {Yang, Peiqing and Zhou, Shangchen and Hao, Kai and Tao, Qingyi},
|
| 558 |
+
booktitle = {CVPR},
|
| 559 |
+
year = {2026}
|
| 560 |
+
}
|
| 561 |
+
|
| 562 |
@InProceedings{yang2025matanyone,
|
| 563 |
title = {{MatAnyone}: Stable Video Matting with Consistent Memory Propagation},
|
| 564 |
author = {Yang, Peiqing and Zhou, Shangchen and Zhao, Jixin and Tao, Qingyi and Loy, Chen Change},
|
|
|
|
| 654 |
<div class="title-container">
|
| 655 |
<h1 class="title is-2 publication-title"
|
| 656 |
style="font-size:50px; font-family: 'Sarpanch', serif;
|
| 657 |
+
background: linear-gradient(to right, #000000, #2dc464);
|
| 658 |
display: inline-block; -webkit-background-clip: text;
|
| 659 |
-webkit-text-fill-color: transparent;">
|
| 660 |
+
MatAnyone Series
|
| 661 |
</h1>
|
| 662 |
</div>
|
| 663 |
''')
|
|
|
|
| 710 |
|
| 711 |
with gr.Group(elem_classes="gr-monochrome-group", visible=True):
|
| 712 |
with gr.Row():
|
| 713 |
+
model_selection = gr.Radio(
|
| 714 |
+
choices=available_models,
|
| 715 |
+
value=default_model,
|
| 716 |
+
label="Model Selection",
|
| 717 |
+
info="Choose the model to use for matting",
|
| 718 |
+
interactive=True)
|
| 719 |
+
with gr.Row():
|
| 720 |
+
with gr.Accordion('Model Settings (click to expand)', open=False):
|
| 721 |
with gr.Row():
|
| 722 |
erode_kernel_size = gr.Slider(label='Erode Kernel Size',
|
| 723 |
minimum=0,
|
|
|
|
| 825 |
# video matting
|
| 826 |
matting_button.click(
|
| 827 |
fn=video_matting,
|
| 828 |
+
inputs=[video_state, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size, model_selection],
|
| 829 |
outputs=[foreground_video_output, alpha_video_output]
|
| 830 |
)
|
| 831 |
|
|
|
|
| 878 |
gr.Markdown("---")
|
| 879 |
gr.Markdown("## Examples")
|
| 880 |
gr.Examples(
|
| 881 |
+
examples=[os.path.join(os.path.dirname(__file__), "./test_sample/", test_sample) for test_sample in ["test-sample-0-1080p.mp4", "test-sample-1-1080p.mp4", "test-sample-2-720p.mp4", "test-sample-3-720p.mp4", "test-sample-4-720p.mp4", "test-sample-5-720p.mp4"]],
|
| 882 |
inputs=[video_input],
|
| 883 |
)
|
| 884 |
|
|
|
|
| 914 |
|
| 915 |
with gr.Group(elem_classes="gr-monochrome-group", visible=True):
|
| 916 |
with gr.Row():
|
| 917 |
+
model_selection = gr.Radio(
|
| 918 |
+
choices=available_models,
|
| 919 |
+
value=default_model,
|
| 920 |
+
label="Model Selection",
|
| 921 |
+
info="Choose the model to use for matting",
|
| 922 |
+
interactive=True)
|
| 923 |
+
with gr.Row():
|
| 924 |
+
with gr.Accordion('Model Settings (click to expand)', open=False):
|
| 925 |
with gr.Row():
|
| 926 |
erode_kernel_size = gr.Slider(label='Erode Kernel Size',
|
| 927 |
minimum=0,
|
|
|
|
| 1028 |
# image matting
|
| 1029 |
matting_button.click(
|
| 1030 |
fn=image_matting,
|
| 1031 |
+
inputs=[image_state, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size, image_selection_slider, model_selection],
|
| 1032 |
outputs=[foreground_image_output, alpha_image_output]
|
| 1033 |
)
|
| 1034 |
|
|
|
|
| 1081 |
gr.Markdown("---")
|
| 1082 |
gr.Markdown("## Examples")
|
| 1083 |
gr.Examples(
|
| 1084 |
+
examples=[os.path.join(os.path.dirname(__file__), "./test_sample/", test_sample) for test_sample in ["test-sample-0.jpg", "test-sample-1.jpg", "test-sample-2.jpg", "test-sample-3.jpg"]],
|
| 1085 |
inputs=[image_input],
|
| 1086 |
)
|
| 1087 |
|
hugging_face/{matanyone_wrapper.py β matanyone2_wrapper.py}
RENAMED
|
@@ -1,9 +1,13 @@
|
|
|
|
|
| 1 |
import tqdm
|
| 2 |
import torch
|
| 3 |
from torchvision.transforms.functional import to_tensor
|
| 4 |
import numpy as np
|
| 5 |
import random
|
| 6 |
import cv2
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
def gen_dilate(alpha, min_kernel_size, max_kernel_size):
|
| 9 |
kernel_size = random.randint(min_kernel_size, max_kernel_size)
|
|
@@ -20,8 +24,8 @@ def gen_erosion(alpha, min_kernel_size, max_kernel_size):
|
|
| 20 |
return erode.astype(np.float32)
|
| 21 |
|
| 22 |
@torch.inference_mode()
|
| 23 |
-
@
|
| 24 |
-
def
|
| 25 |
"""
|
| 26 |
Args:
|
| 27 |
frames_np: [(H,W,C)]*n, uint8
|
|
@@ -41,14 +45,14 @@ def matanyone(processor, frames_np, mask, r_erode=0, r_dilate=0, n_warmup=10):
|
|
| 41 |
if r_erode > 0:
|
| 42 |
mask = gen_erosion(mask, r_erode, r_erode)
|
| 43 |
|
| 44 |
-
mask = torch.from_numpy(mask).
|
| 45 |
|
| 46 |
frames_np = [frames_np[0]]* n_warmup + frames_np
|
| 47 |
|
| 48 |
frames = []
|
| 49 |
phas = []
|
| 50 |
for ti, frame_single in tqdm.tqdm(enumerate(frames_np)):
|
| 51 |
-
image = to_tensor(frame_single).
|
| 52 |
|
| 53 |
if ti == 0:
|
| 54 |
output_prob = processor.step(image, mask, objects=objects) # encode given mask
|
|
@@ -62,7 +66,7 @@ def matanyone(processor, frames_np, mask, r_erode=0, r_dilate=0, n_warmup=10):
|
|
| 62 |
# convert output probabilities to an object mask
|
| 63 |
mask = processor.output_prob_to_mask(output_prob)
|
| 64 |
|
| 65 |
-
pha = mask.unsqueeze(2).
|
| 66 |
com_np = frame_single / 255. * pha + bgr * (1 - pha)
|
| 67 |
|
| 68 |
# DONOT save the warmup frames
|
|
@@ -70,4 +74,4 @@ def matanyone(processor, frames_np, mask, r_erode=0, r_dilate=0, n_warmup=10):
|
|
| 70 |
frames.append((com_np*255).astype(np.uint8))
|
| 71 |
phas.append((pha*255).astype(np.uint8))
|
| 72 |
|
| 73 |
-
return frames, phas
|
|
|
|
| 1 |
+
|
| 2 |
import tqdm
|
| 3 |
import torch
|
| 4 |
from torchvision.transforms.functional import to_tensor
|
| 5 |
import numpy as np
|
| 6 |
import random
|
| 7 |
import cv2
|
| 8 |
+
from matanyone2.utils.device import get_default_device, safe_autocast_decorator
|
| 9 |
+
|
| 10 |
+
device = get_default_device()
|
| 11 |
|
| 12 |
def gen_dilate(alpha, min_kernel_size, max_kernel_size):
|
| 13 |
kernel_size = random.randint(min_kernel_size, max_kernel_size)
|
|
|
|
| 24 |
return erode.astype(np.float32)
|
| 25 |
|
| 26 |
@torch.inference_mode()
|
| 27 |
+
@safe_autocast_decorator()
|
| 28 |
+
def matanyone2(processor, frames_np, mask, r_erode=0, r_dilate=0, n_warmup=10):
|
| 29 |
"""
|
| 30 |
Args:
|
| 31 |
frames_np: [(H,W,C)]*n, uint8
|
|
|
|
| 45 |
if r_erode > 0:
|
| 46 |
mask = gen_erosion(mask, r_erode, r_erode)
|
| 47 |
|
| 48 |
+
mask = torch.from_numpy(mask).to(device)
|
| 49 |
|
| 50 |
frames_np = [frames_np[0]]* n_warmup + frames_np
|
| 51 |
|
| 52 |
frames = []
|
| 53 |
phas = []
|
| 54 |
for ti, frame_single in tqdm.tqdm(enumerate(frames_np)):
|
| 55 |
+
image = to_tensor(frame_single).float().to(device)
|
| 56 |
|
| 57 |
if ti == 0:
|
| 58 |
output_prob = processor.step(image, mask, objects=objects) # encode given mask
|
|
|
|
| 66 |
# convert output probabilities to an object mask
|
| 67 |
mask = processor.output_prob_to_mask(output_prob)
|
| 68 |
|
| 69 |
+
pha = mask.unsqueeze(2).detach().to("cpu").numpy()
|
| 70 |
com_np = frame_single / 255. * pha + bgr * (1 - pha)
|
| 71 |
|
| 72 |
# DONOT save the warmup frames
|
|
|
|
| 74 |
frames.append((com_np*255).astype(np.uint8))
|
| 75 |
phas.append((pha*255).astype(np.uint8))
|
| 76 |
|
| 77 |
+
return frames, phas
|
matanyone/utils/__init__.py
DELETED
|
File without changes
|
matanyone2/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from matanyone2.inference.inference_core import InferenceCore
|
| 2 |
+
from matanyone2.model.matanyone2 import MatAnyone2
|
{matanyone β matanyone2/config}/__init__.py
RENAMED
|
File without changes
|
{matanyone β matanyone2}/config/eval_matanyone_config.yaml
RENAMED
|
@@ -9,7 +9,7 @@ hydra:
|
|
| 9 |
output_subdir: ${now:%Y-%m-%d_%H-%M-%S}-hydra
|
| 10 |
|
| 11 |
amp: False
|
| 12 |
-
weights: pretrained_models/
|
| 13 |
output_dir: null # defaults to run_dir; specify this to override
|
| 14 |
flip_aug: False
|
| 15 |
|
|
|
|
| 9 |
output_subdir: ${now:%Y-%m-%d_%H-%M-%S}-hydra
|
| 10 |
|
| 11 |
amp: False
|
| 12 |
+
weights: pretrained_models/matanyone2.pth # default (can be modified from outside)
|
| 13 |
output_dir: null # defaults to run_dir; specify this to override
|
| 14 |
flip_aug: False
|
| 15 |
|
{matanyone β matanyone2}/config/hydra/job_logging/custom-no-rank.yaml
RENAMED
|
File without changes
|
{matanyone β matanyone2}/config/hydra/job_logging/custom.yaml
RENAMED
|
File without changes
|
{matanyone β matanyone2}/config/model/base.yaml
RENAMED
|
File without changes
|
{matanyone/config β matanyone2/inference}/__init__.py
RENAMED
|
File without changes
|
{matanyone β matanyone2}/inference/image_feature_store.py
RENAMED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
import warnings
|
| 2 |
from typing import Iterable
|
| 3 |
import torch
|
| 4 |
-
from
|
| 5 |
|
| 6 |
|
| 7 |
class ImageFeatureStore:
|
|
@@ -13,7 +13,7 @@ class ImageFeatureStore:
|
|
| 13 |
|
| 14 |
Feature of a frame should be associated with a unique index -- typically the frame id.
|
| 15 |
"""
|
| 16 |
-
def __init__(self, network:
|
| 17 |
self.network = network
|
| 18 |
self._store = {}
|
| 19 |
self.no_warning = no_warning
|
|
|
|
| 1 |
import warnings
|
| 2 |
from typing import Iterable
|
| 3 |
import torch
|
| 4 |
+
from matanyone2.model.matanyone2 import MatAnyone2
|
| 5 |
|
| 6 |
|
| 7 |
class ImageFeatureStore:
|
|
|
|
| 13 |
|
| 14 |
Feature of a frame should be associated with a unique index -- typically the frame id.
|
| 15 |
"""
|
| 16 |
+
def __init__(self, network: MatAnyone2, no_warning: bool = False):
|
| 17 |
self.network = network
|
| 18 |
self._store = {}
|
| 19 |
self.no_warning = no_warning
|
{matanyone β matanyone2}/inference/inference_core.py
RENAMED
|
@@ -1,16 +1,25 @@
|
|
| 1 |
-
from typing import List, Optional, Iterable
|
| 2 |
import logging
|
| 3 |
from omegaconf import DictConfig
|
|
|
|
| 4 |
|
| 5 |
-
import
|
|
|
|
| 6 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
import torch.nn.functional as F
|
| 8 |
|
| 9 |
-
from
|
| 10 |
-
from
|
| 11 |
-
from
|
| 12 |
-
from
|
| 13 |
-
from
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
log = logging.getLogger()
|
| 16 |
|
|
@@ -18,11 +27,21 @@ log = logging.getLogger()
|
|
| 18 |
class InferenceCore:
|
| 19 |
|
| 20 |
def __init__(self,
|
| 21 |
-
network:
|
| 22 |
-
cfg: DictConfig,
|
| 23 |
*,
|
| 24 |
-
image_feature_store: ImageFeatureStore = None
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
self.cfg = cfg
|
| 27 |
self.mem_every = cfg.mem_every
|
| 28 |
stagger_updates = cfg.stagger_updates
|
|
@@ -404,3 +423,128 @@ class InferenceCore:
|
|
| 404 |
new_mask[mask == tmp_id] = obj.id
|
| 405 |
|
| 406 |
return new_mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import logging
|
| 2 |
from omegaconf import DictConfig
|
| 3 |
+
from typing import List, Optional, Iterable, Union,Tuple
|
| 4 |
|
| 5 |
+
import os
|
| 6 |
+
import cv2
|
| 7 |
import torch
|
| 8 |
+
import imageio
|
| 9 |
+
import tempfile
|
| 10 |
+
import numpy as np
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
from PIL import Image
|
| 13 |
import torch.nn.functional as F
|
| 14 |
|
| 15 |
+
from matanyone2.inference.memory_manager import MemoryManager
|
| 16 |
+
from matanyone2.inference.object_manager import ObjectManager
|
| 17 |
+
from matanyone2.inference.image_feature_store import ImageFeatureStore
|
| 18 |
+
from matanyone2.model.matanyone2 import MatAnyone2
|
| 19 |
+
from matanyone2.utils.tensor_utils import pad_divide_by, unpad, aggregate
|
| 20 |
+
from matanyone2.utils.inference_utils import gen_dilate, gen_erosion, read_frame_from_videos
|
| 21 |
+
from matanyone2.utils.device import get_default_device, safe_autocast
|
| 22 |
+
|
| 23 |
|
| 24 |
log = logging.getLogger()
|
| 25 |
|
|
|
|
| 27 |
class InferenceCore:
|
| 28 |
|
| 29 |
def __init__(self,
|
| 30 |
+
network: Union[MatAnyone2,str],
|
| 31 |
+
cfg: DictConfig = None,
|
| 32 |
*,
|
| 33 |
+
image_feature_store: ImageFeatureStore = None,
|
| 34 |
+
device: Optional[Union[str, torch.device]] = None
|
| 35 |
+
):
|
| 36 |
+
if device is None:
|
| 37 |
+
device = get_default_device()
|
| 38 |
+
self.device = device
|
| 39 |
+
if isinstance(network, str):
|
| 40 |
+
network = MatAnyone2.from_pretrained(network)
|
| 41 |
+
network.to(device)
|
| 42 |
+
network.eval()
|
| 43 |
+
self.network = network
|
| 44 |
+
cfg = cfg if cfg is not None else network.cfg
|
| 45 |
self.cfg = cfg
|
| 46 |
self.mem_every = cfg.mem_every
|
| 47 |
stagger_updates = cfg.stagger_updates
|
|
|
|
| 423 |
new_mask[mask == tmp_id] = obj.id
|
| 424 |
|
| 425 |
return new_mask
|
| 426 |
+
|
| 427 |
+
@torch.inference_mode()
|
| 428 |
+
@safe_autocast()
|
| 429 |
+
def process_video(
|
| 430 |
+
self,
|
| 431 |
+
input_path: str,
|
| 432 |
+
mask_path: str,
|
| 433 |
+
output_path: str = None,
|
| 434 |
+
n_warmup: int = 10,
|
| 435 |
+
r_erode: int = 10,
|
| 436 |
+
r_dilate: int = 10,
|
| 437 |
+
suffix: str = "",
|
| 438 |
+
save_image: bool = False,
|
| 439 |
+
max_size: int = -1,
|
| 440 |
+
) -> Tuple:
|
| 441 |
+
"""
|
| 442 |
+
Process a video for object segmentation and matting.
|
| 443 |
+
This method processes a video file by performing object segmentation and matting on each frame.
|
| 444 |
+
It supports warmup frames, mask erosion/dilation, and various output options.
|
| 445 |
+
Args:
|
| 446 |
+
input_path (str): Path to the input video file
|
| 447 |
+
mask_path (str): Path to the mask image file used for initial segmentation
|
| 448 |
+
output_path (str, optional): Directory path where output files will be saved. Defaults to a temporary directory
|
| 449 |
+
n_warmup (int, optional): Number of warmup frames to use. Defaults to 10
|
| 450 |
+
r_erode (int, optional): Erosion radius for mask processing. Defaults to 10
|
| 451 |
+
r_dilate (int, optional): Dilation radius for mask processing. Defaults to 10
|
| 452 |
+
suffix (str, optional): Suffix to append to output filename. Defaults to ""
|
| 453 |
+
save_image (bool, optional): Whether to save individual frames. Defaults to False
|
| 454 |
+
max_size (int, optional): Maximum size for frame dimension. Use -1 for no limit. Defaults to -1
|
| 455 |
+
Returns:
|
| 456 |
+
Tuple[str, str]: A tuple containing:
|
| 457 |
+
- Path to the output foreground video file (str)
|
| 458 |
+
- Path to the output alpha matte video file (str)
|
| 459 |
+
Output:
|
| 460 |
+
- Saves processed video files with foreground (_fgr) and alpha matte (_pha)
|
| 461 |
+
- If save_image=True, saves individual frames in separate directories
|
| 462 |
+
"""
|
| 463 |
+
output_path = output_path if output_path is not None else tempfile.TemporaryDirectory().name
|
| 464 |
+
r_erode = int(r_erode)
|
| 465 |
+
r_dilate = int(r_dilate)
|
| 466 |
+
n_warmup = int(n_warmup)
|
| 467 |
+
max_size = int(max_size)
|
| 468 |
+
|
| 469 |
+
vframes, fps, length, video_name = read_frame_from_videos(input_path)
|
| 470 |
+
repeated_frames = vframes[0].unsqueeze(0).repeat(n_warmup, 1, 1, 1)
|
| 471 |
+
vframes = torch.cat([repeated_frames, vframes], dim=0).float()
|
| 472 |
+
length += n_warmup
|
| 473 |
+
|
| 474 |
+
new_h, new_w = vframes.shape[-2:]
|
| 475 |
+
if max_size > 0:
|
| 476 |
+
h, w = new_h, new_w
|
| 477 |
+
min_side = min(h, w)
|
| 478 |
+
if min_side > max_size:
|
| 479 |
+
new_h = int(h / min_side * max_size)
|
| 480 |
+
new_w = int(w / min_side * max_size)
|
| 481 |
+
vframes = F.interpolate(vframes, size=(new_h, new_w), mode="area")
|
| 482 |
+
|
| 483 |
+
os.makedirs(output_path, exist_ok=True)
|
| 484 |
+
if suffix:
|
| 485 |
+
video_name = f"{video_name}_{suffix}"
|
| 486 |
+
if save_image:
|
| 487 |
+
os.makedirs(f"{output_path}/{video_name}", exist_ok=True)
|
| 488 |
+
os.makedirs(f"{output_path}/{video_name}/pha", exist_ok=True)
|
| 489 |
+
os.makedirs(f"{output_path}/{video_name}/fgr", exist_ok=True)
|
| 490 |
+
|
| 491 |
+
mask = np.array(Image.open(mask_path).convert("L"))
|
| 492 |
+
if r_dilate > 0:
|
| 493 |
+
mask = gen_dilate(mask, r_dilate, r_dilate)
|
| 494 |
+
if r_erode > 0:
|
| 495 |
+
mask = gen_erosion(mask, r_erode, r_erode)
|
| 496 |
+
|
| 497 |
+
mask = torch.from_numpy(mask).float().to(self.device)
|
| 498 |
+
if max_size > 0:
|
| 499 |
+
mask = F.interpolate(
|
| 500 |
+
mask.unsqueeze(0).unsqueeze(0), size=(new_h, new_w), mode="nearest"
|
| 501 |
+
)[0, 0]
|
| 502 |
+
|
| 503 |
+
bgr = (np.array([120, 255, 155], dtype=np.float32) / 255).reshape((1, 1, 3))
|
| 504 |
+
objects = [1]
|
| 505 |
+
|
| 506 |
+
phas = []
|
| 507 |
+
fgrs = []
|
| 508 |
+
for ti in tqdm(range(length)):
|
| 509 |
+
image = vframes[ti]
|
| 510 |
+
image_np = np.array(image.permute(1, 2, 0))
|
| 511 |
+
image = (image / 255.0).float().to(self.device)
|
| 512 |
+
|
| 513 |
+
if ti == 0:
|
| 514 |
+
output_prob = self.step(image, mask, objects=objects)
|
| 515 |
+
output_prob = self.step(image, first_frame_pred=True)
|
| 516 |
+
else:
|
| 517 |
+
if ti <= n_warmup:
|
| 518 |
+
output_prob = self.step(image, first_frame_pred=True)
|
| 519 |
+
else:
|
| 520 |
+
output_prob = self.step(image)
|
| 521 |
+
|
| 522 |
+
mask = self.output_prob_to_mask(output_prob)
|
| 523 |
+
pha = mask.unsqueeze(2).cpu().numpy()
|
| 524 |
+
com_np = image_np / 255.0 * pha + bgr * (1 - pha)
|
| 525 |
+
|
| 526 |
+
if ti > (n_warmup - 1):
|
| 527 |
+
com_np = (com_np * 255).astype(np.uint8)
|
| 528 |
+
pha = (pha * 255).astype(np.uint8)
|
| 529 |
+
fgrs.append(com_np)
|
| 530 |
+
phas.append(pha)
|
| 531 |
+
if save_image:
|
| 532 |
+
cv2.imwrite(
|
| 533 |
+
f"{output_path}/{video_name}/pha/{str(ti - n_warmup).zfill(5)}.png",
|
| 534 |
+
pha,
|
| 535 |
+
)
|
| 536 |
+
cv2.imwrite(
|
| 537 |
+
f"{output_path}/{video_name}/fgr/{str(ti - n_warmup).zfill(5)}.png",
|
| 538 |
+
com_np[..., [2, 1, 0]],
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
fgrs = np.array(fgrs)
|
| 542 |
+
phas = np.array(phas)
|
| 543 |
+
|
| 544 |
+
fgr_filename = f"{output_path}/{video_name}_fgr.mp4"
|
| 545 |
+
alpha_filename = f"{output_path}/{video_name}_pha.mp4"
|
| 546 |
+
|
| 547 |
+
imageio.mimwrite(fgr_filename, fgrs, fps=fps, quality=7)
|
| 548 |
+
imageio.mimwrite(alpha_filename, phas, fps=fps, quality=7)
|
| 549 |
+
|
| 550 |
+
return (fgr_filename,alpha_filename)
|
{matanyone β matanyone2}/inference/kv_memory_store.py
RENAMED
|
File without changes
|
{matanyone β matanyone2}/inference/memory_manager.py
RENAMED
|
@@ -3,10 +3,10 @@ from omegaconf import DictConfig
|
|
| 3 |
from typing import List, Dict
|
| 4 |
import torch
|
| 5 |
|
| 6 |
-
from
|
| 7 |
-
from
|
| 8 |
-
from
|
| 9 |
-
from
|
| 10 |
|
| 11 |
log = logging.getLogger()
|
| 12 |
|
|
@@ -113,7 +113,7 @@ class MemoryManager:
|
|
| 113 |
return value
|
| 114 |
|
| 115 |
def read_first_frame(self, last_msk_value, pix_feat: torch.Tensor,
|
| 116 |
-
last_mask: torch.Tensor, network:
|
| 117 |
"""
|
| 118 |
Read from all memory stores and returns a single memory readout tensor for each object
|
| 119 |
|
|
@@ -166,7 +166,7 @@ class MemoryManager:
|
|
| 166 |
return all_readout_mem
|
| 167 |
|
| 168 |
def read(self, pix_feat: torch.Tensor, query_key: torch.Tensor, selection: torch.Tensor,
|
| 169 |
-
last_mask: torch.Tensor, network:
|
| 170 |
last_pix_feat=None, last_pred_mask=None) -> Dict[int, torch.Tensor]:
|
| 171 |
"""
|
| 172 |
Read from all memory stores and returns a single memory readout tensor for each object
|
|
|
|
| 3 |
from typing import List, Dict
|
| 4 |
import torch
|
| 5 |
|
| 6 |
+
from matanyone2.inference.object_manager import ObjectManager
|
| 7 |
+
from matanyone2.inference.kv_memory_store import KeyValueMemoryStore
|
| 8 |
+
from matanyone2.model.matanyone2 import MatAnyone2
|
| 9 |
+
from matanyone2.model.utils.memory_utils import get_similarity, do_softmax
|
| 10 |
|
| 11 |
log = logging.getLogger()
|
| 12 |
|
|
|
|
| 113 |
return value
|
| 114 |
|
| 115 |
def read_first_frame(self, last_msk_value, pix_feat: torch.Tensor,
|
| 116 |
+
last_mask: torch.Tensor, network: MatAnyone2, uncert_output=None) -> Dict[int, torch.Tensor]:
|
| 117 |
"""
|
| 118 |
Read from all memory stores and returns a single memory readout tensor for each object
|
| 119 |
|
|
|
|
| 166 |
return all_readout_mem
|
| 167 |
|
| 168 |
def read(self, pix_feat: torch.Tensor, query_key: torch.Tensor, selection: torch.Tensor,
|
| 169 |
+
last_mask: torch.Tensor, network: MatAnyone2, uncert_output=None, last_msk_value=None, ti=None,
|
| 170 |
last_pix_feat=None, last_pred_mask=None) -> Dict[int, torch.Tensor]:
|
| 171 |
"""
|
| 172 |
Read from all memory stores and returns a single memory readout tensor for each object
|
{matanyone β matanyone2}/inference/object_info.py
RENAMED
|
File without changes
|
{matanyone β matanyone2}/inference/object_manager.py
RENAMED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
from typing import Union, List, Dict
|
| 2 |
|
| 3 |
import torch
|
| 4 |
-
from
|
| 5 |
|
| 6 |
|
| 7 |
class ObjectManager:
|
|
|
|
| 1 |
from typing import Union, List, Dict
|
| 2 |
|
| 3 |
import torch
|
| 4 |
+
from matanyone2.inference.object_info import ObjectInfo
|
| 5 |
|
| 6 |
|
| 7 |
class ObjectManager:
|
{matanyone/inference β matanyone2/inference/utils}/__init__.py
RENAMED
|
File without changes
|
{matanyone β matanyone2}/inference/utils/args_utils.py
RENAMED
|
File without changes
|
{matanyone/inference/utils β matanyone2/model}/__init__.py
RENAMED
|
File without changes
|
{matanyone β matanyone2}/model/aux_modules.py
RENAMED
|
@@ -6,8 +6,8 @@ from omegaconf import DictConfig
|
|
| 6 |
import torch
|
| 7 |
import torch.nn as nn
|
| 8 |
|
| 9 |
-
from
|
| 10 |
-
from
|
| 11 |
|
| 12 |
|
| 13 |
class LinearPredictor(nn.Module):
|
|
|
|
| 6 |
import torch
|
| 7 |
import torch.nn as nn
|
| 8 |
|
| 9 |
+
from matanyone2.model.group_modules import GConv2d
|
| 10 |
+
from matanyone2.utils.tensor_utils import aggregate
|
| 11 |
|
| 12 |
|
| 13 |
class LinearPredictor(nn.Module):
|
{matanyone β matanyone2}/model/big_modules.py
RENAMED
|
@@ -14,9 +14,10 @@ import torch
|
|
| 14 |
import torch.nn as nn
|
| 15 |
import torch.nn.functional as F
|
| 16 |
|
| 17 |
-
from
|
| 18 |
-
from
|
| 19 |
-
from
|
|
|
|
| 20 |
|
| 21 |
class UncertPred(nn.Module):
|
| 22 |
def __init__(self, model_cfg: DictConfig):
|
|
@@ -330,7 +331,7 @@ class MaskDecoder(nn.Module):
|
|
| 330 |
p4 = self.up_8_4(p8, f4)
|
| 331 |
p2 = self.up_4_2(p4, f2)
|
| 332 |
p1 = self.up_2_1(p2, f1)
|
| 333 |
-
with
|
| 334 |
if seg_pass:
|
| 335 |
if last_mask is not None:
|
| 336 |
res = self.pred_seg(F.relu(p1.flatten(start_dim=0, end_dim=1).float()))
|
|
|
|
| 14 |
import torch.nn as nn
|
| 15 |
import torch.nn.functional as F
|
| 16 |
|
| 17 |
+
from matanyone2.model.group_modules import MainToGroupDistributor, GroupFeatureFusionBlock, GConv2d
|
| 18 |
+
from matanyone2.model.utils import resnet
|
| 19 |
+
from matanyone2.model.modules import SensoryDeepUpdater, SensoryUpdater_fullscale, DecoderFeatureProcessor, MaskUpsampleBlock
|
| 20 |
+
from matanyone2.utils.device import safe_autocast
|
| 21 |
|
| 22 |
class UncertPred(nn.Module):
|
| 23 |
def __init__(self, model_cfg: DictConfig):
|
|
|
|
| 331 |
p4 = self.up_8_4(p8, f4)
|
| 332 |
p2 = self.up_4_2(p4, f2)
|
| 333 |
p1 = self.up_2_1(p2, f1)
|
| 334 |
+
with safe_autocast(enabled=False):
|
| 335 |
if seg_pass:
|
| 336 |
if last_mask is not None:
|
| 337 |
res = self.pred_seg(F.relu(p1.flatten(start_dim=0, end_dim=1).float()))
|
{matanyone β matanyone2}/model/channel_attn.py
RENAMED
|
File without changes
|
{matanyone β matanyone2}/model/group_modules.py
RENAMED
|
@@ -2,7 +2,7 @@ from typing import Optional
|
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
| 4 |
import torch.nn.functional as F
|
| 5 |
-
from
|
| 6 |
|
| 7 |
def interpolate_groups(g: torch.Tensor, ratio: float, mode: str,
|
| 8 |
align_corners: bool) -> torch.Tensor:
|
|
|
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
| 4 |
import torch.nn.functional as F
|
| 5 |
+
from matanyone2.model.channel_attn import CAResBlock
|
| 6 |
|
| 7 |
def interpolate_groups(g: torch.Tensor, ratio: float, mode: str,
|
| 8 |
align_corners: bool) -> torch.Tensor:
|
matanyone/model/matanyone.py β matanyone2/model/matanyone2.py
RENAMED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from typing import List, Dict, Iterable
|
| 2 |
import logging
|
| 3 |
from omegaconf import DictConfig
|
| 4 |
import torch
|
|
@@ -7,18 +7,21 @@ import torch.nn.functional as F
|
|
| 7 |
from omegaconf import OmegaConf
|
| 8 |
from huggingface_hub import PyTorchModelHubMixin
|
| 9 |
|
| 10 |
-
from
|
| 11 |
-
from
|
| 12 |
-
from
|
| 13 |
-
from
|
| 14 |
-
from
|
| 15 |
-
from
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
log = logging.getLogger()
|
| 18 |
-
class
|
| 19 |
PyTorchModelHubMixin,
|
| 20 |
-
library_name="
|
| 21 |
-
repo_url="https://github.com/pq-yang/
|
| 22 |
coders={
|
| 23 |
DictConfig: (
|
| 24 |
lambda x: OmegaConf.to_container(x),
|
|
@@ -83,6 +86,8 @@ class MatAnyone(nn.Module,
|
|
| 83 |
return uncert_output
|
| 84 |
|
| 85 |
def encode_image(self, image: torch.Tensor, seq_length=None, last_feats=None) -> (Iterable[torch.Tensor], torch.Tensor): # type: ignore
|
|
|
|
|
|
|
| 86 |
image = (image - self.pixel_mean) / self.pixel_std
|
| 87 |
ms_image_feat = self.pixel_encoder(image, seq_length) # f16, f8, f4, f2, f1
|
| 88 |
return ms_image_feat, self.pix_feat_proj(ms_image_feat[0])
|
|
@@ -96,7 +101,7 @@ class MatAnyone(nn.Module,
|
|
| 96 |
*,
|
| 97 |
deep_update: bool = True,
|
| 98 |
chunk_size: int = -1,
|
| 99 |
-
need_weights: bool = False) ->
|
| 100 |
image = (image - self.pixel_mean) / self.pixel_std
|
| 101 |
others = self._get_others(masks)
|
| 102 |
mask_value, new_sensory = self.mask_encoder(image,
|
|
@@ -113,7 +118,7 @@ class MatAnyone(nn.Module,
|
|
| 113 |
final_pix_feat: torch.Tensor,
|
| 114 |
*,
|
| 115 |
need_sk: bool = True,
|
| 116 |
-
need_ek: bool = True) ->
|
| 117 |
key, shrinkage, selection = self.key_proj(final_pix_feat, need_s=need_sk, need_e=need_ek)
|
| 118 |
return key, shrinkage, selection
|
| 119 |
|
|
@@ -124,7 +129,7 @@ class MatAnyone(nn.Module,
|
|
| 124 |
msk_value: torch.Tensor, obj_memory: torch.Tensor, pix_feat: torch.Tensor,
|
| 125 |
sensory: torch.Tensor, last_mask: torch.Tensor,
|
| 126 |
selector: torch.Tensor, uncert_output=None, seg_pass=False,
|
| 127 |
-
last_pix_feat=None, last_pred_mask=None) ->
|
| 128 |
"""
|
| 129 |
query_key : B * CK * H * W
|
| 130 |
query_selection : B * CK * H * W
|
|
@@ -139,7 +144,7 @@ class MatAnyone(nn.Module,
|
|
| 139 |
uncert_mask = uncert_output["mask"] if uncert_output is not None else None
|
| 140 |
|
| 141 |
# read using visual attention
|
| 142 |
-
with
|
| 143 |
affinity = get_affinity(memory_key.float(), memory_shrinkage.float(), query_key.float(),
|
| 144 |
query_selection.float(), uncert_mask=uncert_mask)
|
| 145 |
|
|
@@ -171,7 +176,7 @@ class MatAnyone(nn.Module,
|
|
| 171 |
def read_first_frame_memory(self, pixel_readout,
|
| 172 |
obj_memory: torch.Tensor, pix_feat: torch.Tensor,
|
| 173 |
sensory: torch.Tensor, last_mask: torch.Tensor,
|
| 174 |
-
selector: torch.Tensor, seg_pass=False) ->
|
| 175 |
"""
|
| 176 |
query_key : B * CK * H * W
|
| 177 |
query_selection : B * CK * H * W
|
|
@@ -218,7 +223,7 @@ class MatAnyone(nn.Module,
|
|
| 218 |
*,
|
| 219 |
selector=None,
|
| 220 |
need_weights=False,
|
| 221 |
-
seg_pass=False) ->
|
| 222 |
return self.object_transformer(pixel_readout,
|
| 223 |
obj_memory,
|
| 224 |
selector=selector,
|
|
@@ -237,7 +242,7 @@ class MatAnyone(nn.Module,
|
|
| 237 |
clamp_mat: bool = True,
|
| 238 |
last_mask=None,
|
| 239 |
sigmoid_residual=False,
|
| 240 |
-
seg_mat=False) ->
|
| 241 |
"""
|
| 242 |
multi_scale_features is from the key encoder for skip-connection
|
| 243 |
memory_readout is from working/long-term memory
|
|
|
|
| 1 |
+
from typing import List, Dict, Iterable, Tuple
|
| 2 |
import logging
|
| 3 |
from omegaconf import DictConfig
|
| 4 |
import torch
|
|
|
|
| 7 |
from omegaconf import OmegaConf
|
| 8 |
from huggingface_hub import PyTorchModelHubMixin
|
| 9 |
|
| 10 |
+
from matanyone2.model.big_modules import PixelEncoder, UncertPred, KeyProjection, MaskEncoder, PixelFeatureFuser, MaskDecoder
|
| 11 |
+
from matanyone2.model.aux_modules import AuxComputer
|
| 12 |
+
from matanyone2.model.utils.memory_utils import get_affinity, readout
|
| 13 |
+
from matanyone2.model.transformer.object_transformer import QueryTransformer
|
| 14 |
+
from matanyone2.model.transformer.object_summarizer import ObjectSummarizer
|
| 15 |
+
from matanyone2.utils.tensor_utils import aggregate
|
| 16 |
+
from matanyone2.utils.device import get_default_device, safe_autocast
|
| 17 |
+
|
| 18 |
+
device = get_default_device()
|
| 19 |
|
| 20 |
log = logging.getLogger()
|
| 21 |
+
class MatAnyone2(nn.Module,
|
| 22 |
PyTorchModelHubMixin,
|
| 23 |
+
library_name="matanyone2",
|
| 24 |
+
repo_url="https://github.com/pq-yang/MatAnyone2",
|
| 25 |
coders={
|
| 26 |
DictConfig: (
|
| 27 |
lambda x: OmegaConf.to_container(x),
|
|
|
|
| 86 |
return uncert_output
|
| 87 |
|
| 88 |
def encode_image(self, image: torch.Tensor, seq_length=None, last_feats=None) -> (Iterable[torch.Tensor], torch.Tensor): # type: ignore
|
| 89 |
+
self.pixel_mean = self.pixel_mean.to(device)
|
| 90 |
+
self.pixel_std = self.pixel_std.to(device)
|
| 91 |
image = (image - self.pixel_mean) / self.pixel_std
|
| 92 |
ms_image_feat = self.pixel_encoder(image, seq_length) # f16, f8, f4, f2, f1
|
| 93 |
return ms_image_feat, self.pix_feat_proj(ms_image_feat[0])
|
|
|
|
| 101 |
*,
|
| 102 |
deep_update: bool = True,
|
| 103 |
chunk_size: int = -1,
|
| 104 |
+
need_weights: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 105 |
image = (image - self.pixel_mean) / self.pixel_std
|
| 106 |
others = self._get_others(masks)
|
| 107 |
mask_value, new_sensory = self.mask_encoder(image,
|
|
|
|
| 118 |
final_pix_feat: torch.Tensor,
|
| 119 |
*,
|
| 120 |
need_sk: bool = True,
|
| 121 |
+
need_ek: bool = True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 122 |
key, shrinkage, selection = self.key_proj(final_pix_feat, need_s=need_sk, need_e=need_ek)
|
| 123 |
return key, shrinkage, selection
|
| 124 |
|
|
|
|
| 129 |
msk_value: torch.Tensor, obj_memory: torch.Tensor, pix_feat: torch.Tensor,
|
| 130 |
sensory: torch.Tensor, last_mask: torch.Tensor,
|
| 131 |
selector: torch.Tensor, uncert_output=None, seg_pass=False,
|
| 132 |
+
last_pix_feat=None, last_pred_mask=None) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
| 133 |
"""
|
| 134 |
query_key : B * CK * H * W
|
| 135 |
query_selection : B * CK * H * W
|
|
|
|
| 144 |
uncert_mask = uncert_output["mask"] if uncert_output is not None else None
|
| 145 |
|
| 146 |
# read using visual attention
|
| 147 |
+
with safe_autocast(enabled=False):
|
| 148 |
affinity = get_affinity(memory_key.float(), memory_shrinkage.float(), query_key.float(),
|
| 149 |
query_selection.float(), uncert_mask=uncert_mask)
|
| 150 |
|
|
|
|
| 176 |
def read_first_frame_memory(self, pixel_readout,
|
| 177 |
obj_memory: torch.Tensor, pix_feat: torch.Tensor,
|
| 178 |
sensory: torch.Tensor, last_mask: torch.Tensor,
|
| 179 |
+
selector: torch.Tensor, seg_pass=False) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
| 180 |
"""
|
| 181 |
query_key : B * CK * H * W
|
| 182 |
query_selection : B * CK * H * W
|
|
|
|
| 223 |
*,
|
| 224 |
selector=None,
|
| 225 |
need_weights=False,
|
| 226 |
+
seg_pass=False) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
| 227 |
return self.object_transformer(pixel_readout,
|
| 228 |
obj_memory,
|
| 229 |
selector=selector,
|
|
|
|
| 242 |
clamp_mat: bool = True,
|
| 243 |
last_mask=None,
|
| 244 |
sigmoid_residual=False,
|
| 245 |
+
seg_mat=False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 246 |
"""
|
| 247 |
multi_scale_features is from the key encoder for skip-connection
|
| 248 |
memory_readout is from working/long-term memory
|
{matanyone β matanyone2}/model/modules.py
RENAMED
|
@@ -3,7 +3,8 @@ import torch
|
|
| 3 |
import torch.nn as nn
|
| 4 |
import torch.nn.functional as F
|
| 5 |
|
| 6 |
-
from
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
class UpsampleBlock(nn.Module):
|
|
@@ -78,7 +79,7 @@ class SensoryUpdater_fullscale(nn.Module):
|
|
| 78 |
self.g2_conv(downsample_groups(g[3], ratio=1/8)) + \
|
| 79 |
self.g1_conv(downsample_groups(g[4], ratio=1/16))
|
| 80 |
|
| 81 |
-
with
|
| 82 |
g = g.float()
|
| 83 |
h = h.float()
|
| 84 |
values = self.transform(torch.cat([g, h], dim=2))
|
|
@@ -102,7 +103,7 @@ class SensoryUpdater(nn.Module):
|
|
| 102 |
g = self.g16_conv(g[0]) + self.g8_conv(downsample_groups(g[1], ratio=1/2)) + \
|
| 103 |
self.g4_conv(downsample_groups(g[2], ratio=1/4))
|
| 104 |
|
| 105 |
-
with
|
| 106 |
g = g.float()
|
| 107 |
h = h.float()
|
| 108 |
values = self.transform(torch.cat([g, h], dim=2))
|
|
@@ -119,7 +120,7 @@ class SensoryDeepUpdater(nn.Module):
|
|
| 119 |
nn.init.xavier_normal_(self.transform.weight)
|
| 120 |
|
| 121 |
def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
|
| 122 |
-
with
|
| 123 |
g = g.float()
|
| 124 |
h = h.float()
|
| 125 |
values = self.transform(torch.cat([g, h], dim=2))
|
|
@@ -146,4 +147,4 @@ class ResBlock(nn.Module):
|
|
| 146 |
|
| 147 |
g = self.downsample(g)
|
| 148 |
|
| 149 |
-
return out_g + g
|
|
|
|
| 3 |
import torch.nn as nn
|
| 4 |
import torch.nn.functional as F
|
| 5 |
|
| 6 |
+
from matanyone2.model.group_modules import MainToGroupDistributor, GroupResBlock, upsample_groups, GConv2d, downsample_groups
|
| 7 |
+
from matanyone2.utils.device import safe_autocast
|
| 8 |
|
| 9 |
|
| 10 |
class UpsampleBlock(nn.Module):
|
|
|
|
| 79 |
self.g2_conv(downsample_groups(g[3], ratio=1/8)) + \
|
| 80 |
self.g1_conv(downsample_groups(g[4], ratio=1/16))
|
| 81 |
|
| 82 |
+
with safe_autocast(enabled=False):
|
| 83 |
g = g.float()
|
| 84 |
h = h.float()
|
| 85 |
values = self.transform(torch.cat([g, h], dim=2))
|
|
|
|
| 103 |
g = self.g16_conv(g[0]) + self.g8_conv(downsample_groups(g[1], ratio=1/2)) + \
|
| 104 |
self.g4_conv(downsample_groups(g[2], ratio=1/4))
|
| 105 |
|
| 106 |
+
with safe_autocast(enabled=False):
|
| 107 |
g = g.float()
|
| 108 |
h = h.float()
|
| 109 |
values = self.transform(torch.cat([g, h], dim=2))
|
|
|
|
| 120 |
nn.init.xavier_normal_(self.transform.weight)
|
| 121 |
|
| 122 |
def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
|
| 123 |
+
with safe_autocast(enabled=False):
|
| 124 |
g = g.float()
|
| 125 |
h = h.float()
|
| 126 |
values = self.transform(torch.cat([g, h], dim=2))
|
|
|
|
| 147 |
|
| 148 |
g = self.downsample(g)
|
| 149 |
|
| 150 |
+
return out_g + g
|
{matanyone/model β matanyone2/model/transformer}/__init__.py
RENAMED
|
File without changes
|
{matanyone β matanyone2}/model/transformer/object_summarizer.py
RENAMED
|
@@ -4,7 +4,8 @@ from omegaconf import DictConfig
|
|
| 4 |
import torch
|
| 5 |
import torch.nn as nn
|
| 6 |
import torch.nn.functional as F
|
| 7 |
-
from
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
# @torch.jit.script
|
|
@@ -75,7 +76,7 @@ class ObjectSummarizer(nn.Module):
|
|
| 75 |
pe = self.pos_enc(value)
|
| 76 |
value = value + pe
|
| 77 |
|
| 78 |
-
with
|
| 79 |
value = value.float()
|
| 80 |
feature = self.feature_pred(value)
|
| 81 |
logits = self.weights_pred(value)
|
|
@@ -86,4 +87,4 @@ class ObjectSummarizer(nn.Module):
|
|
| 86 |
if need_weights:
|
| 87 |
return summaries, logits
|
| 88 |
else:
|
| 89 |
-
return summaries, None
|
|
|
|
| 4 |
import torch
|
| 5 |
import torch.nn as nn
|
| 6 |
import torch.nn.functional as F
|
| 7 |
+
from matanyone2.model.transformer.positional_encoding import PositionalEncoding
|
| 8 |
+
from matanyone2.utils.device import safe_autocast
|
| 9 |
|
| 10 |
|
| 11 |
# @torch.jit.script
|
|
|
|
| 76 |
pe = self.pos_enc(value)
|
| 77 |
value = value + pe
|
| 78 |
|
| 79 |
+
with safe_autocast(enabled=False): # autocast disabled intentionally
|
| 80 |
value = value.float()
|
| 81 |
feature = self.feature_pred(value)
|
| 82 |
logits = self.weights_pred(value)
|
|
|
|
| 87 |
if need_weights:
|
| 88 |
return summaries, logits
|
| 89 |
else:
|
| 90 |
+
return summaries, None
|
{matanyone β matanyone2}/model/transformer/object_transformer.py
RENAMED
|
@@ -3,10 +3,10 @@ from omegaconf import DictConfig
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.nn as nn
|
| 6 |
-
from
|
| 7 |
-
from
|
| 8 |
-
from
|
| 9 |
-
from
|
| 10 |
|
| 11 |
|
| 12 |
class QueryTransformerBlock(nn.Module):
|
|
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.nn as nn
|
| 6 |
+
from matanyone2.model.group_modules import GConv2d
|
| 7 |
+
from matanyone2.utils.tensor_utils import aggregate
|
| 8 |
+
from matanyone2.model.transformer.positional_encoding import PositionalEncoding
|
| 9 |
+
from matanyone2.model.transformer.transformer_layers import CrossAttention, SelfAttention, FFN, PixelFFN
|
| 10 |
|
| 11 |
|
| 12 |
class QueryTransformerBlock(nn.Module):
|
{matanyone β matanyone2}/model/transformer/positional_encoding.py
RENAMED
|
@@ -7,6 +7,7 @@ import math
|
|
| 7 |
import numpy as np
|
| 8 |
import torch
|
| 9 |
from torch import nn
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
def get_emb(sin_inp: torch.Tensor) -> torch.Tensor:
|
|
@@ -98,8 +99,9 @@ class PositionalEncoding(nn.Module):
|
|
| 98 |
|
| 99 |
|
| 100 |
if __name__ == '__main__':
|
| 101 |
-
|
| 102 |
-
|
|
|
|
| 103 |
output = pe(input)
|
| 104 |
# print(output)
|
| 105 |
print(output[0, :, 0, 0])
|
|
|
|
| 7 |
import numpy as np
|
| 8 |
import torch
|
| 9 |
from torch import nn
|
| 10 |
+
from matanyone2.utils.device import get_default_device
|
| 11 |
|
| 12 |
|
| 13 |
def get_emb(sin_inp: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 99 |
|
| 100 |
|
| 101 |
if __name__ == '__main__':
|
| 102 |
+
device = get_default_device()
|
| 103 |
+
pe = PositionalEncoding(8).to(device)
|
| 104 |
+
input = torch.ones((1, 8, 8, 8), device=device)
|
| 105 |
output = pe(input)
|
| 106 |
# print(output)
|
| 107 |
print(output[0, :, 0, 0])
|
{matanyone β matanyone2}/model/transformer/transformer_layers.py
RENAMED
|
@@ -6,7 +6,7 @@ import torch
|
|
| 6 |
from torch import Tensor
|
| 7 |
import torch.nn as nn
|
| 8 |
import torch.nn.functional as F
|
| 9 |
-
from
|
| 10 |
|
| 11 |
|
| 12 |
class SelfAttention(nn.Module):
|
|
|
|
| 6 |
from torch import Tensor
|
| 7 |
import torch.nn as nn
|
| 8 |
import torch.nn.functional as F
|
| 9 |
+
from matanyone2.model.channel_attn import CAResBlock
|
| 10 |
|
| 11 |
|
| 12 |
class SelfAttention(nn.Module):
|
{matanyone/model/transformer β matanyone2/model/utils}/__init__.py
RENAMED
|
File without changes
|
{matanyone β matanyone2}/model/utils/memory_utils.py
RENAMED
|
File without changes
|
{matanyone β matanyone2}/model/utils/parameter_groups.py
RENAMED
|
File without changes
|
{matanyone β matanyone2}/model/utils/resnet.py
RENAMED
|
File without changes
|
{matanyone/model β matanyone2}/utils/__init__.py
RENAMED
|
File without changes
|
matanyone2/utils/device.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import functools
|
| 3 |
+
|
| 4 |
+
def get_default_device():
|
| 5 |
+
if torch.cuda.is_available():
|
| 6 |
+
return torch.device("cuda")
|
| 7 |
+
elif torch.backends.mps.is_built() and torch.backends.mps.is_available():
|
| 8 |
+
return torch.device("mps")
|
| 9 |
+
else:
|
| 10 |
+
return torch.device("cpu")
|
| 11 |
+
|
| 12 |
+
def safe_autocast_decorator(enabled=True):
|
| 13 |
+
def decorator(func):
|
| 14 |
+
@functools.wraps(func)
|
| 15 |
+
def wrapper(*args, **kwargs):
|
| 16 |
+
device = get_default_device()
|
| 17 |
+
if device.type in ["cuda", "cpu"]:
|
| 18 |
+
with torch.amp.autocast(device_type=device.type, enabled=enabled):
|
| 19 |
+
return func(*args, **kwargs)
|
| 20 |
+
else:
|
| 21 |
+
return func(*args, **kwargs)
|
| 22 |
+
return wrapper
|
| 23 |
+
return decorator
|
| 24 |
+
|
| 25 |
+
import contextlib
|
| 26 |
+
@contextlib.contextmanager
|
| 27 |
+
def safe_autocast(enabled=True):
|
| 28 |
+
device = get_default_device()
|
| 29 |
+
if device.type in ["cuda", "cpu"]:
|
| 30 |
+
with torch.amp.autocast(device_type=device.type, enabled=enabled):
|
| 31 |
+
yield
|
| 32 |
+
else:
|
| 33 |
+
yield # MPS or other unsupported backends skip autocast
|
{matanyone β matanyone2}/utils/get_default_model.py
RENAMED
|
@@ -5,9 +5,9 @@ from omegaconf import open_dict
|
|
| 5 |
from hydra import compose, initialize
|
| 6 |
|
| 7 |
import torch
|
| 8 |
-
from
|
| 9 |
|
| 10 |
-
def
|
| 11 |
initialize(version_base='1.3.2', config_path="../config", job_name="eval_our_config")
|
| 12 |
cfg = compose(config_name="eval_matanyone_config")
|
| 13 |
|
|
@@ -16,12 +16,12 @@ def get_matanyone_model(ckpt_path, device=None) -> MatAnyone:
|
|
| 16 |
|
| 17 |
# Load the network weights
|
| 18 |
if device is not None:
|
| 19 |
-
|
| 20 |
model_weights = torch.load(cfg.weights, map_location=device)
|
| 21 |
else: # if device is not specified, `.cuda()` by default
|
| 22 |
-
|
| 23 |
model_weights = torch.load(cfg.weights)
|
| 24 |
|
| 25 |
-
|
| 26 |
|
| 27 |
-
return
|
|
|
|
| 5 |
from hydra import compose, initialize
|
| 6 |
|
| 7 |
import torch
|
| 8 |
+
from matanyone2.model.matanyone2 import MatAnyone2
|
| 9 |
|
| 10 |
+
def get_matanyone2_model(ckpt_path, device=None) -> MatAnyone2:
|
| 11 |
initialize(version_base='1.3.2', config_path="../config", job_name="eval_our_config")
|
| 12 |
cfg = compose(config_name="eval_matanyone_config")
|
| 13 |
|
|
|
|
| 16 |
|
| 17 |
# Load the network weights
|
| 18 |
if device is not None:
|
| 19 |
+
matanyone2 = MatAnyone2(cfg, single_object=True).to(device).eval()
|
| 20 |
model_weights = torch.load(cfg.weights, map_location=device)
|
| 21 |
else: # if device is not specified, `.cuda()` by default
|
| 22 |
+
matanyone2 = MatAnyone2(cfg, single_object=True).cuda().eval()
|
| 23 |
model_weights = torch.load(cfg.weights)
|
| 24 |
|
| 25 |
+
matanyone2.load_weights(model_weights)
|
| 26 |
|
| 27 |
+
return matanyone2
|
matanyone2/utils/inference_utils.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import random
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torchvision
|
| 8 |
+
|
| 9 |
+
IMAGE_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG')
|
| 10 |
+
VIDEO_EXTENSIONS = ('.mp4', '.mov', '.avi', '.MP4', '.MOV', '.AVI')
|
| 11 |
+
|
| 12 |
+
def read_frame_from_videos(frame_root):
|
| 13 |
+
if frame_root.endswith(VIDEO_EXTENSIONS): # Video file path
|
| 14 |
+
video_name = os.path.basename(frame_root)[:-4]
|
| 15 |
+
frames, _, info = torchvision.io.read_video(filename=frame_root, pts_unit='sec', output_format='TCHW') # RGB
|
| 16 |
+
fps = info['video_fps']
|
| 17 |
+
else:
|
| 18 |
+
video_name = os.path.basename(frame_root)
|
| 19 |
+
frames = []
|
| 20 |
+
fr_lst = sorted(os.listdir(frame_root))
|
| 21 |
+
for fr in fr_lst:
|
| 22 |
+
frame = cv2.imread(os.path.join(frame_root, fr))[...,[2,1,0]] # RGB, HWC
|
| 23 |
+
frames.append(frame)
|
| 24 |
+
fps = 24 # default
|
| 25 |
+
frames = torch.Tensor(np.array(frames)).permute(0, 3, 1, 2).contiguous() # TCHW
|
| 26 |
+
|
| 27 |
+
length = frames.shape[0]
|
| 28 |
+
|
| 29 |
+
return frames, fps, length, video_name
|
| 30 |
+
|
| 31 |
+
def get_video_paths(input_root):
|
| 32 |
+
video_paths = []
|
| 33 |
+
for root, _, files in os.walk(input_root):
|
| 34 |
+
for file in files:
|
| 35 |
+
if file.lower().endswith(VIDEO_EXTENSIONS):
|
| 36 |
+
video_paths.append(os.path.join(root, file))
|
| 37 |
+
return sorted(video_paths)
|
| 38 |
+
|
| 39 |
+
def str_to_list(value):
|
| 40 |
+
return list(map(int, value.split(',')))
|
| 41 |
+
|
| 42 |
+
def gen_dilate(alpha, min_kernel_size, max_kernel_size):
|
| 43 |
+
kernel_size = random.randint(min_kernel_size, max_kernel_size)
|
| 44 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size,kernel_size))
|
| 45 |
+
fg_and_unknown = np.array(np.not_equal(alpha, 0).astype(np.float32))
|
| 46 |
+
dilate = cv2.dilate(fg_and_unknown, kernel, iterations=1)*255
|
| 47 |
+
return dilate.astype(np.float32)
|
| 48 |
+
|
| 49 |
+
def gen_erosion(alpha, min_kernel_size, max_kernel_size):
|
| 50 |
+
kernel_size = random.randint(min_kernel_size, max_kernel_size)
|
| 51 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size,kernel_size))
|
| 52 |
+
fg = np.array(np.equal(alpha, 255).astype(np.float32))
|
| 53 |
+
erode = cv2.erode(fg, kernel, iterations=1)*255
|
| 54 |
+
return erode.astype(np.float32)
|
{matanyone β matanyone2}/utils/tensor_utils.py
RENAMED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
from typing import List, Iterable
|
| 2 |
import torch
|
| 3 |
import torch.nn.functional as F
|
| 4 |
-
|
| 5 |
|
| 6 |
# STM
|
| 7 |
def pad_divide_by(in_img: torch.Tensor, d: int) -> (torch.Tensor, Iterable[int]):
|
|
@@ -45,7 +45,7 @@ def unpad(img: torch.Tensor, pad: Iterable[int]) -> torch.Tensor:
|
|
| 45 |
|
| 46 |
# @torch.jit.script
|
| 47 |
def aggregate(prob: torch.Tensor, dim: int) -> torch.Tensor:
|
| 48 |
-
with
|
| 49 |
prob = prob.float()
|
| 50 |
new_prob = torch.cat([torch.prod(1 - prob, dim=dim, keepdim=True), prob],
|
| 51 |
dim).clamp(1e-7, 1 - 1e-7)
|
|
@@ -59,4 +59,4 @@ def cls_to_one_hot(cls_gt: torch.Tensor, num_objects: int) -> torch.Tensor:
|
|
| 59 |
# cls_gt: B*1*H*W
|
| 60 |
B, _, H, W = cls_gt.shape
|
| 61 |
one_hot = torch.zeros(B, num_objects + 1, H, W, device=cls_gt.device).scatter_(1, cls_gt, 1)
|
| 62 |
-
return one_hot
|
|
|
|
| 1 |
from typing import List, Iterable
|
| 2 |
import torch
|
| 3 |
import torch.nn.functional as F
|
| 4 |
+
from matanyone2.utils.device import safe_autocast
|
| 5 |
|
| 6 |
# STM
|
| 7 |
def pad_divide_by(in_img: torch.Tensor, d: int) -> (torch.Tensor, Iterable[int]):
|
|
|
|
| 45 |
|
| 46 |
# @torch.jit.script
|
| 47 |
def aggregate(prob: torch.Tensor, dim: int) -> torch.Tensor:
|
| 48 |
+
with safe_autocast(enabled=False):
|
| 49 |
prob = prob.float()
|
| 50 |
new_prob = torch.cat([torch.prod(1 - prob, dim=dim, keepdim=True), prob],
|
| 51 |
dim).clamp(1e-7, 1 - 1e-7)
|
|
|
|
| 59 |
# cls_gt: B*1*H*W
|
| 60 |
B, _, H, W = cls_gt.shape
|
| 61 |
one_hot = torch.zeros(B, num_objects + 1, H, W, device=cls_gt.device).scatter_(1, cls_gt, 1)
|
| 62 |
+
return one_hot
|