| """ |
| VideoMaMa Inference Wrapper |
| Handles video matting with mask conditioning |
| """ |
|
|
| import os |
| import torch |
|
|
| |
| os.environ['TORCH_HOME'] = '/tmp/torch_cache' |
| os.environ['HUB_DIR'] = '/tmp/torch_hub' |
| os.environ['TMPDIR'] = '/tmp' |
| torch.hub.set_dir('/tmp/torch_hub') |
|
|
| import os |
| import torch |
| import numpy as np |
| from PIL import Image |
| from pathlib import Path |
| from typing import List |
| import tqdm |
|
|
| from pipeline_svd_mask import VideoInferencePipeline |
|
|
|
|
| def videomama(pipeline, frames_np, mask_frames_np): |
| """ |
| Run VideoMaMa inference on video frames with mask conditioning |
| |
| Args: |
| pipeline: VideoInferencePipeline instance |
| frames_np: List of numpy arrays, [(H,W,3)]*n, uint8 RGB frames |
| mask_frames_np: List of numpy arrays, [(H,W)]*n, uint8 grayscale masks |
| |
| Returns: |
| output_frames: List of numpy arrays, [(H,W,3)]*n, uint8 RGB outputs |
| """ |
| |
| frames_pil = [Image.fromarray(f) for f in frames_np] |
| mask_frames_pil = [Image.fromarray(m, mode='L') for m in mask_frames_np] |
| |
| |
| target_width, target_height = 1024, 576 |
| frames_resized = [f.resize((target_width, target_height), Image.Resampling.BILINEAR) |
| for f in frames_pil] |
| masks_resized = [m.resize((target_width, target_height), Image.Resampling.BILINEAR) |
| for m in mask_frames_pil] |
| |
| |
| print(f"Running VideoMaMa inference on {len(frames_resized)} frames...") |
| output_frames_pil = pipeline.run( |
| cond_frames=frames_resized, |
| mask_frames=masks_resized, |
| seed=42, |
| mask_cond_mode="vae" |
| ) |
| |
| |
| original_size = frames_pil[0].size |
| output_frames_resized = [f.resize(original_size, Image.Resampling.BILINEAR) |
| for f in output_frames_pil] |
| |
| |
| output_frames_np = [np.array(f) for f in output_frames_resized] |
| |
| return output_frames_np |
|
|
|
|
| def load_videomama_pipeline(device="cuda"): |
| """ |
| Load VideoMaMa pipeline with pretrained weights |
| |
| Args: |
| device: Device to run on |
| |
| Returns: |
| VideoInferencePipeline instance |
| """ |
| |
| |
| base_model_path = os.path.join("checkpoints", "stable-video-diffusion-img2vid-xt") |
| unet_checkpoint_path = os.path.join("checkpoints", "videomama") |
| |
| |
| if not os.path.exists(base_model_path): |
| raise FileNotFoundError( |
| f"SVD base model not found at {base_model_path}. " |
| "Please run download_checkpoints.sh first." |
| ) |
| |
| if not os.path.exists(unet_checkpoint_path): |
| raise FileNotFoundError( |
| f"VideoMaMa checkpoint not found at {unet_checkpoint_path}. " |
| "Please run download_checkpoints.sh first." |
| ) |
| |
| print(f"Loading VideoMaMa pipeline from {unet_checkpoint_path}...") |
| |
| pipeline = VideoInferencePipeline( |
| base_model_path=base_model_path, |
| unet_checkpoint_path=unet_checkpoint_path, |
| weight_dtype=torch.float16, |
| device=device |
| ) |
| |
| print("VideoMaMa pipeline loaded successfully!") |
| |
| return pipeline |
|
|