PeiqingYang commited on
Commit
77b7c2e
Β·
1 Parent(s): d6e7745

integrate MatAnyone 1 & 2

Browse files
Files changed (40) hide show
  1. README.md +1 -1
  2. hugging_face/app.py +159 -49
  3. hugging_face/{matanyone_wrapper.py β†’ matanyone2_wrapper.py} +10 -6
  4. matanyone/utils/__init__.py +0 -0
  5. matanyone2/__init__.py +2 -0
  6. {matanyone β†’ matanyone2/config}/__init__.py +0 -0
  7. {matanyone β†’ matanyone2}/config/eval_matanyone_config.yaml +1 -1
  8. {matanyone β†’ matanyone2}/config/hydra/job_logging/custom-no-rank.yaml +0 -0
  9. {matanyone β†’ matanyone2}/config/hydra/job_logging/custom.yaml +0 -0
  10. {matanyone β†’ matanyone2}/config/model/base.yaml +0 -0
  11. {matanyone/config β†’ matanyone2/inference}/__init__.py +0 -0
  12. {matanyone β†’ matanyone2}/inference/image_feature_store.py +2 -2
  13. {matanyone β†’ matanyone2}/inference/inference_core.py +155 -11
  14. {matanyone β†’ matanyone2}/inference/kv_memory_store.py +0 -0
  15. {matanyone β†’ matanyone2}/inference/memory_manager.py +6 -6
  16. {matanyone β†’ matanyone2}/inference/object_info.py +0 -0
  17. {matanyone β†’ matanyone2}/inference/object_manager.py +1 -1
  18. {matanyone/inference β†’ matanyone2/inference/utils}/__init__.py +0 -0
  19. {matanyone β†’ matanyone2}/inference/utils/args_utils.py +0 -0
  20. {matanyone/inference/utils β†’ matanyone2/model}/__init__.py +0 -0
  21. {matanyone β†’ matanyone2}/model/aux_modules.py +2 -2
  22. {matanyone β†’ matanyone2}/model/big_modules.py +5 -4
  23. {matanyone β†’ matanyone2}/model/channel_attn.py +0 -0
  24. {matanyone β†’ matanyone2}/model/group_modules.py +1 -1
  25. matanyone/model/matanyone.py β†’ matanyone2/model/matanyone2.py +22 -17
  26. {matanyone β†’ matanyone2}/model/modules.py +6 -5
  27. {matanyone/model β†’ matanyone2/model/transformer}/__init__.py +0 -0
  28. {matanyone β†’ matanyone2}/model/transformer/object_summarizer.py +4 -3
  29. {matanyone β†’ matanyone2}/model/transformer/object_transformer.py +4 -4
  30. {matanyone β†’ matanyone2}/model/transformer/positional_encoding.py +4 -2
  31. {matanyone β†’ matanyone2}/model/transformer/transformer_layers.py +1 -1
  32. {matanyone/model/transformer β†’ matanyone2/model/utils}/__init__.py +0 -0
  33. {matanyone β†’ matanyone2}/model/utils/memory_utils.py +0 -0
  34. {matanyone β†’ matanyone2}/model/utils/parameter_groups.py +0 -0
  35. {matanyone β†’ matanyone2}/model/utils/resnet.py +0 -0
  36. {matanyone/model β†’ matanyone2}/utils/__init__.py +0 -0
  37. matanyone2/utils/device.py +33 -0
  38. {matanyone β†’ matanyone2}/utils/get_default_model.py +6 -6
  39. matanyone2/utils/inference_utils.py +54 -0
  40. {matanyone β†’ matanyone2}/utils/tensor_utils.py +3 -3
README.md CHANGED
@@ -8,7 +8,7 @@ sdk_version: 5.16.0
8
  app_file: hugging_face/app.py
9
  pinned: false
10
  license: other
11
- short_description: Gradio demo for MatAnyone
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
8
  app_file: hugging_face/app.py
9
  pinned: false
10
  license: other
11
+ short_description: Gradio demo for MatAnyone 1 & 2
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
hugging_face/app.py CHANGED
@@ -21,9 +21,13 @@ from tools.interact_tools import SamControler
21
  from tools.misc import get_device
22
  from tools.download_util import load_file_from_url
23
 
24
- from matanyone_wrapper import matanyone
25
- from matanyone.utils.get_default_model import get_matanyone_model
26
- from matanyone.inference.inference_core import InferenceCore
 
 
 
 
27
 
28
  def parse_augment():
29
  parser = argparse.ArgumentParser()
@@ -121,7 +125,6 @@ def get_frames_from_video(video_input, video_state):
121
  except Exception as e:
122
  print(f"Audio extraction error: {str(e)}")
123
  audio_path = "" # Set to "" if extraction fails
124
- # print(f'audio_path: {audio_path}')
125
 
126
  # extract frames
127
  try:
@@ -140,8 +143,8 @@ def get_frames_from_video(video_input, video_state):
140
  print("read_frame_source:{} error. {}\n".format(video_path, str(e)))
141
  image_size = (frames[0].shape[0],frames[0].shape[1])
142
 
143
- # resize if resolution too big
144
- if image_size[0]>=1280 and image_size[0]>=1280:
145
  scale = 1080 / min(image_size)
146
  new_w = int(image_size[1] * scale)
147
  new_h = int(image_size[0] * scale)
@@ -165,8 +168,7 @@ def get_frames_from_video(video_input, video_state):
165
  video_info = "Video Name: {},\nFPS: {},\nTotal Frames: {},\nImage Size:{}".format(video_state["video_name"], round(video_state["fps"], 0), len(frames), image_size)
166
  model.samcontroler.sam_controler.reset_image()
167
  model.samcontroler.sam_controler.set_image(video_state["origin_images"][0])
168
- return video_state, video_info, video_state["origin_images"][0], \
169
- gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=False, maximum=len(frames), value=len(frames)), \
170
  gr.update(visible=True), gr.update(visible=True), \
171
  gr.update(visible=True), gr.update(visible=True),\
172
  gr.update(visible=True), gr.update(visible=True), \
@@ -267,8 +269,18 @@ def show_mask(video_state, interactive_state, mask_dropdown):
267
  return select_frame
268
 
269
  # image matting
270
- def image_matting(video_state, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size, refine_iter):
271
- matanyone_processor = InferenceCore(matanyone_model, cfg=matanyone_model.cfg)
 
 
 
 
 
 
 
 
 
 
272
  if interactive_state["track_end_number"]:
273
  following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
274
  else:
@@ -289,14 +301,25 @@ def image_matting(video_state, interactive_state, mask_dropdown, erode_kernel_si
289
  # operation error
290
  if len(np.unique(template_mask))==1:
291
  template_mask[0][0]=1
292
- foreground, alpha = matanyone(matanyone_processor, following_frames, template_mask*255, r_erode=erode_kernel_size, r_dilate=dilate_kernel_size, n_warmup=refine_iter)
293
  foreground_output = Image.fromarray(foreground[-1])
294
  alpha_output = Image.fromarray(alpha[-1][:,:,0])
 
295
  return foreground_output, alpha_output
296
 
297
  # video matting
298
- def video_matting(video_state, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size):
299
- matanyone_processor = InferenceCore(matanyone_model, cfg=matanyone_model.cfg)
 
 
 
 
 
 
 
 
 
 
300
  if interactive_state["track_end_number"]:
301
  following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
302
  else:
@@ -320,11 +343,11 @@ def video_matting(video_state, interactive_state, mask_dropdown, erode_kernel_si
320
  # operation error
321
  if len(np.unique(template_mask))==1:
322
  template_mask[0][0]=1
323
- foreground, alpha = matanyone(matanyone_processor, following_frames, template_mask*255, r_erode=erode_kernel_size, r_dilate=dilate_kernel_size)
324
 
325
  foreground_output = generate_video_from_frames(foreground, output_path="./results/{}_fg.mp4".format(video_state["video_name"]), fps=fps, audio_path=audio_path) # import video_input to name the output video
326
  alpha_output = generate_video_from_frames(alpha, output_path="./results/{}_alpha.mp4".format(video_state["video_name"]), fps=fps, gray2rgb=True, audio_path=audio_path) # import video_input to name the output video
327
-
328
  return foreground_output, alpha_output
329
 
330
 
@@ -415,47 +438,113 @@ sam_checkpoint = load_file_from_url(sam_checkpoint_url_dict[args.sam_model_type]
415
  # initialize sams
416
  model = MaskGenerator(sam_checkpoint, args)
417
 
418
- # initialize matanyone
419
- # load from ckpt
420
- # pretrain_model_url = "https://github.com/pq-yang/MatAnyone/releases/download/v1.0.0"
421
- # ckpt_path = load_file_from_url(os.path.join(pretrain_model_url, 'matanyone.pth'), checkpoint_folder)
422
- # matanyone_model = get_matanyone_model(ckpt_path, args.device)
423
- # load from Hugging Face
424
- from matanyone.model.matanyone import MatAnyone
425
- matanyone_model = MatAnyone.from_pretrained("PeiqingYang/MatAnyone")
 
 
 
 
 
 
 
 
 
 
426
 
427
- matanyone_model = matanyone_model.to(args.device).eval()
428
- matanyone_processor = InferenceCore(matanyone_model, cfg=matanyone_model.cfg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
429
 
430
  # download test samples
431
- media_url = "https://github.com/pq-yang/MatAnyone/releases/download/media/"
432
  test_sample_path = os.path.join('/home/user/app/hugging_face/', "test_sample/")
433
- load_file_from_url(os.path.join(media_url, 'test-sample0-720p.mp4'), test_sample_path)
434
- load_file_from_url(os.path.join(media_url, 'test-sample1-720p.mp4'), test_sample_path)
435
- load_file_from_url(os.path.join(media_url, 'test-sample2-720p.mp4'), test_sample_path)
436
- load_file_from_url(os.path.join(media_url, 'test-sample3-720p.mp4'), test_sample_path)
437
- load_file_from_url(os.path.join(media_url, 'test-sample0.jpg'), test_sample_path)
438
- load_file_from_url(os.path.join(media_url, 'test-sample1.jpg'), test_sample_path)
 
 
 
 
439
 
440
  # download assets
441
  assets_path = os.path.join('/home/user/app/hugging_face/', "assets/")
442
- load_file_from_url(os.path.join(media_url, 'tutorial_single_target.mp4'), assets_path)
443
- load_file_from_url(os.path.join(media_url, 'tutorial_multi_targets.mp4'), assets_path)
444
 
445
  # documents
446
- title = r"""<div class="multi-layer" align="center"><span>MatAnyone</span></div>
447
  """
448
  description = r"""
449
- <b>Official Gradio demo</b> for <a href='https://github.com/pq-yang/MatAnyone' target='_blank'><b>MatAnyone: Stable Video Matting with Consistent Memory Propagation</b></a>.<br>
450
- πŸ”₯ MatAnyone is a practical human video matting framework supporting target assignment 🎯.<br>
451
- πŸŽͺ Try to drop your video/image, assign the target masks with a few clicks, and get the the matting results 🀑!<br>
 
452
 
453
  *Note: Due to the online GPU memory constraints, any input with too big resolution will be resized to 1080p.<br>*
454
- πŸš€ <b> If you encounter any issue (e.g., frozen video output) or wish to run on higher resolution inputs, please consider <u>duplicating this space</u> or
455
- <u>launching the <a href='https://github.com/pq-yang/MatAnyone?tab=readme-ov-file#-interactive-demo' target='_blank'>demo</a> locally</u> following the GitHub instructions.</b>
456
  """
457
  article = r"""<h3>
458
- <b>If MatAnyone is helpful, please help to 🌟 the <a href='https://github.com/pq-yang/MatAnyone' target='_blank'>Github Repo</a>. Thanks!</b></h3>
459
 
460
  ---
461
 
@@ -463,6 +552,13 @@ article = r"""<h3>
463
  <br>
464
  If our work is useful for your research, please consider citing:
465
  ```bibtex
 
 
 
 
 
 
 
466
  @InProceedings{yang2025matanyone,
467
  title = {{MatAnyone}: Stable Video Matting with Consistent Memory Propagation},
468
  author = {Yang, Peiqing and Zhou, Shangchen and Zhao, Jixin and Tao, Qingyi and Loy, Chen Change},
@@ -558,10 +654,10 @@ with gr.Blocks(theme=gr.themes.Monochrome(), css=my_custom_css) as demo:
558
  <div class="title-container">
559
  <h1 class="title is-2 publication-title"
560
  style="font-size:50px; font-family: 'Sarpanch', serif;
561
- background: linear-gradient(to right, #d231d8, #2dc464);
562
  display: inline-block; -webkit-background-clip: text;
563
  -webkit-text-fill-color: transparent;">
564
- MatAnyone
565
  </h1>
566
  </div>
567
  ''')
@@ -614,7 +710,14 @@ with gr.Blocks(theme=gr.themes.Monochrome(), css=my_custom_css) as demo:
614
 
615
  with gr.Group(elem_classes="gr-monochrome-group", visible=True):
616
  with gr.Row():
617
- with gr.Accordion('MatAnyone Settings (click to expand)', open=False):
 
 
 
 
 
 
 
618
  with gr.Row():
619
  erode_kernel_size = gr.Slider(label='Erode Kernel Size',
620
  minimum=0,
@@ -722,7 +825,7 @@ with gr.Blocks(theme=gr.themes.Monochrome(), css=my_custom_css) as demo:
722
  # video matting
723
  matting_button.click(
724
  fn=video_matting,
725
- inputs=[video_state, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size],
726
  outputs=[foreground_video_output, alpha_video_output]
727
  )
728
 
@@ -775,7 +878,7 @@ with gr.Blocks(theme=gr.themes.Monochrome(), css=my_custom_css) as demo:
775
  gr.Markdown("---")
776
  gr.Markdown("## Examples")
777
  gr.Examples(
778
- examples=[os.path.join(os.path.dirname(__file__), "./test_sample/", test_sample) for test_sample in ["test-sample0-720p.mp4", "test-sample1-720p.mp4", "test-sample2-720p.mp4", "test-sample3-720p.mp4"]],
779
  inputs=[video_input],
780
  )
781
 
@@ -811,7 +914,14 @@ with gr.Blocks(theme=gr.themes.Monochrome(), css=my_custom_css) as demo:
811
 
812
  with gr.Group(elem_classes="gr-monochrome-group", visible=True):
813
  with gr.Row():
814
- with gr.Accordion('MatAnyone Settings (click to expand)', open=False):
 
 
 
 
 
 
 
815
  with gr.Row():
816
  erode_kernel_size = gr.Slider(label='Erode Kernel Size',
817
  minimum=0,
@@ -918,7 +1028,7 @@ with gr.Blocks(theme=gr.themes.Monochrome(), css=my_custom_css) as demo:
918
  # image matting
919
  matting_button.click(
920
  fn=image_matting,
921
- inputs=[image_state, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size, image_selection_slider],
922
  outputs=[foreground_image_output, alpha_image_output]
923
  )
924
 
@@ -971,7 +1081,7 @@ with gr.Blocks(theme=gr.themes.Monochrome(), css=my_custom_css) as demo:
971
  gr.Markdown("---")
972
  gr.Markdown("## Examples")
973
  gr.Examples(
974
- examples=[os.path.join(os.path.dirname(__file__), "./test_sample/", test_sample) for test_sample in ["test-sample0.jpg", "test-sample1.jpg"]],
975
  inputs=[image_input],
976
  )
977
 
 
21
  from tools.misc import get_device
22
  from tools.download_util import load_file_from_url
23
 
24
+ from matanyone2_wrapper import matanyone2
25
+ from matanyone2.utils.get_default_model import get_matanyone2_model
26
+ from matanyone2.inference.inference_core import InferenceCore
27
+ from hydra.core.global_hydra import GlobalHydra
28
+
29
+ import warnings
30
+ warnings.filterwarnings("ignore")
31
 
32
  def parse_augment():
33
  parser = argparse.ArgumentParser()
 
125
  except Exception as e:
126
  print(f"Audio extraction error: {str(e)}")
127
  audio_path = "" # Set to "" if extraction fails
 
128
 
129
  # extract frames
130
  try:
 
143
  print("read_frame_source:{} error. {}\n".format(video_path, str(e)))
144
  image_size = (frames[0].shape[0],frames[0].shape[1])
145
 
146
+ # [remove for local demo] resize if resolution too big
147
+ if image_size[0]>=1080 and image_size[0]>=1080:
148
  scale = 1080 / min(image_size)
149
  new_w = int(image_size[1] * scale)
150
  new_h = int(image_size[0] * scale)
 
168
  video_info = "Video Name: {},\nFPS: {},\nTotal Frames: {},\nImage Size:{}".format(video_state["video_name"], round(video_state["fps"], 0), len(frames), image_size)
169
  model.samcontroler.sam_controler.reset_image()
170
  model.samcontroler.sam_controler.set_image(video_state["origin_images"][0])
171
+ return video_state, video_info, video_state["origin_images"][0], gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=False, maximum=len(frames), value=len(frames)), \
 
172
  gr.update(visible=True), gr.update(visible=True), \
173
  gr.update(visible=True), gr.update(visible=True),\
174
  gr.update(visible=True), gr.update(visible=True), \
 
269
  return select_frame
270
 
271
  # image matting
272
+ def image_matting(video_state, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size, refine_iter, model_selection):
273
+ # Load model if not already loaded
274
+ try:
275
+ selected_model = load_model(model_selection)
276
+ except (FileNotFoundError, ValueError) as e:
277
+ # Fallback to first available model
278
+ if available_models:
279
+ print(f"Warning: {str(e)}. Using {available_models[0]} instead.")
280
+ selected_model = load_model(available_models[0])
281
+ else:
282
+ raise ValueError("No models are available! Please check if the model files exist.")
283
+ matanyone_processor = InferenceCore(selected_model, cfg=selected_model.cfg)
284
  if interactive_state["track_end_number"]:
285
  following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
286
  else:
 
301
  # operation error
302
  if len(np.unique(template_mask))==1:
303
  template_mask[0][0]=1
304
+ foreground, alpha = matanyone2(matanyone_processor, following_frames, template_mask*255, r_erode=erode_kernel_size, r_dilate=dilate_kernel_size, n_warmup=refine_iter)
305
  foreground_output = Image.fromarray(foreground[-1])
306
  alpha_output = Image.fromarray(alpha[-1][:,:,0])
307
+
308
  return foreground_output, alpha_output
309
 
310
  # video matting
311
+ def video_matting(video_state, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size, model_selection):
312
+ # Load model if not already loaded
313
+ try:
314
+ selected_model = load_model(model_selection)
315
+ except (FileNotFoundError, ValueError) as e:
316
+ # Fallback to first available model
317
+ if available_models:
318
+ print(f"Warning: {str(e)}. Using {available_models[0]} instead.")
319
+ selected_model = load_model(available_models[0])
320
+ else:
321
+ raise ValueError("No models are available! Please check if the model files exist.")
322
+ matanyone_processor = InferenceCore(selected_model, cfg=selected_model.cfg)
323
  if interactive_state["track_end_number"]:
324
  following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
325
  else:
 
343
  # operation error
344
  if len(np.unique(template_mask))==1:
345
  template_mask[0][0]=1
346
+ foreground, alpha = matanyone2(matanyone_processor, following_frames, template_mask*255, r_erode=erode_kernel_size, r_dilate=dilate_kernel_size)
347
 
348
  foreground_output = generate_video_from_frames(foreground, output_path="./results/{}_fg.mp4".format(video_state["video_name"]), fps=fps, audio_path=audio_path) # import video_input to name the output video
349
  alpha_output = generate_video_from_frames(alpha, output_path="./results/{}_alpha.mp4".format(video_state["video_name"]), fps=fps, gray2rgb=True, audio_path=audio_path) # import video_input to name the output video
350
+
351
  return foreground_output, alpha_output
352
 
353
 
 
438
  # initialize sams
439
  model = MaskGenerator(sam_checkpoint, args)
440
 
441
+ # initialize matanyone - lazy loading
442
+ # Model display names to file names mapping
443
+ model_display_to_file = {
444
+ "MatAnyone": "matanyone.pth",
445
+ "MatAnyone 2": "matanyone2.pth"
446
+ }
447
+
448
+ # Model URLs
449
+ model_urls = {
450
+ "matanyone.pth": "https://github.com/pq-yang/MatAnyone/releases/download/v1.0.0/matanyone.pth",
451
+ "matanyone2.pth": "https://github.com/pq-yang/MatAnyone2/releases/download/v1.0.0/matanyone2.pth"
452
+ }
453
+
454
+ # Model paths - download models using load_file_from_url
455
+ model_paths = {
456
+ "matanyone.pth": load_file_from_url(model_urls["matanyone.pth"], checkpoint_folder),
457
+ "matanyone2.pth": load_file_from_url(model_urls["matanyone2.pth"], checkpoint_folder)
458
+ }
459
 
460
+ # Cache for loaded models (lazy loading)
461
+ loaded_models = {}
462
+
463
+ def load_model(display_name):
464
+ """Load a model if not already loaded"""
465
+ # Convert display name to file name
466
+ if display_name in model_display_to_file:
467
+ model_file = model_display_to_file[display_name]
468
+ elif display_name in model_paths:
469
+ # Also support direct file name for backward compatibility
470
+ model_file = display_name
471
+ else:
472
+ raise ValueError(f"Unknown model: {display_name}")
473
+
474
+ if model_file in loaded_models:
475
+ return loaded_models[model_file]
476
+
477
+ if model_file not in model_paths:
478
+ raise ValueError(f"Unknown model file: {model_file}")
479
+
480
+ ckpt_path = model_paths[model_file]
481
+ if not os.path.exists(ckpt_path):
482
+ raise FileNotFoundError(f"Model file not found: {ckpt_path}")
483
+
484
+ # Clear Hydra instance if already initialized (to allow loading different models)
485
+ try:
486
+ GlobalHydra.instance().clear()
487
+ except:
488
+ pass # If Hydra is not initialized, this is fine
489
+
490
+ print(f"Loading model: {display_name} ({model_file})...")
491
+ model = get_matanyone2_model(ckpt_path, args.device)
492
+ model = model.to(args.device).eval()
493
+ loaded_models[model_file] = model
494
+ print(f"Model {display_name} loaded successfully.")
495
+ return model
496
+
497
+ # Get available model choices for the UI (check if files exist)
498
+ # Order: MatAnyone 2 first, then MatAnyone
499
+ available_models = []
500
+ # Check MatAnyone 2 first
501
+ if "MatAnyone 2" in model_display_to_file:
502
+ file_name = model_display_to_file["MatAnyone 2"]
503
+ if file_name in model_paths and os.path.exists(model_paths[file_name]):
504
+ available_models.append("MatAnyone 2")
505
+ # Then check MatAnyone
506
+ if "MatAnyone" in model_display_to_file:
507
+ file_name = model_display_to_file["MatAnyone"]
508
+ if file_name in model_paths and os.path.exists(model_paths[file_name]):
509
+ available_models.append("MatAnyone")
510
+
511
+ if not available_models:
512
+ raise RuntimeError("No models are available! Please ensure at least one model file exists in ../pretrained_models/")
513
+ default_model = "MatAnyone 2" if "MatAnyone 2" in available_models else available_models[0]
514
 
515
  # download test samples
 
516
  test_sample_path = os.path.join('/home/user/app/hugging_face/', "test_sample/")
517
+ load_file_from_url('https://github.com/pq-yang/MatAnyone2/releases/download/media/test-sample-0-1080p.mp4', test_sample_path)
518
+ load_file_from_url('https://github.com/pq-yang/MatAnyone2/releases/download/media/test-sample-1-1080p.mp4', test_sample_path)
519
+ load_file_from_url('https://github.com/pq-yang/MatAnyone2/releases/download/media/test-sample-2-720p.mp4', test_sample_path)
520
+ load_file_from_url('https://github.com/pq-yang/MatAnyone2/releases/download/media/test-sample-3-720p.mp4', test_sample_path)
521
+ load_file_from_url('https://github.com/pq-yang/MatAnyone2/releases/download/media/test-sample-4-720p.mp4', test_sample_path)
522
+ load_file_from_url('https://github.com/pq-yang/MatAnyone2/releases/download/media/test-sample-5-720p.mp4', test_sample_path)
523
+ load_file_from_url('https://github.com/pq-yang/MatAnyone2/releases/download/media/test-sample-0.jpg', test_sample_path)
524
+ load_file_from_url('https://github.com/pq-yang/MatAnyone2/releases/download/media/test-sample-1.jpg', test_sample_path)
525
+ load_file_from_url('https://github.com/pq-yang/MatAnyone2/releases/download/media/test-sample-2.jpg', test_sample_path)
526
+ load_file_from_url('https://github.com/pq-yang/MatAnyone2/releases/download/media/test-sample-3.jpg', test_sample_path)
527
 
528
  # download assets
529
  assets_path = os.path.join('/home/user/app/hugging_face/', "assets/")
530
+ load_file_from_url('https://github.com/pq-yang/MatAnyone/releases/download/media/tutorial_single_target.mp4', assets_path)
531
+ load_file_from_url('https://github.com/pq-yang/MatAnyone/releases/download/media/tutorial_multi_targets.mp4', assets_path)
532
 
533
  # documents
534
+ title = r"""<div class="multi-layer" align="center"><span>MatAnyone Series</span></div>
535
  """
536
  description = r"""
537
+ <b>Official Gradio demo</b> for <a href='https://github.com/pq-yang/MatAnyone2' target='_blank'><b>MatAnyone 2</b></a> and <a href='https://github.com/pq-yang/MatAnyone' target='_blank'><b>MatAnyone</b></a>.<br>
538
+ πŸ”₯ MatAnyone series provide practical human video matting framework supporting target assignment.<br>
539
+ 🧐 <b>We use <u>MatAnyone 2</u> as the default model. You can also choose <u>MatAnyone</u> in "Model Selection".</b><br>
540
+ πŸŽͺ Try to drop your video/image, assign the target masks with a few clicks, and get the the matting results!<br>
541
 
542
  *Note: Due to the online GPU memory constraints, any input with too big resolution will be resized to 1080p.<br>*
543
+ πŸš€ <b> If you encounter any issue (e.g., frozen video output) or wish to run on higher resolution inputs, please consider duplicating this space or
544
+ launching the demo locally following the <a href='https://github.com/pq-yang/MatAnyone2?tab=readme-ov-file#-interactive-demo' target='_blank'>GitHub instructions</a>.</b>
545
  """
546
  article = r"""<h3>
547
+ <b>If our projects are helpful, please help to 🌟 the Github Repo for <a href='https://github.com/pq-yang/MatAnyone2' target='_blank'>MatAnyone 2</a> and <a href='https://github.com/pq-yang/MatAnyone' target='_blank'>MatAnyone</a>. Thanks!</b></h3>
548
 
549
  ---
550
 
 
552
  <br>
553
  If our work is useful for your research, please consider citing:
554
  ```bibtex
555
+ @InProceedings{yang2026matanyone2,
556
+ title = {{MatAnyone 2}: Scaling Video Matting via a Learned Quality Evaluator},
557
+ author = {Yang, Peiqing and Zhou, Shangchen and Hao, Kai and Tao, Qingyi},
558
+ booktitle = {CVPR},
559
+ year = {2026}
560
+ }
561
+
562
  @InProceedings{yang2025matanyone,
563
  title = {{MatAnyone}: Stable Video Matting with Consistent Memory Propagation},
564
  author = {Yang, Peiqing and Zhou, Shangchen and Zhao, Jixin and Tao, Qingyi and Loy, Chen Change},
 
654
  <div class="title-container">
655
  <h1 class="title is-2 publication-title"
656
  style="font-size:50px; font-family: 'Sarpanch', serif;
657
+ background: linear-gradient(to right, #000000, #2dc464);
658
  display: inline-block; -webkit-background-clip: text;
659
  -webkit-text-fill-color: transparent;">
660
+ MatAnyone Series
661
  </h1>
662
  </div>
663
  ''')
 
710
 
711
  with gr.Group(elem_classes="gr-monochrome-group", visible=True):
712
  with gr.Row():
713
+ model_selection = gr.Radio(
714
+ choices=available_models,
715
+ value=default_model,
716
+ label="Model Selection",
717
+ info="Choose the model to use for matting",
718
+ interactive=True)
719
+ with gr.Row():
720
+ with gr.Accordion('Model Settings (click to expand)', open=False):
721
  with gr.Row():
722
  erode_kernel_size = gr.Slider(label='Erode Kernel Size',
723
  minimum=0,
 
825
  # video matting
826
  matting_button.click(
827
  fn=video_matting,
828
+ inputs=[video_state, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size, model_selection],
829
  outputs=[foreground_video_output, alpha_video_output]
830
  )
831
 
 
878
  gr.Markdown("---")
879
  gr.Markdown("## Examples")
880
  gr.Examples(
881
+ examples=[os.path.join(os.path.dirname(__file__), "./test_sample/", test_sample) for test_sample in ["test-sample-0-1080p.mp4", "test-sample-1-1080p.mp4", "test-sample-2-720p.mp4", "test-sample-3-720p.mp4", "test-sample-4-720p.mp4", "test-sample-5-720p.mp4"]],
882
  inputs=[video_input],
883
  )
884
 
 
914
 
915
  with gr.Group(elem_classes="gr-monochrome-group", visible=True):
916
  with gr.Row():
917
+ model_selection = gr.Radio(
918
+ choices=available_models,
919
+ value=default_model,
920
+ label="Model Selection",
921
+ info="Choose the model to use for matting",
922
+ interactive=True)
923
+ with gr.Row():
924
+ with gr.Accordion('Model Settings (click to expand)', open=False):
925
  with gr.Row():
926
  erode_kernel_size = gr.Slider(label='Erode Kernel Size',
927
  minimum=0,
 
1028
  # image matting
1029
  matting_button.click(
1030
  fn=image_matting,
1031
+ inputs=[image_state, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size, image_selection_slider, model_selection],
1032
  outputs=[foreground_image_output, alpha_image_output]
1033
  )
1034
 
 
1081
  gr.Markdown("---")
1082
  gr.Markdown("## Examples")
1083
  gr.Examples(
1084
+ examples=[os.path.join(os.path.dirname(__file__), "./test_sample/", test_sample) for test_sample in ["test-sample-0.jpg", "test-sample-1.jpg", "test-sample-2.jpg", "test-sample-3.jpg"]],
1085
  inputs=[image_input],
1086
  )
1087
 
hugging_face/{matanyone_wrapper.py β†’ matanyone2_wrapper.py} RENAMED
@@ -1,9 +1,13 @@
 
1
  import tqdm
2
  import torch
3
  from torchvision.transforms.functional import to_tensor
4
  import numpy as np
5
  import random
6
  import cv2
 
 
 
7
 
8
  def gen_dilate(alpha, min_kernel_size, max_kernel_size):
9
  kernel_size = random.randint(min_kernel_size, max_kernel_size)
@@ -20,8 +24,8 @@ def gen_erosion(alpha, min_kernel_size, max_kernel_size):
20
  return erode.astype(np.float32)
21
 
22
  @torch.inference_mode()
23
- @torch.cuda.amp.autocast()
24
- def matanyone(processor, frames_np, mask, r_erode=0, r_dilate=0, n_warmup=10):
25
  """
26
  Args:
27
  frames_np: [(H,W,C)]*n, uint8
@@ -41,14 +45,14 @@ def matanyone(processor, frames_np, mask, r_erode=0, r_dilate=0, n_warmup=10):
41
  if r_erode > 0:
42
  mask = gen_erosion(mask, r_erode, r_erode)
43
 
44
- mask = torch.from_numpy(mask).cuda()
45
 
46
  frames_np = [frames_np[0]]* n_warmup + frames_np
47
 
48
  frames = []
49
  phas = []
50
  for ti, frame_single in tqdm.tqdm(enumerate(frames_np)):
51
- image = to_tensor(frame_single).cuda().float()
52
 
53
  if ti == 0:
54
  output_prob = processor.step(image, mask, objects=objects) # encode given mask
@@ -62,7 +66,7 @@ def matanyone(processor, frames_np, mask, r_erode=0, r_dilate=0, n_warmup=10):
62
  # convert output probabilities to an object mask
63
  mask = processor.output_prob_to_mask(output_prob)
64
 
65
- pha = mask.unsqueeze(2).cpu().numpy()
66
  com_np = frame_single / 255. * pha + bgr * (1 - pha)
67
 
68
  # DONOT save the warmup frames
@@ -70,4 +74,4 @@ def matanyone(processor, frames_np, mask, r_erode=0, r_dilate=0, n_warmup=10):
70
  frames.append((com_np*255).astype(np.uint8))
71
  phas.append((pha*255).astype(np.uint8))
72
 
73
- return frames, phas
 
1
+
2
  import tqdm
3
  import torch
4
  from torchvision.transforms.functional import to_tensor
5
  import numpy as np
6
  import random
7
  import cv2
8
+ from matanyone2.utils.device import get_default_device, safe_autocast_decorator
9
+
10
+ device = get_default_device()
11
 
12
  def gen_dilate(alpha, min_kernel_size, max_kernel_size):
13
  kernel_size = random.randint(min_kernel_size, max_kernel_size)
 
24
  return erode.astype(np.float32)
25
 
26
  @torch.inference_mode()
27
+ @safe_autocast_decorator()
28
+ def matanyone2(processor, frames_np, mask, r_erode=0, r_dilate=0, n_warmup=10):
29
  """
30
  Args:
31
  frames_np: [(H,W,C)]*n, uint8
 
45
  if r_erode > 0:
46
  mask = gen_erosion(mask, r_erode, r_erode)
47
 
48
+ mask = torch.from_numpy(mask).to(device)
49
 
50
  frames_np = [frames_np[0]]* n_warmup + frames_np
51
 
52
  frames = []
53
  phas = []
54
  for ti, frame_single in tqdm.tqdm(enumerate(frames_np)):
55
+ image = to_tensor(frame_single).float().to(device)
56
 
57
  if ti == 0:
58
  output_prob = processor.step(image, mask, objects=objects) # encode given mask
 
66
  # convert output probabilities to an object mask
67
  mask = processor.output_prob_to_mask(output_prob)
68
 
69
+ pha = mask.unsqueeze(2).detach().to("cpu").numpy()
70
  com_np = frame_single / 255. * pha + bgr * (1 - pha)
71
 
72
  # DONOT save the warmup frames
 
74
  frames.append((com_np*255).astype(np.uint8))
75
  phas.append((pha*255).astype(np.uint8))
76
 
77
+ return frames, phas
matanyone/utils/__init__.py DELETED
File without changes
matanyone2/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from matanyone2.inference.inference_core import InferenceCore
2
+ from matanyone2.model.matanyone2 import MatAnyone2
{matanyone β†’ matanyone2/config}/__init__.py RENAMED
File without changes
{matanyone β†’ matanyone2}/config/eval_matanyone_config.yaml RENAMED
@@ -9,7 +9,7 @@ hydra:
9
  output_subdir: ${now:%Y-%m-%d_%H-%M-%S}-hydra
10
 
11
  amp: False
12
- weights: pretrained_models/matanyone.pth # default (can be modified from outside)
13
  output_dir: null # defaults to run_dir; specify this to override
14
  flip_aug: False
15
 
 
9
  output_subdir: ${now:%Y-%m-%d_%H-%M-%S}-hydra
10
 
11
  amp: False
12
+ weights: pretrained_models/matanyone2.pth # default (can be modified from outside)
13
  output_dir: null # defaults to run_dir; specify this to override
14
  flip_aug: False
15
 
{matanyone β†’ matanyone2}/config/hydra/job_logging/custom-no-rank.yaml RENAMED
File without changes
{matanyone β†’ matanyone2}/config/hydra/job_logging/custom.yaml RENAMED
File without changes
{matanyone β†’ matanyone2}/config/model/base.yaml RENAMED
File without changes
{matanyone/config β†’ matanyone2/inference}/__init__.py RENAMED
File without changes
{matanyone β†’ matanyone2}/inference/image_feature_store.py RENAMED
@@ -1,7 +1,7 @@
1
  import warnings
2
  from typing import Iterable
3
  import torch
4
- from matanyone.model.matanyone import MatAnyone
5
 
6
 
7
  class ImageFeatureStore:
@@ -13,7 +13,7 @@ class ImageFeatureStore:
13
 
14
  Feature of a frame should be associated with a unique index -- typically the frame id.
15
  """
16
- def __init__(self, network: MatAnyone, no_warning: bool = False):
17
  self.network = network
18
  self._store = {}
19
  self.no_warning = no_warning
 
1
  import warnings
2
  from typing import Iterable
3
  import torch
4
+ from matanyone2.model.matanyone2 import MatAnyone2
5
 
6
 
7
  class ImageFeatureStore:
 
13
 
14
  Feature of a frame should be associated with a unique index -- typically the frame id.
15
  """
16
+ def __init__(self, network: MatAnyone2, no_warning: bool = False):
17
  self.network = network
18
  self._store = {}
19
  self.no_warning = no_warning
{matanyone β†’ matanyone2}/inference/inference_core.py RENAMED
@@ -1,16 +1,25 @@
1
- from typing import List, Optional, Iterable
2
  import logging
3
  from omegaconf import DictConfig
 
4
 
5
- import numpy as np
 
6
  import torch
 
 
 
 
 
7
  import torch.nn.functional as F
8
 
9
- from matanyone.inference.memory_manager import MemoryManager
10
- from matanyone.inference.object_manager import ObjectManager
11
- from matanyone.inference.image_feature_store import ImageFeatureStore
12
- from matanyone.model.matanyone import MatAnyone
13
- from matanyone.utils.tensor_utils import pad_divide_by, unpad, aggregate
 
 
 
14
 
15
  log = logging.getLogger()
16
 
@@ -18,11 +27,21 @@ log = logging.getLogger()
18
  class InferenceCore:
19
 
20
  def __init__(self,
21
- network: MatAnyone,
22
- cfg: DictConfig,
23
  *,
24
- image_feature_store: ImageFeatureStore = None):
25
- self.network = network
 
 
 
 
 
 
 
 
 
 
26
  self.cfg = cfg
27
  self.mem_every = cfg.mem_every
28
  stagger_updates = cfg.stagger_updates
@@ -404,3 +423,128 @@ class InferenceCore:
404
  new_mask[mask == tmp_id] = obj.id
405
 
406
  return new_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import logging
2
  from omegaconf import DictConfig
3
+ from typing import List, Optional, Iterable, Union,Tuple
4
 
5
+ import os
6
+ import cv2
7
  import torch
8
+ import imageio
9
+ import tempfile
10
+ import numpy as np
11
+ from tqdm import tqdm
12
+ from PIL import Image
13
  import torch.nn.functional as F
14
 
15
+ from matanyone2.inference.memory_manager import MemoryManager
16
+ from matanyone2.inference.object_manager import ObjectManager
17
+ from matanyone2.inference.image_feature_store import ImageFeatureStore
18
+ from matanyone2.model.matanyone2 import MatAnyone2
19
+ from matanyone2.utils.tensor_utils import pad_divide_by, unpad, aggregate
20
+ from matanyone2.utils.inference_utils import gen_dilate, gen_erosion, read_frame_from_videos
21
+ from matanyone2.utils.device import get_default_device, safe_autocast
22
+
23
 
24
  log = logging.getLogger()
25
 
 
27
  class InferenceCore:
28
 
29
  def __init__(self,
30
+ network: Union[MatAnyone2,str],
31
+ cfg: DictConfig = None,
32
  *,
33
+ image_feature_store: ImageFeatureStore = None,
34
+ device: Optional[Union[str, torch.device]] = None
35
+ ):
36
+ if device is None:
37
+ device = get_default_device()
38
+ self.device = device
39
+ if isinstance(network, str):
40
+ network = MatAnyone2.from_pretrained(network)
41
+ network.to(device)
42
+ network.eval()
43
+ self.network = network
44
+ cfg = cfg if cfg is not None else network.cfg
45
  self.cfg = cfg
46
  self.mem_every = cfg.mem_every
47
  stagger_updates = cfg.stagger_updates
 
423
  new_mask[mask == tmp_id] = obj.id
424
 
425
  return new_mask
426
+
427
+ @torch.inference_mode()
428
+ @safe_autocast()
429
+ def process_video(
430
+ self,
431
+ input_path: str,
432
+ mask_path: str,
433
+ output_path: str = None,
434
+ n_warmup: int = 10,
435
+ r_erode: int = 10,
436
+ r_dilate: int = 10,
437
+ suffix: str = "",
438
+ save_image: bool = False,
439
+ max_size: int = -1,
440
+ ) -> Tuple:
441
+ """
442
+ Process a video for object segmentation and matting.
443
+ This method processes a video file by performing object segmentation and matting on each frame.
444
+ It supports warmup frames, mask erosion/dilation, and various output options.
445
+ Args:
446
+ input_path (str): Path to the input video file
447
+ mask_path (str): Path to the mask image file used for initial segmentation
448
+ output_path (str, optional): Directory path where output files will be saved. Defaults to a temporary directory
449
+ n_warmup (int, optional): Number of warmup frames to use. Defaults to 10
450
+ r_erode (int, optional): Erosion radius for mask processing. Defaults to 10
451
+ r_dilate (int, optional): Dilation radius for mask processing. Defaults to 10
452
+ suffix (str, optional): Suffix to append to output filename. Defaults to ""
453
+ save_image (bool, optional): Whether to save individual frames. Defaults to False
454
+ max_size (int, optional): Maximum size for frame dimension. Use -1 for no limit. Defaults to -1
455
+ Returns:
456
+ Tuple[str, str]: A tuple containing:
457
+ - Path to the output foreground video file (str)
458
+ - Path to the output alpha matte video file (str)
459
+ Output:
460
+ - Saves processed video files with foreground (_fgr) and alpha matte (_pha)
461
+ - If save_image=True, saves individual frames in separate directories
462
+ """
463
+ output_path = output_path if output_path is not None else tempfile.TemporaryDirectory().name
464
+ r_erode = int(r_erode)
465
+ r_dilate = int(r_dilate)
466
+ n_warmup = int(n_warmup)
467
+ max_size = int(max_size)
468
+
469
+ vframes, fps, length, video_name = read_frame_from_videos(input_path)
470
+ repeated_frames = vframes[0].unsqueeze(0).repeat(n_warmup, 1, 1, 1)
471
+ vframes = torch.cat([repeated_frames, vframes], dim=0).float()
472
+ length += n_warmup
473
+
474
+ new_h, new_w = vframes.shape[-2:]
475
+ if max_size > 0:
476
+ h, w = new_h, new_w
477
+ min_side = min(h, w)
478
+ if min_side > max_size:
479
+ new_h = int(h / min_side * max_size)
480
+ new_w = int(w / min_side * max_size)
481
+ vframes = F.interpolate(vframes, size=(new_h, new_w), mode="area")
482
+
483
+ os.makedirs(output_path, exist_ok=True)
484
+ if suffix:
485
+ video_name = f"{video_name}_{suffix}"
486
+ if save_image:
487
+ os.makedirs(f"{output_path}/{video_name}", exist_ok=True)
488
+ os.makedirs(f"{output_path}/{video_name}/pha", exist_ok=True)
489
+ os.makedirs(f"{output_path}/{video_name}/fgr", exist_ok=True)
490
+
491
+ mask = np.array(Image.open(mask_path).convert("L"))
492
+ if r_dilate > 0:
493
+ mask = gen_dilate(mask, r_dilate, r_dilate)
494
+ if r_erode > 0:
495
+ mask = gen_erosion(mask, r_erode, r_erode)
496
+
497
+ mask = torch.from_numpy(mask).float().to(self.device)
498
+ if max_size > 0:
499
+ mask = F.interpolate(
500
+ mask.unsqueeze(0).unsqueeze(0), size=(new_h, new_w), mode="nearest"
501
+ )[0, 0]
502
+
503
+ bgr = (np.array([120, 255, 155], dtype=np.float32) / 255).reshape((1, 1, 3))
504
+ objects = [1]
505
+
506
+ phas = []
507
+ fgrs = []
508
+ for ti in tqdm(range(length)):
509
+ image = vframes[ti]
510
+ image_np = np.array(image.permute(1, 2, 0))
511
+ image = (image / 255.0).float().to(self.device)
512
+
513
+ if ti == 0:
514
+ output_prob = self.step(image, mask, objects=objects)
515
+ output_prob = self.step(image, first_frame_pred=True)
516
+ else:
517
+ if ti <= n_warmup:
518
+ output_prob = self.step(image, first_frame_pred=True)
519
+ else:
520
+ output_prob = self.step(image)
521
+
522
+ mask = self.output_prob_to_mask(output_prob)
523
+ pha = mask.unsqueeze(2).cpu().numpy()
524
+ com_np = image_np / 255.0 * pha + bgr * (1 - pha)
525
+
526
+ if ti > (n_warmup - 1):
527
+ com_np = (com_np * 255).astype(np.uint8)
528
+ pha = (pha * 255).astype(np.uint8)
529
+ fgrs.append(com_np)
530
+ phas.append(pha)
531
+ if save_image:
532
+ cv2.imwrite(
533
+ f"{output_path}/{video_name}/pha/{str(ti - n_warmup).zfill(5)}.png",
534
+ pha,
535
+ )
536
+ cv2.imwrite(
537
+ f"{output_path}/{video_name}/fgr/{str(ti - n_warmup).zfill(5)}.png",
538
+ com_np[..., [2, 1, 0]],
539
+ )
540
+
541
+ fgrs = np.array(fgrs)
542
+ phas = np.array(phas)
543
+
544
+ fgr_filename = f"{output_path}/{video_name}_fgr.mp4"
545
+ alpha_filename = f"{output_path}/{video_name}_pha.mp4"
546
+
547
+ imageio.mimwrite(fgr_filename, fgrs, fps=fps, quality=7)
548
+ imageio.mimwrite(alpha_filename, phas, fps=fps, quality=7)
549
+
550
+ return (fgr_filename,alpha_filename)
{matanyone β†’ matanyone2}/inference/kv_memory_store.py RENAMED
File without changes
{matanyone β†’ matanyone2}/inference/memory_manager.py RENAMED
@@ -3,10 +3,10 @@ from omegaconf import DictConfig
3
  from typing import List, Dict
4
  import torch
5
 
6
- from matanyone.inference.object_manager import ObjectManager
7
- from matanyone.inference.kv_memory_store import KeyValueMemoryStore
8
- from matanyone.model.matanyone import MatAnyone
9
- from matanyone.model.utils.memory_utils import get_similarity, do_softmax
10
 
11
  log = logging.getLogger()
12
 
@@ -113,7 +113,7 @@ class MemoryManager:
113
  return value
114
 
115
  def read_first_frame(self, last_msk_value, pix_feat: torch.Tensor,
116
- last_mask: torch.Tensor, network: MatAnyone, uncert_output=None) -> Dict[int, torch.Tensor]:
117
  """
118
  Read from all memory stores and returns a single memory readout tensor for each object
119
 
@@ -166,7 +166,7 @@ class MemoryManager:
166
  return all_readout_mem
167
 
168
  def read(self, pix_feat: torch.Tensor, query_key: torch.Tensor, selection: torch.Tensor,
169
- last_mask: torch.Tensor, network: MatAnyone, uncert_output=None, last_msk_value=None, ti=None,
170
  last_pix_feat=None, last_pred_mask=None) -> Dict[int, torch.Tensor]:
171
  """
172
  Read from all memory stores and returns a single memory readout tensor for each object
 
3
  from typing import List, Dict
4
  import torch
5
 
6
+ from matanyone2.inference.object_manager import ObjectManager
7
+ from matanyone2.inference.kv_memory_store import KeyValueMemoryStore
8
+ from matanyone2.model.matanyone2 import MatAnyone2
9
+ from matanyone2.model.utils.memory_utils import get_similarity, do_softmax
10
 
11
  log = logging.getLogger()
12
 
 
113
  return value
114
 
115
  def read_first_frame(self, last_msk_value, pix_feat: torch.Tensor,
116
+ last_mask: torch.Tensor, network: MatAnyone2, uncert_output=None) -> Dict[int, torch.Tensor]:
117
  """
118
  Read from all memory stores and returns a single memory readout tensor for each object
119
 
 
166
  return all_readout_mem
167
 
168
  def read(self, pix_feat: torch.Tensor, query_key: torch.Tensor, selection: torch.Tensor,
169
+ last_mask: torch.Tensor, network: MatAnyone2, uncert_output=None, last_msk_value=None, ti=None,
170
  last_pix_feat=None, last_pred_mask=None) -> Dict[int, torch.Tensor]:
171
  """
172
  Read from all memory stores and returns a single memory readout tensor for each object
{matanyone β†’ matanyone2}/inference/object_info.py RENAMED
File without changes
{matanyone β†’ matanyone2}/inference/object_manager.py RENAMED
@@ -1,7 +1,7 @@
1
  from typing import Union, List, Dict
2
 
3
  import torch
4
- from matanyone.inference.object_info import ObjectInfo
5
 
6
 
7
  class ObjectManager:
 
1
  from typing import Union, List, Dict
2
 
3
  import torch
4
+ from matanyone2.inference.object_info import ObjectInfo
5
 
6
 
7
  class ObjectManager:
{matanyone/inference β†’ matanyone2/inference/utils}/__init__.py RENAMED
File without changes
{matanyone β†’ matanyone2}/inference/utils/args_utils.py RENAMED
File without changes
{matanyone/inference/utils β†’ matanyone2/model}/__init__.py RENAMED
File without changes
{matanyone β†’ matanyone2}/model/aux_modules.py RENAMED
@@ -6,8 +6,8 @@ from omegaconf import DictConfig
6
  import torch
7
  import torch.nn as nn
8
 
9
- from matanyone.model.group_modules import GConv2d
10
- from matanyone.utils.tensor_utils import aggregate
11
 
12
 
13
  class LinearPredictor(nn.Module):
 
6
  import torch
7
  import torch.nn as nn
8
 
9
+ from matanyone2.model.group_modules import GConv2d
10
+ from matanyone2.utils.tensor_utils import aggregate
11
 
12
 
13
  class LinearPredictor(nn.Module):
{matanyone β†’ matanyone2}/model/big_modules.py RENAMED
@@ -14,9 +14,10 @@ import torch
14
  import torch.nn as nn
15
  import torch.nn.functional as F
16
 
17
- from matanyone.model.group_modules import MainToGroupDistributor, GroupFeatureFusionBlock, GConv2d
18
- from matanyone.model.utils import resnet
19
- from matanyone.model.modules import SensoryDeepUpdater, SensoryUpdater_fullscale, DecoderFeatureProcessor, MaskUpsampleBlock
 
20
 
21
  class UncertPred(nn.Module):
22
  def __init__(self, model_cfg: DictConfig):
@@ -330,7 +331,7 @@ class MaskDecoder(nn.Module):
330
  p4 = self.up_8_4(p8, f4)
331
  p2 = self.up_4_2(p4, f2)
332
  p1 = self.up_2_1(p2, f1)
333
- with torch.cuda.amp.autocast(enabled=False):
334
  if seg_pass:
335
  if last_mask is not None:
336
  res = self.pred_seg(F.relu(p1.flatten(start_dim=0, end_dim=1).float()))
 
14
  import torch.nn as nn
15
  import torch.nn.functional as F
16
 
17
+ from matanyone2.model.group_modules import MainToGroupDistributor, GroupFeatureFusionBlock, GConv2d
18
+ from matanyone2.model.utils import resnet
19
+ from matanyone2.model.modules import SensoryDeepUpdater, SensoryUpdater_fullscale, DecoderFeatureProcessor, MaskUpsampleBlock
20
+ from matanyone2.utils.device import safe_autocast
21
 
22
  class UncertPred(nn.Module):
23
  def __init__(self, model_cfg: DictConfig):
 
331
  p4 = self.up_8_4(p8, f4)
332
  p2 = self.up_4_2(p4, f2)
333
  p1 = self.up_2_1(p2, f1)
334
+ with safe_autocast(enabled=False):
335
  if seg_pass:
336
  if last_mask is not None:
337
  res = self.pred_seg(F.relu(p1.flatten(start_dim=0, end_dim=1).float()))
{matanyone β†’ matanyone2}/model/channel_attn.py RENAMED
File without changes
{matanyone β†’ matanyone2}/model/group_modules.py RENAMED
@@ -2,7 +2,7 @@ from typing import Optional
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
- from matanyone.model.channel_attn import CAResBlock
6
 
7
  def interpolate_groups(g: torch.Tensor, ratio: float, mode: str,
8
  align_corners: bool) -> torch.Tensor:
 
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
+ from matanyone2.model.channel_attn import CAResBlock
6
 
7
  def interpolate_groups(g: torch.Tensor, ratio: float, mode: str,
8
  align_corners: bool) -> torch.Tensor:
matanyone/model/matanyone.py β†’ matanyone2/model/matanyone2.py RENAMED
@@ -1,4 +1,4 @@
1
- from typing import List, Dict, Iterable
2
  import logging
3
  from omegaconf import DictConfig
4
  import torch
@@ -7,18 +7,21 @@ import torch.nn.functional as F
7
  from omegaconf import OmegaConf
8
  from huggingface_hub import PyTorchModelHubMixin
9
 
10
- from matanyone.model.big_modules import PixelEncoder, UncertPred, KeyProjection, MaskEncoder, PixelFeatureFuser, MaskDecoder
11
- from matanyone.model.aux_modules import AuxComputer
12
- from matanyone.model.utils.memory_utils import get_affinity, readout
13
- from matanyone.model.transformer.object_transformer import QueryTransformer
14
- from matanyone.model.transformer.object_summarizer import ObjectSummarizer
15
- from matanyone.utils.tensor_utils import aggregate
 
 
 
16
 
17
  log = logging.getLogger()
18
- class MatAnyone(nn.Module,
19
  PyTorchModelHubMixin,
20
- library_name="matanyone",
21
- repo_url="https://github.com/pq-yang/MatAnyone",
22
  coders={
23
  DictConfig: (
24
  lambda x: OmegaConf.to_container(x),
@@ -83,6 +86,8 @@ class MatAnyone(nn.Module,
83
  return uncert_output
84
 
85
  def encode_image(self, image: torch.Tensor, seq_length=None, last_feats=None) -> (Iterable[torch.Tensor], torch.Tensor): # type: ignore
 
 
86
  image = (image - self.pixel_mean) / self.pixel_std
87
  ms_image_feat = self.pixel_encoder(image, seq_length) # f16, f8, f4, f2, f1
88
  return ms_image_feat, self.pix_feat_proj(ms_image_feat[0])
@@ -96,7 +101,7 @@ class MatAnyone(nn.Module,
96
  *,
97
  deep_update: bool = True,
98
  chunk_size: int = -1,
99
- need_weights: bool = False) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor):
100
  image = (image - self.pixel_mean) / self.pixel_std
101
  others = self._get_others(masks)
102
  mask_value, new_sensory = self.mask_encoder(image,
@@ -113,7 +118,7 @@ class MatAnyone(nn.Module,
113
  final_pix_feat: torch.Tensor,
114
  *,
115
  need_sk: bool = True,
116
- need_ek: bool = True) -> (torch.Tensor, torch.Tensor, torch.Tensor):
117
  key, shrinkage, selection = self.key_proj(final_pix_feat, need_s=need_sk, need_e=need_ek)
118
  return key, shrinkage, selection
119
 
@@ -124,7 +129,7 @@ class MatAnyone(nn.Module,
124
  msk_value: torch.Tensor, obj_memory: torch.Tensor, pix_feat: torch.Tensor,
125
  sensory: torch.Tensor, last_mask: torch.Tensor,
126
  selector: torch.Tensor, uncert_output=None, seg_pass=False,
127
- last_pix_feat=None, last_pred_mask=None) -> (torch.Tensor, Dict[str, torch.Tensor]):
128
  """
129
  query_key : B * CK * H * W
130
  query_selection : B * CK * H * W
@@ -139,7 +144,7 @@ class MatAnyone(nn.Module,
139
  uncert_mask = uncert_output["mask"] if uncert_output is not None else None
140
 
141
  # read using visual attention
142
- with torch.cuda.amp.autocast(enabled=False):
143
  affinity = get_affinity(memory_key.float(), memory_shrinkage.float(), query_key.float(),
144
  query_selection.float(), uncert_mask=uncert_mask)
145
 
@@ -171,7 +176,7 @@ class MatAnyone(nn.Module,
171
  def read_first_frame_memory(self, pixel_readout,
172
  obj_memory: torch.Tensor, pix_feat: torch.Tensor,
173
  sensory: torch.Tensor, last_mask: torch.Tensor,
174
- selector: torch.Tensor, seg_pass=False) -> (torch.Tensor, Dict[str, torch.Tensor]):
175
  """
176
  query_key : B * CK * H * W
177
  query_selection : B * CK * H * W
@@ -218,7 +223,7 @@ class MatAnyone(nn.Module,
218
  *,
219
  selector=None,
220
  need_weights=False,
221
- seg_pass=False) -> (torch.Tensor, Dict[str, torch.Tensor]):
222
  return self.object_transformer(pixel_readout,
223
  obj_memory,
224
  selector=selector,
@@ -237,7 +242,7 @@ class MatAnyone(nn.Module,
237
  clamp_mat: bool = True,
238
  last_mask=None,
239
  sigmoid_residual=False,
240
- seg_mat=False) -> (torch.Tensor, torch.Tensor, torch.Tensor):
241
  """
242
  multi_scale_features is from the key encoder for skip-connection
243
  memory_readout is from working/long-term memory
 
1
+ from typing import List, Dict, Iterable, Tuple
2
  import logging
3
  from omegaconf import DictConfig
4
  import torch
 
7
  from omegaconf import OmegaConf
8
  from huggingface_hub import PyTorchModelHubMixin
9
 
10
+ from matanyone2.model.big_modules import PixelEncoder, UncertPred, KeyProjection, MaskEncoder, PixelFeatureFuser, MaskDecoder
11
+ from matanyone2.model.aux_modules import AuxComputer
12
+ from matanyone2.model.utils.memory_utils import get_affinity, readout
13
+ from matanyone2.model.transformer.object_transformer import QueryTransformer
14
+ from matanyone2.model.transformer.object_summarizer import ObjectSummarizer
15
+ from matanyone2.utils.tensor_utils import aggregate
16
+ from matanyone2.utils.device import get_default_device, safe_autocast
17
+
18
+ device = get_default_device()
19
 
20
  log = logging.getLogger()
21
+ class MatAnyone2(nn.Module,
22
  PyTorchModelHubMixin,
23
+ library_name="matanyone2",
24
+ repo_url="https://github.com/pq-yang/MatAnyone2",
25
  coders={
26
  DictConfig: (
27
  lambda x: OmegaConf.to_container(x),
 
86
  return uncert_output
87
 
88
  def encode_image(self, image: torch.Tensor, seq_length=None, last_feats=None) -> (Iterable[torch.Tensor], torch.Tensor): # type: ignore
89
+ self.pixel_mean = self.pixel_mean.to(device)
90
+ self.pixel_std = self.pixel_std.to(device)
91
  image = (image - self.pixel_mean) / self.pixel_std
92
  ms_image_feat = self.pixel_encoder(image, seq_length) # f16, f8, f4, f2, f1
93
  return ms_image_feat, self.pix_feat_proj(ms_image_feat[0])
 
101
  *,
102
  deep_update: bool = True,
103
  chunk_size: int = -1,
104
+ need_weights: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
105
  image = (image - self.pixel_mean) / self.pixel_std
106
  others = self._get_others(masks)
107
  mask_value, new_sensory = self.mask_encoder(image,
 
118
  final_pix_feat: torch.Tensor,
119
  *,
120
  need_sk: bool = True,
121
+ need_ek: bool = True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
122
  key, shrinkage, selection = self.key_proj(final_pix_feat, need_s=need_sk, need_e=need_ek)
123
  return key, shrinkage, selection
124
 
 
129
  msk_value: torch.Tensor, obj_memory: torch.Tensor, pix_feat: torch.Tensor,
130
  sensory: torch.Tensor, last_mask: torch.Tensor,
131
  selector: torch.Tensor, uncert_output=None, seg_pass=False,
132
+ last_pix_feat=None, last_pred_mask=None) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
133
  """
134
  query_key : B * CK * H * W
135
  query_selection : B * CK * H * W
 
144
  uncert_mask = uncert_output["mask"] if uncert_output is not None else None
145
 
146
  # read using visual attention
147
+ with safe_autocast(enabled=False):
148
  affinity = get_affinity(memory_key.float(), memory_shrinkage.float(), query_key.float(),
149
  query_selection.float(), uncert_mask=uncert_mask)
150
 
 
176
  def read_first_frame_memory(self, pixel_readout,
177
  obj_memory: torch.Tensor, pix_feat: torch.Tensor,
178
  sensory: torch.Tensor, last_mask: torch.Tensor,
179
+ selector: torch.Tensor, seg_pass=False) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
180
  """
181
  query_key : B * CK * H * W
182
  query_selection : B * CK * H * W
 
223
  *,
224
  selector=None,
225
  need_weights=False,
226
+ seg_pass=False) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
227
  return self.object_transformer(pixel_readout,
228
  obj_memory,
229
  selector=selector,
 
242
  clamp_mat: bool = True,
243
  last_mask=None,
244
  sigmoid_residual=False,
245
+ seg_mat=False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
246
  """
247
  multi_scale_features is from the key encoder for skip-connection
248
  memory_readout is from working/long-term memory
{matanyone β†’ matanyone2}/model/modules.py RENAMED
@@ -3,7 +3,8 @@ import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
 
6
- from matanyone.model.group_modules import MainToGroupDistributor, GroupResBlock, upsample_groups, GConv2d, downsample_groups
 
7
 
8
 
9
  class UpsampleBlock(nn.Module):
@@ -78,7 +79,7 @@ class SensoryUpdater_fullscale(nn.Module):
78
  self.g2_conv(downsample_groups(g[3], ratio=1/8)) + \
79
  self.g1_conv(downsample_groups(g[4], ratio=1/16))
80
 
81
- with torch.cuda.amp.autocast(enabled=False):
82
  g = g.float()
83
  h = h.float()
84
  values = self.transform(torch.cat([g, h], dim=2))
@@ -102,7 +103,7 @@ class SensoryUpdater(nn.Module):
102
  g = self.g16_conv(g[0]) + self.g8_conv(downsample_groups(g[1], ratio=1/2)) + \
103
  self.g4_conv(downsample_groups(g[2], ratio=1/4))
104
 
105
- with torch.cuda.amp.autocast(enabled=False):
106
  g = g.float()
107
  h = h.float()
108
  values = self.transform(torch.cat([g, h], dim=2))
@@ -119,7 +120,7 @@ class SensoryDeepUpdater(nn.Module):
119
  nn.init.xavier_normal_(self.transform.weight)
120
 
121
  def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
122
- with torch.cuda.amp.autocast(enabled=False):
123
  g = g.float()
124
  h = h.float()
125
  values = self.transform(torch.cat([g, h], dim=2))
@@ -146,4 +147,4 @@ class ResBlock(nn.Module):
146
 
147
  g = self.downsample(g)
148
 
149
- return out_g + g
 
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
 
6
+ from matanyone2.model.group_modules import MainToGroupDistributor, GroupResBlock, upsample_groups, GConv2d, downsample_groups
7
+ from matanyone2.utils.device import safe_autocast
8
 
9
 
10
  class UpsampleBlock(nn.Module):
 
79
  self.g2_conv(downsample_groups(g[3], ratio=1/8)) + \
80
  self.g1_conv(downsample_groups(g[4], ratio=1/16))
81
 
82
+ with safe_autocast(enabled=False):
83
  g = g.float()
84
  h = h.float()
85
  values = self.transform(torch.cat([g, h], dim=2))
 
103
  g = self.g16_conv(g[0]) + self.g8_conv(downsample_groups(g[1], ratio=1/2)) + \
104
  self.g4_conv(downsample_groups(g[2], ratio=1/4))
105
 
106
+ with safe_autocast(enabled=False):
107
  g = g.float()
108
  h = h.float()
109
  values = self.transform(torch.cat([g, h], dim=2))
 
120
  nn.init.xavier_normal_(self.transform.weight)
121
 
122
  def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
123
+ with safe_autocast(enabled=False):
124
  g = g.float()
125
  h = h.float()
126
  values = self.transform(torch.cat([g, h], dim=2))
 
147
 
148
  g = self.downsample(g)
149
 
150
+ return out_g + g
{matanyone/model β†’ matanyone2/model/transformer}/__init__.py RENAMED
File without changes
{matanyone β†’ matanyone2}/model/transformer/object_summarizer.py RENAMED
@@ -4,7 +4,8 @@ from omegaconf import DictConfig
4
  import torch
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
- from matanyone.model.transformer.positional_encoding import PositionalEncoding
 
8
 
9
 
10
  # @torch.jit.script
@@ -75,7 +76,7 @@ class ObjectSummarizer(nn.Module):
75
  pe = self.pos_enc(value)
76
  value = value + pe
77
 
78
- with torch.cuda.amp.autocast(enabled=False):
79
  value = value.float()
80
  feature = self.feature_pred(value)
81
  logits = self.weights_pred(value)
@@ -86,4 +87,4 @@ class ObjectSummarizer(nn.Module):
86
  if need_weights:
87
  return summaries, logits
88
  else:
89
- return summaries, None
 
4
  import torch
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
+ from matanyone2.model.transformer.positional_encoding import PositionalEncoding
8
+ from matanyone2.utils.device import safe_autocast
9
 
10
 
11
  # @torch.jit.script
 
76
  pe = self.pos_enc(value)
77
  value = value + pe
78
 
79
+ with safe_autocast(enabled=False): # autocast disabled intentionally
80
  value = value.float()
81
  feature = self.feature_pred(value)
82
  logits = self.weights_pred(value)
 
87
  if need_weights:
88
  return summaries, logits
89
  else:
90
+ return summaries, None
{matanyone β†’ matanyone2}/model/transformer/object_transformer.py RENAMED
@@ -3,10 +3,10 @@ from omegaconf import DictConfig
3
 
4
  import torch
5
  import torch.nn as nn
6
- from matanyone.model.group_modules import GConv2d
7
- from matanyone.utils.tensor_utils import aggregate
8
- from matanyone.model.transformer.positional_encoding import PositionalEncoding
9
- from matanyone.model.transformer.transformer_layers import CrossAttention, SelfAttention, FFN, PixelFFN
10
 
11
 
12
  class QueryTransformerBlock(nn.Module):
 
3
 
4
  import torch
5
  import torch.nn as nn
6
+ from matanyone2.model.group_modules import GConv2d
7
+ from matanyone2.utils.tensor_utils import aggregate
8
+ from matanyone2.model.transformer.positional_encoding import PositionalEncoding
9
+ from matanyone2.model.transformer.transformer_layers import CrossAttention, SelfAttention, FFN, PixelFFN
10
 
11
 
12
  class QueryTransformerBlock(nn.Module):
{matanyone β†’ matanyone2}/model/transformer/positional_encoding.py RENAMED
@@ -7,6 +7,7 @@ import math
7
  import numpy as np
8
  import torch
9
  from torch import nn
 
10
 
11
 
12
  def get_emb(sin_inp: torch.Tensor) -> torch.Tensor:
@@ -98,8 +99,9 @@ class PositionalEncoding(nn.Module):
98
 
99
 
100
  if __name__ == '__main__':
101
- pe = PositionalEncoding(8).cuda()
102
- input = torch.ones((1, 8, 8, 8)).cuda()
 
103
  output = pe(input)
104
  # print(output)
105
  print(output[0, :, 0, 0])
 
7
  import numpy as np
8
  import torch
9
  from torch import nn
10
+ from matanyone2.utils.device import get_default_device
11
 
12
 
13
  def get_emb(sin_inp: torch.Tensor) -> torch.Tensor:
 
99
 
100
 
101
  if __name__ == '__main__':
102
+ device = get_default_device()
103
+ pe = PositionalEncoding(8).to(device)
104
+ input = torch.ones((1, 8, 8, 8), device=device)
105
  output = pe(input)
106
  # print(output)
107
  print(output[0, :, 0, 0])
{matanyone β†’ matanyone2}/model/transformer/transformer_layers.py RENAMED
@@ -6,7 +6,7 @@ import torch
6
  from torch import Tensor
7
  import torch.nn as nn
8
  import torch.nn.functional as F
9
- from matanyone.model.channel_attn import CAResBlock
10
 
11
 
12
  class SelfAttention(nn.Module):
 
6
  from torch import Tensor
7
  import torch.nn as nn
8
  import torch.nn.functional as F
9
+ from matanyone2.model.channel_attn import CAResBlock
10
 
11
 
12
  class SelfAttention(nn.Module):
{matanyone/model/transformer β†’ matanyone2/model/utils}/__init__.py RENAMED
File without changes
{matanyone β†’ matanyone2}/model/utils/memory_utils.py RENAMED
File without changes
{matanyone β†’ matanyone2}/model/utils/parameter_groups.py RENAMED
File without changes
{matanyone β†’ matanyone2}/model/utils/resnet.py RENAMED
File without changes
{matanyone/model β†’ matanyone2}/utils/__init__.py RENAMED
File without changes
matanyone2/utils/device.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import functools
3
+
4
+ def get_default_device():
5
+ if torch.cuda.is_available():
6
+ return torch.device("cuda")
7
+ elif torch.backends.mps.is_built() and torch.backends.mps.is_available():
8
+ return torch.device("mps")
9
+ else:
10
+ return torch.device("cpu")
11
+
12
+ def safe_autocast_decorator(enabled=True):
13
+ def decorator(func):
14
+ @functools.wraps(func)
15
+ def wrapper(*args, **kwargs):
16
+ device = get_default_device()
17
+ if device.type in ["cuda", "cpu"]:
18
+ with torch.amp.autocast(device_type=device.type, enabled=enabled):
19
+ return func(*args, **kwargs)
20
+ else:
21
+ return func(*args, **kwargs)
22
+ return wrapper
23
+ return decorator
24
+
25
+ import contextlib
26
+ @contextlib.contextmanager
27
+ def safe_autocast(enabled=True):
28
+ device = get_default_device()
29
+ if device.type in ["cuda", "cpu"]:
30
+ with torch.amp.autocast(device_type=device.type, enabled=enabled):
31
+ yield
32
+ else:
33
+ yield # MPS or other unsupported backends skip autocast
{matanyone β†’ matanyone2}/utils/get_default_model.py RENAMED
@@ -5,9 +5,9 @@ from omegaconf import open_dict
5
  from hydra import compose, initialize
6
 
7
  import torch
8
- from matanyone.model.matanyone import MatAnyone
9
 
10
- def get_matanyone_model(ckpt_path, device=None) -> MatAnyone:
11
  initialize(version_base='1.3.2', config_path="../config", job_name="eval_our_config")
12
  cfg = compose(config_name="eval_matanyone_config")
13
 
@@ -16,12 +16,12 @@ def get_matanyone_model(ckpt_path, device=None) -> MatAnyone:
16
 
17
  # Load the network weights
18
  if device is not None:
19
- matanyone = MatAnyone(cfg, single_object=True).to(device).eval()
20
  model_weights = torch.load(cfg.weights, map_location=device)
21
  else: # if device is not specified, `.cuda()` by default
22
- matanyone = MatAnyone(cfg, single_object=True).cuda().eval()
23
  model_weights = torch.load(cfg.weights)
24
 
25
- matanyone.load_weights(model_weights)
26
 
27
- return matanyone
 
5
  from hydra import compose, initialize
6
 
7
  import torch
8
+ from matanyone2.model.matanyone2 import MatAnyone2
9
 
10
+ def get_matanyone2_model(ckpt_path, device=None) -> MatAnyone2:
11
  initialize(version_base='1.3.2', config_path="../config", job_name="eval_our_config")
12
  cfg = compose(config_name="eval_matanyone_config")
13
 
 
16
 
17
  # Load the network weights
18
  if device is not None:
19
+ matanyone2 = MatAnyone2(cfg, single_object=True).to(device).eval()
20
  model_weights = torch.load(cfg.weights, map_location=device)
21
  else: # if device is not specified, `.cuda()` by default
22
+ matanyone2 = MatAnyone2(cfg, single_object=True).cuda().eval()
23
  model_weights = torch.load(cfg.weights)
24
 
25
+ matanyone2.load_weights(model_weights)
26
 
27
+ return matanyone2
matanyone2/utils/inference_utils.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import random
4
+ import numpy as np
5
+
6
+ import torch
7
+ import torchvision
8
+
9
+ IMAGE_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG')
10
+ VIDEO_EXTENSIONS = ('.mp4', '.mov', '.avi', '.MP4', '.MOV', '.AVI')
11
+
12
+ def read_frame_from_videos(frame_root):
13
+ if frame_root.endswith(VIDEO_EXTENSIONS): # Video file path
14
+ video_name = os.path.basename(frame_root)[:-4]
15
+ frames, _, info = torchvision.io.read_video(filename=frame_root, pts_unit='sec', output_format='TCHW') # RGB
16
+ fps = info['video_fps']
17
+ else:
18
+ video_name = os.path.basename(frame_root)
19
+ frames = []
20
+ fr_lst = sorted(os.listdir(frame_root))
21
+ for fr in fr_lst:
22
+ frame = cv2.imread(os.path.join(frame_root, fr))[...,[2,1,0]] # RGB, HWC
23
+ frames.append(frame)
24
+ fps = 24 # default
25
+ frames = torch.Tensor(np.array(frames)).permute(0, 3, 1, 2).contiguous() # TCHW
26
+
27
+ length = frames.shape[0]
28
+
29
+ return frames, fps, length, video_name
30
+
31
+ def get_video_paths(input_root):
32
+ video_paths = []
33
+ for root, _, files in os.walk(input_root):
34
+ for file in files:
35
+ if file.lower().endswith(VIDEO_EXTENSIONS):
36
+ video_paths.append(os.path.join(root, file))
37
+ return sorted(video_paths)
38
+
39
+ def str_to_list(value):
40
+ return list(map(int, value.split(',')))
41
+
42
+ def gen_dilate(alpha, min_kernel_size, max_kernel_size):
43
+ kernel_size = random.randint(min_kernel_size, max_kernel_size)
44
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size,kernel_size))
45
+ fg_and_unknown = np.array(np.not_equal(alpha, 0).astype(np.float32))
46
+ dilate = cv2.dilate(fg_and_unknown, kernel, iterations=1)*255
47
+ return dilate.astype(np.float32)
48
+
49
+ def gen_erosion(alpha, min_kernel_size, max_kernel_size):
50
+ kernel_size = random.randint(min_kernel_size, max_kernel_size)
51
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size,kernel_size))
52
+ fg = np.array(np.equal(alpha, 255).astype(np.float32))
53
+ erode = cv2.erode(fg, kernel, iterations=1)*255
54
+ return erode.astype(np.float32)
{matanyone β†’ matanyone2}/utils/tensor_utils.py RENAMED
@@ -1,7 +1,7 @@
1
  from typing import List, Iterable
2
  import torch
3
  import torch.nn.functional as F
4
-
5
 
6
  # STM
7
  def pad_divide_by(in_img: torch.Tensor, d: int) -> (torch.Tensor, Iterable[int]):
@@ -45,7 +45,7 @@ def unpad(img: torch.Tensor, pad: Iterable[int]) -> torch.Tensor:
45
 
46
  # @torch.jit.script
47
  def aggregate(prob: torch.Tensor, dim: int) -> torch.Tensor:
48
- with torch.cuda.amp.autocast(enabled=False):
49
  prob = prob.float()
50
  new_prob = torch.cat([torch.prod(1 - prob, dim=dim, keepdim=True), prob],
51
  dim).clamp(1e-7, 1 - 1e-7)
@@ -59,4 +59,4 @@ def cls_to_one_hot(cls_gt: torch.Tensor, num_objects: int) -> torch.Tensor:
59
  # cls_gt: B*1*H*W
60
  B, _, H, W = cls_gt.shape
61
  one_hot = torch.zeros(B, num_objects + 1, H, W, device=cls_gt.device).scatter_(1, cls_gt, 1)
62
- return one_hot
 
1
  from typing import List, Iterable
2
  import torch
3
  import torch.nn.functional as F
4
+ from matanyone2.utils.device import safe_autocast
5
 
6
  # STM
7
  def pad_divide_by(in_img: torch.Tensor, d: int) -> (torch.Tensor, Iterable[int]):
 
45
 
46
  # @torch.jit.script
47
  def aggregate(prob: torch.Tensor, dim: int) -> torch.Tensor:
48
+ with safe_autocast(enabled=False):
49
  prob = prob.float()
50
  new_prob = torch.cat([torch.prod(1 - prob, dim=dim, keepdim=True), prob],
51
  dim).clamp(1e-7, 1 - 1e-7)
 
59
  # cls_gt: B*1*H*W
60
  B, _, H, W = cls_gt.shape
61
  one_hot = torch.zeros(B, num_objects + 1, H, W, device=cls_gt.device).scatter_(1, cls_gt, 1)
62
+ return one_hot