MogensR commited on
Commit
e94d263
·
1 Parent(s): 19a2b07

Update utils/refinement.py

Browse files
Files changed (1) hide show
  1. utils/refinement.py +33 -5
utils/refinement.py CHANGED
@@ -141,16 +141,29 @@ def _refine_with_matanyone(
141
  image_tensor = torch.from_numpy(image_rgb).permute(2, 0, 1).float() / 255.0
142
  image_tensor = image_tensor.unsqueeze(0).to(device) # Add batch dimension and move to GPU
143
 
 
 
 
 
 
 
 
 
144
  # Ensure mask is binary uint8
145
  if mask.dtype != np.uint8:
146
  mask = (mask * 255).astype(np.uint8) if mask.max() <= 1 else mask.astype(np.uint8)
147
- if mask.ndim == 3:
148
- mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
 
 
149
 
150
  # Convert mask to tensor and move to GPU
151
  mask_tensor = torch.from_numpy(mask).float() / 255.0
152
  mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0).to(device) # (1, 1, H, W) on GPU
153
 
 
 
 
154
  # Try different methods on InferenceCore
155
  result = None
156
 
@@ -224,10 +237,18 @@ def _refine_batch_with_matanyone(
224
 
225
  # Prepare first mask for initialization
226
  first_mask = masks[0]
227
- if first_mask.dtype != np.uint8:
228
- first_mask = (first_mask * 255).astype(np.uint8)
229
  if first_mask.ndim == 3:
230
- first_mask = cv2.cvtColor(first_mask, cv2.COLOR_BGR2GRAY)
 
 
 
 
 
 
 
 
231
 
232
  # Convert first mask to tensor and move to GPU
233
  first_mask_tensor = torch.from_numpy(first_mask).float() / 255.0
@@ -262,6 +283,13 @@ def _refine_batch_with_matanyone(
262
  # Fallback to processing each frame with its mask
263
  log.warning("MatAnyone batch processing not available, using frame-by-frame")
264
  for frame_tensor, mask in zip(frame_tensors, masks):
 
 
 
 
 
 
 
265
  mask_tensor = torch.from_numpy(mask).float() / 255.0
266
  mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0).to(device)
267
  frame_on_device = frame_tensor.unsqueeze(0).to(device)
 
141
  image_tensor = torch.from_numpy(image_rgb).permute(2, 0, 1).float() / 255.0
142
  image_tensor = image_tensor.unsqueeze(0).to(device) # Add batch dimension and move to GPU
143
 
144
+ # CRITICAL: Ensure mask is 2D before processing
145
+ if mask.ndim == 3:
146
+ # Convert multi-channel to single channel
147
+ if mask.shape[2] == 3:
148
+ mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
149
+ else:
150
+ mask = mask[:, :, 0]
151
+
152
  # Ensure mask is binary uint8
153
  if mask.dtype != np.uint8:
154
  mask = (mask * 255).astype(np.uint8) if mask.max() <= 1 else mask.astype(np.uint8)
155
+
156
+ # Final verification that mask is 2D
157
+ assert mask.ndim == 2, f"Mask must be 2D after conversion, got shape {mask.shape}"
158
+ assert mask.shape == (h, w), f"Mask shape {mask.shape} doesn't match image shape ({h}, {w})"
159
 
160
  # Convert mask to tensor and move to GPU
161
  mask_tensor = torch.from_numpy(mask).float() / 255.0
162
  mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0).to(device) # (1, 1, H, W) on GPU
163
 
164
+ # Verify tensor dimensions
165
+ assert mask_tensor.shape == (1, 1, h, w), f"Mask tensor wrong shape: {mask_tensor.shape}, expected (1, 1, {h}, {w})"
166
+
167
  # Try different methods on InferenceCore
168
  result = None
169
 
 
237
 
238
  # Prepare first mask for initialization
239
  first_mask = masks[0]
240
+
241
+ # CRITICAL: Ensure first mask is 2D
242
  if first_mask.ndim == 3:
243
+ if first_mask.shape[2] == 3:
244
+ first_mask = cv2.cvtColor(first_mask, cv2.COLOR_BGR2GRAY)
245
+ else:
246
+ first_mask = first_mask[:, :, 0]
247
+
248
+ if first_mask.dtype != np.uint8:
249
+ first_mask = (first_mask * 255).astype(np.uint8) if first_mask.max() <= 1 else first_mask.astype(np.uint8)
250
+
251
+ assert first_mask.ndim == 2, f"First mask must be 2D, got shape {first_mask.shape}"
252
 
253
  # Convert first mask to tensor and move to GPU
254
  first_mask_tensor = torch.from_numpy(first_mask).float() / 255.0
 
283
  # Fallback to processing each frame with its mask
284
  log.warning("MatAnyone batch processing not available, using frame-by-frame")
285
  for frame_tensor, mask in zip(frame_tensors, masks):
286
+ # Ensure each mask is 2D
287
+ if mask.ndim == 3:
288
+ if mask.shape[2] == 3:
289
+ mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
290
+ else:
291
+ mask = mask[:, :, 0]
292
+
293
  mask_tensor = torch.from_numpy(mask).float() / 255.0
294
  mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0).to(device)
295
  frame_on_device = frame_tensor.unsqueeze(0).to(device)