Phase 1: Add SAM2/MatAnyone optimization infrastructure
Browse files- Dockerfile +51 -5
- models/__init__.py +0 -0
- models/matanyone_loader.py +72 -0
- models/sam2_loader.py +75 -0
- requirements.txt +37 -2
- utils/__init__.py +0 -0
- utils/accelerator.py +34 -0
Dockerfile
CHANGED
|
@@ -1,11 +1,57 @@
|
|
| 1 |
-
#
|
| 2 |
-
FROM
|
| 3 |
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
RUN pip install --no-cache-dir -r requirements.txt
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
COPY . .
|
| 10 |
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Production Dockerfile for BackgroundFX Pro with SAM2 + MatAnyone
|
| 2 |
+
FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04
|
| 3 |
|
| 4 |
+
# System dependencies
|
| 5 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 6 |
+
git ffmpeg libglib2.0-0 libgl1 libglib2.0-0 libsm6 libxrender1 libxext6 \
|
| 7 |
+
python3.10 python3.10-venv python3-pip \
|
| 8 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 9 |
|
| 10 |
+
# Upgrade pip
|
| 11 |
+
RUN python3 -m pip install --upgrade pip
|
| 12 |
+
|
| 13 |
+
# Environment variables for caching and performance
|
| 14 |
+
ENV HF_HOME=/home/user/.cache/huggingface \
|
| 15 |
+
TORCH_HOME=/home/user/.cache/torch \
|
| 16 |
+
TRANSFORMERS_CACHE=/home/user/.cache/transformers \
|
| 17 |
+
MPLCONFIGDIR=/home/user/.config/matplotlib
|
| 18 |
+
|
| 19 |
+
# CUDA and memory optimizations for T4
|
| 20 |
+
ENV PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128,expandable_segments:True \
|
| 21 |
+
CUDA_LAUNCH_BLOCKING=0 \
|
| 22 |
+
OMP_NUM_THREADS=2 \
|
| 23 |
+
MKL_NUM_THREADS=2 \
|
| 24 |
+
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
| 25 |
+
TOKENIZERS_PARALLELISM=false
|
| 26 |
+
|
| 27 |
+
# Create working directory
|
| 28 |
+
WORKDIR /home/user/app
|
| 29 |
+
|
| 30 |
+
# Copy and install Python dependencies
|
| 31 |
+
COPY requirements.txt ./requirements.txt
|
| 32 |
RUN pip install --no-cache-dir -r requirements.txt
|
| 33 |
|
| 34 |
+
# Vendor SAM2 and MatAnyone at build time (more reliable than runtime git)
|
| 35 |
+
# SAM2
|
| 36 |
+
RUN git clone --depth=1 https://github.com/facebookresearch/segment-anything-2 /home/user/app/third_party/sam2
|
| 37 |
+
ENV PYTHONPATH=/home/user/app/third_party/sam2:${PYTHONPATH}
|
| 38 |
+
|
| 39 |
+
# MatAnyone (official repo)
|
| 40 |
+
RUN git clone --depth=1 https://github.com/pq-yang/MatAnyone /home/user/app/third_party/matanyone
|
| 41 |
+
ENV PYTHONPATH=/home/user/app/third_party/matanyone:${PYTHONPATH}
|
| 42 |
+
|
| 43 |
+
# Copy application code
|
| 44 |
COPY . .
|
| 45 |
|
| 46 |
+
# Create cache directories
|
| 47 |
+
RUN mkdir -p /home/user/.cache/huggingface /home/user/.cache/torch /home/user/.cache/transformers
|
| 48 |
+
|
| 49 |
+
# Expose Gradio port
|
| 50 |
+
EXPOSE 7860
|
| 51 |
+
|
| 52 |
+
# Environment for Gradio
|
| 53 |
+
ENV GRADIO_SERVER_NAME=0.0.0.0 \
|
| 54 |
+
GRADIO_SERVER_PORT=7860
|
| 55 |
+
|
| 56 |
+
# Run the application
|
| 57 |
+
CMD ["python3", "app.py"]
|
models/__init__.py
ADDED
|
File without changes
|
models/matanyone_loader.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# models/matanyone_loader.py
|
| 2 |
+
import os, logging, torch, gc
|
| 3 |
+
import numpy as np
|
| 4 |
+
from typing import Optional, Tuple
|
| 5 |
+
|
| 6 |
+
log = logging.getLogger("matany_loader")
|
| 7 |
+
|
| 8 |
+
def _import_inference_core():
|
| 9 |
+
try:
|
| 10 |
+
# Check the actual import path from pq-yang/MatAnyone repo
|
| 11 |
+
from matanyone.inference_core import InferenceCore
|
| 12 |
+
return InferenceCore
|
| 13 |
+
except Exception as e:
|
| 14 |
+
log.error("MatAnyone import failed (vendoring/repo path?): %s", e)
|
| 15 |
+
return None
|
| 16 |
+
|
| 17 |
+
def _to_chw01(img):
|
| 18 |
+
# img: HWC uint8 or float01 -> CHW float01
|
| 19 |
+
if img.dtype != np.float32:
|
| 20 |
+
img = img.astype("float32")/255.0
|
| 21 |
+
return np.transpose(img, (2,0,1))
|
| 22 |
+
|
| 23 |
+
def _to_1hw01(mask):
|
| 24 |
+
# mask: HxW [0,1]
|
| 25 |
+
m = mask.astype("float32")
|
| 26 |
+
return m[None, ...]
|
| 27 |
+
|
| 28 |
+
class MatAnyoneSession:
|
| 29 |
+
def __init__(self, device: torch.device, precision: str = "fp16"):
|
| 30 |
+
self.device = device
|
| 31 |
+
self.precision = precision
|
| 32 |
+
self.core = None
|
| 33 |
+
|
| 34 |
+
def load(self, ckpt_path: Optional[str] = None, repo_id: Optional[str] = None, filename: Optional[str] = None):
|
| 35 |
+
InferenceCore = _import_inference_core()
|
| 36 |
+
if InferenceCore is None:
|
| 37 |
+
raise RuntimeError("MatAnyone not importable")
|
| 38 |
+
|
| 39 |
+
if ckpt_path is None and repo_id and filename:
|
| 40 |
+
from huggingface_hub import hf_hub_download
|
| 41 |
+
ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename, local_dir=os.environ.get("HF_HOME"))
|
| 42 |
+
|
| 43 |
+
# init model
|
| 44 |
+
self.core = InferenceCore(ckpt_path, device=str(self.device))
|
| 45 |
+
return self
|
| 46 |
+
|
| 47 |
+
@torch.inference_mode()
|
| 48 |
+
def step(self, image_rgb, seed_mask: Optional[np.ndarray]=None):
|
| 49 |
+
"""
|
| 50 |
+
image_rgb: HxWx3 uint8/float01
|
| 51 |
+
seed_mask: HxW float01 for first frame, else None
|
| 52 |
+
returns alpha HxW float01
|
| 53 |
+
"""
|
| 54 |
+
assert self.core is not None, "MatAnyone not loaded"
|
| 55 |
+
img = _to_chw01(image_rgb) # CHW
|
| 56 |
+
if seed_mask is not None:
|
| 57 |
+
mask = _to_1hw01(seed_mask) # 1HW
|
| 58 |
+
alpha = self.core.step(img, mask)
|
| 59 |
+
else:
|
| 60 |
+
alpha = self.core.step(img, None)
|
| 61 |
+
# ensure HxW
|
| 62 |
+
if isinstance(alpha, np.ndarray):
|
| 63 |
+
return alpha.astype("float32")
|
| 64 |
+
if torch.is_tensor(alpha):
|
| 65 |
+
return alpha.detach().float().cpu().numpy()
|
| 66 |
+
raise RuntimeError("MatAnyone returned unknown alpha type")
|
| 67 |
+
|
| 68 |
+
def reset(self):
|
| 69 |
+
if self.core and hasattr(self.core, "reset"):
|
| 70 |
+
self.core.reset()
|
| 71 |
+
torch.cuda.empty_cache()
|
| 72 |
+
gc.collect()
|
models/sam2_loader.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# models/sam2_loader.py
|
| 2 |
+
import os, logging, torch
|
| 3 |
+
from huggingface_hub import hf_hub_download
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
log = logging.getLogger("sam2_loader")
|
| 8 |
+
|
| 9 |
+
DEFAULT_MODEL_ID = os.environ.get("SAM2_MODEL_ID", "facebook/sam2")
|
| 10 |
+
DEFAULT_VARIANT = os.environ.get("SAM2_VARIANT", "sam2_hiera_large")
|
| 11 |
+
|
| 12 |
+
# Map variant -> filenames (SAM2 releases follow this pattern)
|
| 13 |
+
VARIANT_FILES = {
|
| 14 |
+
"sam2_hiera_small": ("sam2_hiera_small.pt", "configs/sam2/sam2_hiera_s.yaml"),
|
| 15 |
+
"sam2_hiera_base": ("sam2_hiera_base.pt", "configs/sam2/sam2_hiera_b.yaml"),
|
| 16 |
+
"sam2_hiera_large": ("sam2_hiera_large.pt", "configs/sam2/sam2_hiera_l.yaml"),
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
def _download_checkpoint(model_id: str, ckpt_name: str) -> str:
|
| 20 |
+
return hf_hub_download(repo_id=model_id, filename=ckpt_name, local_dir=os.environ.get("HF_HOME"))
|
| 21 |
+
|
| 22 |
+
def _find_sam2_build():
|
| 23 |
+
try:
|
| 24 |
+
from sam2.build_sam import build_sam2
|
| 25 |
+
return build_sam2
|
| 26 |
+
except Exception as e:
|
| 27 |
+
log.error("SAM2 not importable (check Dockerfile vendoring): %s", e)
|
| 28 |
+
return None
|
| 29 |
+
|
| 30 |
+
class SAM2Predictor:
|
| 31 |
+
def __init__(self, device: torch.device):
|
| 32 |
+
self.device = device
|
| 33 |
+
self.model = None
|
| 34 |
+
self.predictor = None
|
| 35 |
+
|
| 36 |
+
def load(self, variant: str = DEFAULT_VARIANT, model_id: str = DEFAULT_MODEL_ID):
|
| 37 |
+
build_sam2 = _find_sam2_build()
|
| 38 |
+
if build_sam2 is None:
|
| 39 |
+
raise RuntimeError("SAM2 build function not available")
|
| 40 |
+
|
| 41 |
+
ckpt_name, cfg_path = VARIANT_FILES.get(variant, VARIANT_FILES["sam2_hiera_large"])
|
| 42 |
+
ckpt = _download_checkpoint(model_id, ckpt_name)
|
| 43 |
+
|
| 44 |
+
# Compose config via hydra-free path (using explicit path args)
|
| 45 |
+
model = build_sam2(config_file=cfg_path, ckpt_path=ckpt, device=str(self.device))
|
| 46 |
+
model.eval()
|
| 47 |
+
self.model = model
|
| 48 |
+
|
| 49 |
+
try:
|
| 50 |
+
from sam2.sam2_video_predictor import SAM2VideoPredictor
|
| 51 |
+
self.predictor = SAM2VideoPredictor(self.model)
|
| 52 |
+
except Exception:
|
| 53 |
+
# Fallback to image predictor if video predictor missing
|
| 54 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
| 55 |
+
self.predictor = SAM2ImagePredictor(self.model)
|
| 56 |
+
|
| 57 |
+
return self
|
| 58 |
+
|
| 59 |
+
@torch.inference_mode()
|
| 60 |
+
def first_frame_mask(self, image_rgb01):
|
| 61 |
+
"""
|
| 62 |
+
Returns an initial binary-ish mask for the foreground subject from first frame.
|
| 63 |
+
You can refine prompts here (points/boxes) if you add UI hooks later.
|
| 64 |
+
"""
|
| 65 |
+
if hasattr(self.predictor, "set_image"):
|
| 66 |
+
self.predictor.set_image((image_rgb01*255).astype("uint8"))
|
| 67 |
+
# simple auto-box prompt (tight box)
|
| 68 |
+
h, w = image_rgb01.shape[:2]
|
| 69 |
+
box = np.array([1, 1, w-2, h-2])
|
| 70 |
+
masks, _, _ = self.predictor.predict(box=box, multimask_output=False)
|
| 71 |
+
mask = masks[0] # HxW bool/float
|
| 72 |
+
else:
|
| 73 |
+
# video predictor path: run_single_frame if available
|
| 74 |
+
mask = (image_rgb01[...,0] > -1) # dummy, should not happen
|
| 75 |
+
return mask.astype("float32")
|
requirements.txt
CHANGED
|
@@ -1,8 +1,43 @@
|
|
|
|
|
|
|
|
| 1 |
torch==2.2.2
|
| 2 |
torchvision==0.17.2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
opencv-python-headless==4.10.0.84
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
numpy==1.26.4
|
| 5 |
-
|
|
|
|
| 6 |
gradio==5.42.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
requests==2.31.0
|
| 8 |
-
|
|
|
|
| 1 |
+
# ===== Core runtime =====
|
| 2 |
+
# Option A: Keep your current Torch stack (safe for existing builds)
|
| 3 |
torch==2.2.2
|
| 4 |
torchvision==0.17.2
|
| 5 |
+
torchaudio==2.2.2
|
| 6 |
+
|
| 7 |
+
# Option B: Faster CUDA 12.1 wheels for T4 (uncomment to use instead)
|
| 8 |
+
# torch==2.3.1+cu121
|
| 9 |
+
# torchvision==0.18.1+cu121
|
| 10 |
+
# torchaudio==2.3.1+cu121
|
| 11 |
+
# --extra-index-url https://download.pytorch.org/whl/cu121
|
| 12 |
+
|
| 13 |
+
# ===== Video / image IO =====
|
| 14 |
opencv-python-headless==4.10.0.84
|
| 15 |
+
imageio==2.35.1
|
| 16 |
+
imageio-ffmpeg==0.5.1
|
| 17 |
+
moviepy==1.0.3
|
| 18 |
+
decord==0.6.0
|
| 19 |
+
Pillow==10.4.0
|
| 20 |
numpy==1.26.4
|
| 21 |
+
|
| 22 |
+
# ===== Gradio UI =====
|
| 23 |
gradio==5.42.0
|
| 24 |
+
|
| 25 |
+
# ===== SAM2 Dependencies =====
|
| 26 |
+
hydra-core==1.3.2
|
| 27 |
+
omegaconf==2.3.0
|
| 28 |
+
einops==0.8.0
|
| 29 |
+
timm==1.0.9
|
| 30 |
+
pyyaml==6.0.2
|
| 31 |
+
matplotlib==3.9.2
|
| 32 |
+
|
| 33 |
+
# ===== MatAnyone Dependencies =====
|
| 34 |
+
kornia==0.7.3
|
| 35 |
+
scikit-image==0.24.0
|
| 36 |
+
tqdm==4.66.5
|
| 37 |
+
|
| 38 |
+
# ===== Helpers / caching =====
|
| 39 |
+
huggingface_hub==0.24.6
|
| 40 |
+
ffmpeg-python==0.2.0
|
| 41 |
+
psutil==6.0.0
|
| 42 |
requests==2.31.0
|
| 43 |
+
scikit-learn==1.5.1
|
utils/__init__.py
ADDED
|
File without changes
|
utils/accelerator.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# utils/accelerator.py
|
| 2 |
+
import os, torch, logging, psutil, gc
|
| 3 |
+
|
| 4 |
+
log = logging.getLogger("accelerator")
|
| 5 |
+
|
| 6 |
+
def pick_device():
|
| 7 |
+
if torch.cuda.is_available():
|
| 8 |
+
dev = torch.device("cuda")
|
| 9 |
+
name = torch.cuda.get_device_name(0)
|
| 10 |
+
log.info(f"Using GPU: {name}")
|
| 11 |
+
return dev
|
| 12 |
+
log.warning("CUDA not available; falling back to CPU.")
|
| 13 |
+
return torch.device("cpu")
|
| 14 |
+
|
| 15 |
+
def torch_global_tuning():
|
| 16 |
+
# better matmul perf without crazy memory
|
| 17 |
+
try:
|
| 18 |
+
torch.set_float32_matmul_precision("high")
|
| 19 |
+
except Exception:
|
| 20 |
+
pass
|
| 21 |
+
|
| 22 |
+
def memory_checkpoint(tag=""):
|
| 23 |
+
try:
|
| 24 |
+
if torch.cuda.is_available():
|
| 25 |
+
mem = torch.cuda.memory_allocated() / (1024**2)
|
| 26 |
+
log.info(f"[CUDA mem] {tag}: {mem:.1f} MB")
|
| 27 |
+
except Exception:
|
| 28 |
+
pass
|
| 29 |
+
|
| 30 |
+
def cleanup():
|
| 31 |
+
if torch.cuda.is_available():
|
| 32 |
+
torch.cuda.synchronize()
|
| 33 |
+
torch.cuda.empty_cache()
|
| 34 |
+
gc.collect()
|