Update utils/refinement.py
Browse files- utils/refinement.py +33 -5
utils/refinement.py
CHANGED
|
@@ -141,16 +141,29 @@ def _refine_with_matanyone(
|
|
| 141 |
image_tensor = torch.from_numpy(image_rgb).permute(2, 0, 1).float() / 255.0
|
| 142 |
image_tensor = image_tensor.unsqueeze(0).to(device) # Add batch dimension and move to GPU
|
| 143 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
# Ensure mask is binary uint8
|
| 145 |
if mask.dtype != np.uint8:
|
| 146 |
mask = (mask * 255).astype(np.uint8) if mask.max() <= 1 else mask.astype(np.uint8)
|
| 147 |
-
|
| 148 |
-
|
|
|
|
|
|
|
| 149 |
|
| 150 |
# Convert mask to tensor and move to GPU
|
| 151 |
mask_tensor = torch.from_numpy(mask).float() / 255.0
|
| 152 |
mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0).to(device) # (1, 1, H, W) on GPU
|
| 153 |
|
|
|
|
|
|
|
|
|
|
| 154 |
# Try different methods on InferenceCore
|
| 155 |
result = None
|
| 156 |
|
|
@@ -224,10 +237,18 @@ def _refine_batch_with_matanyone(
|
|
| 224 |
|
| 225 |
# Prepare first mask for initialization
|
| 226 |
first_mask = masks[0]
|
| 227 |
-
|
| 228 |
-
|
| 229 |
if first_mask.ndim == 3:
|
| 230 |
-
first_mask =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
|
| 232 |
# Convert first mask to tensor and move to GPU
|
| 233 |
first_mask_tensor = torch.from_numpy(first_mask).float() / 255.0
|
|
@@ -262,6 +283,13 @@ def _refine_batch_with_matanyone(
|
|
| 262 |
# Fallback to processing each frame with its mask
|
| 263 |
log.warning("MatAnyone batch processing not available, using frame-by-frame")
|
| 264 |
for frame_tensor, mask in zip(frame_tensors, masks):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
mask_tensor = torch.from_numpy(mask).float() / 255.0
|
| 266 |
mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0).to(device)
|
| 267 |
frame_on_device = frame_tensor.unsqueeze(0).to(device)
|
|
|
|
| 141 |
image_tensor = torch.from_numpy(image_rgb).permute(2, 0, 1).float() / 255.0
|
| 142 |
image_tensor = image_tensor.unsqueeze(0).to(device) # Add batch dimension and move to GPU
|
| 143 |
|
| 144 |
+
# CRITICAL: Ensure mask is 2D before processing
|
| 145 |
+
if mask.ndim == 3:
|
| 146 |
+
# Convert multi-channel to single channel
|
| 147 |
+
if mask.shape[2] == 3:
|
| 148 |
+
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
|
| 149 |
+
else:
|
| 150 |
+
mask = mask[:, :, 0]
|
| 151 |
+
|
| 152 |
# Ensure mask is binary uint8
|
| 153 |
if mask.dtype != np.uint8:
|
| 154 |
mask = (mask * 255).astype(np.uint8) if mask.max() <= 1 else mask.astype(np.uint8)
|
| 155 |
+
|
| 156 |
+
# Final verification that mask is 2D
|
| 157 |
+
assert mask.ndim == 2, f"Mask must be 2D after conversion, got shape {mask.shape}"
|
| 158 |
+
assert mask.shape == (h, w), f"Mask shape {mask.shape} doesn't match image shape ({h}, {w})"
|
| 159 |
|
| 160 |
# Convert mask to tensor and move to GPU
|
| 161 |
mask_tensor = torch.from_numpy(mask).float() / 255.0
|
| 162 |
mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0).to(device) # (1, 1, H, W) on GPU
|
| 163 |
|
| 164 |
+
# Verify tensor dimensions
|
| 165 |
+
assert mask_tensor.shape == (1, 1, h, w), f"Mask tensor wrong shape: {mask_tensor.shape}, expected (1, 1, {h}, {w})"
|
| 166 |
+
|
| 167 |
# Try different methods on InferenceCore
|
| 168 |
result = None
|
| 169 |
|
|
|
|
| 237 |
|
| 238 |
# Prepare first mask for initialization
|
| 239 |
first_mask = masks[0]
|
| 240 |
+
|
| 241 |
+
# CRITICAL: Ensure first mask is 2D
|
| 242 |
if first_mask.ndim == 3:
|
| 243 |
+
if first_mask.shape[2] == 3:
|
| 244 |
+
first_mask = cv2.cvtColor(first_mask, cv2.COLOR_BGR2GRAY)
|
| 245 |
+
else:
|
| 246 |
+
first_mask = first_mask[:, :, 0]
|
| 247 |
+
|
| 248 |
+
if first_mask.dtype != np.uint8:
|
| 249 |
+
first_mask = (first_mask * 255).astype(np.uint8) if first_mask.max() <= 1 else first_mask.astype(np.uint8)
|
| 250 |
+
|
| 251 |
+
assert first_mask.ndim == 2, f"First mask must be 2D, got shape {first_mask.shape}"
|
| 252 |
|
| 253 |
# Convert first mask to tensor and move to GPU
|
| 254 |
first_mask_tensor = torch.from_numpy(first_mask).float() / 255.0
|
|
|
|
| 283 |
# Fallback to processing each frame with its mask
|
| 284 |
log.warning("MatAnyone batch processing not available, using frame-by-frame")
|
| 285 |
for frame_tensor, mask in zip(frame_tensors, masks):
|
| 286 |
+
# Ensure each mask is 2D
|
| 287 |
+
if mask.ndim == 3:
|
| 288 |
+
if mask.shape[2] == 3:
|
| 289 |
+
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
|
| 290 |
+
else:
|
| 291 |
+
mask = mask[:, :, 0]
|
| 292 |
+
|
| 293 |
mask_tensor = torch.from_numpy(mask).float() / 255.0
|
| 294 |
mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0).to(device)
|
| 295 |
frame_on_device = frame_tensor.unsqueeze(0).to(device)
|