MogensR commited on
Commit
28e0f6c
·
1 Parent(s): e21220a

Create utils/two_stage_processor.py

Browse files
Files changed (1) hide show
  1. utils/two_stage_processor.py +306 -0
utils/two_stage_processor.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Fixed SAM2 + MatAnyone Integration
3
+ Corrects tensor dimension mismatches and ensures proper model cooperation
4
+ """
5
+
6
+ import torch
7
+ import numpy as np
8
+ import cv2
9
+ from typing import Optional, Tuple, List
10
+ import logging
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ class TwoStageProcessor:
15
+ """Properly integrated SAM2 + MatAnyone processor"""
16
+
17
+ def __init__(self, sam2_model, matanyone_model, device='cuda'):
18
+ self.sam2 = sam2_model
19
+ self.matanyone = matanyone_model
20
+ self.device = device
21
+ logger.info(f"TwoStageProcessor initialized on {device}")
22
+
23
+ def process_frame(self, frame: np.ndarray, prev_mask: Optional[np.ndarray] = None) -> Tuple[np.ndarray, np.ndarray]:
24
+ """
25
+ Process a single frame through SAM2 + MatAnyone
26
+
27
+ Args:
28
+ frame: RGB frame (H, W, 3) as numpy array
29
+ prev_mask: Optional previous frame mask for temporal consistency
30
+
31
+ Returns:
32
+ processed_frame: Frame with background removed (H, W, 4) RGBA
33
+ mask: Binary mask (H, W) as uint8
34
+ """
35
+ H, W = frame.shape[:2]
36
+
37
+ try:
38
+ # Step 1: Get mask from SAM2
39
+ mask = self._get_sam2_mask(frame, prev_mask)
40
+
41
+ # Step 2: Process with MatAnyone
42
+ if self.matanyone is not None and mask is not None:
43
+ processed = self._process_with_matanyone(frame, mask)
44
+ if processed is not None:
45
+ return processed, mask
46
+
47
+ # Fallback: Simple alpha composite if MatAnyone fails
48
+ return self._simple_composite(frame, mask), mask
49
+
50
+ except Exception as e:
51
+ logger.error(f"Frame processing failed: {e}")
52
+ # Return original frame with full opacity as fallback
53
+ rgba = np.zeros((H, W, 4), dtype=np.uint8)
54
+ rgba[:, :, :3] = frame
55
+ rgba[:, :, 3] = 255
56
+ return rgba, np.ones((H, W), dtype=np.uint8) * 255
57
+
58
+ def _get_sam2_mask(self, frame: np.ndarray, prev_mask: Optional[np.ndarray]) -> np.ndarray:
59
+ """Get segmentation mask from SAM2"""
60
+ H, W = frame.shape[:2]
61
+
62
+ try:
63
+ if hasattr(self.sam2, 'generate_mask'):
64
+ # Proper SAM2 call
65
+ with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):
66
+ # Convert frame to tensor
67
+ frame_tensor = torch.from_numpy(frame).to(self.device).float() / 255.0
68
+ frame_tensor = frame_tensor.permute(2, 0, 1).unsqueeze(0) # (1, 3, H, W)
69
+
70
+ # Get mask from SAM2
71
+ if prev_mask is not None:
72
+ prev_mask_tensor = torch.from_numpy(prev_mask).to(self.device).float() / 255.0
73
+ prev_mask_tensor = prev_mask_tensor.unsqueeze(0).unsqueeze(0) # (1, 1, H, W)
74
+ mask_logits = self.sam2.generate_mask(frame_tensor, prev_mask_tensor)
75
+ else:
76
+ mask_logits = self.sam2.generate_mask(frame_tensor)
77
+
78
+ # Convert to binary mask
79
+ mask = (mask_logits.squeeze().cpu().numpy() > 0).astype(np.uint8) * 255
80
+ return mask
81
+ else:
82
+ # Fallback SAM2 - create center-weighted mask
83
+ logger.warning("Using fallback mask generation")
84
+ return self._generate_center_mask(H, W)
85
+
86
+ except Exception as e:
87
+ logger.error(f"SAM2 mask generation failed: {e}")
88
+ return self._generate_center_mask(H, W)
89
+
90
+ def _generate_center_mask(self, H: int, W: int) -> np.ndarray:
91
+ """Generate a center-weighted elliptical mask as fallback"""
92
+ mask = np.zeros((H, W), dtype=np.uint8)
93
+ center_x, center_y = W // 2, H // 2
94
+ axes_x, axes_y = W // 3, H // 3
95
+
96
+ y, x = np.ogrid[:H, :W]
97
+ mask_area = ((x - center_x) / axes_x) ** 2 + ((y - center_y) / axes_y) ** 2 <= 1
98
+ mask[mask_area] = 255
99
+
100
+ # Smooth edges
101
+ mask = cv2.GaussianBlur(mask, (21, 21), 10)
102
+ mask = (mask > 128).astype(np.uint8) * 255
103
+
104
+ return mask
105
+
106
+ def _process_with_matanyone(self, frame: np.ndarray, mask: np.ndarray) -> Optional[np.ndarray]:
107
+ """Process frame with MatAnyone for high-quality matting"""
108
+ try:
109
+ H, W = frame.shape[:2]
110
+
111
+ # Ensure correct input formats for MatAnyone
112
+ # Frame should be (H, W, 3) uint8
113
+ if frame.dtype != np.uint8:
114
+ frame = (frame * 255).astype(np.uint8) if frame.max() <= 1 else frame.astype(np.uint8)
115
+
116
+ # Mask should be (H, W, 1) float32 normalized to [0, 1]
117
+ mask_input = mask.astype(np.float32) / 255.0
118
+ if len(mask_input.shape) == 2:
119
+ mask_input = np.expand_dims(mask_input, axis=2) # (H, W, 1)
120
+
121
+ # Prepare tensors for MatAnyone
122
+ with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):
123
+ # Convert to tensors with correct dimensions
124
+ frame_tensor = torch.from_numpy(frame).to(self.device).float() / 255.0
125
+ frame_tensor = frame_tensor.permute(2, 0, 1).unsqueeze(0) # (1, 3, H, W)
126
+
127
+ mask_tensor = torch.from_numpy(mask_input).to(self.device).float()
128
+ mask_tensor = mask_tensor.permute(2, 0, 1).unsqueeze(0) # (1, 1, H, W)
129
+
130
+ # Call MatAnyone with correct tensor shapes
131
+ if hasattr(self.matanyone, '__call__'):
132
+ # MatAnyone expects: image (1, 3, H, W), mask (1, 1, H, W)
133
+ result = self.matanyone(frame_tensor, mask_tensor)
134
+
135
+ if result is not None:
136
+ # Extract alpha matte
137
+ if isinstance(result, tuple):
138
+ alpha = result[0] # Assume first element is alpha
139
+ else:
140
+ alpha = result
141
+
142
+ # Convert back to numpy
143
+ alpha = alpha.squeeze(0).squeeze(0).cpu().numpy() # (H, W)
144
+ alpha = (alpha * 255).astype(np.uint8)
145
+
146
+ # Create RGBA image
147
+ rgba = np.zeros((H, W, 4), dtype=np.uint8)
148
+ rgba[:, :, :3] = frame
149
+ rgba[:, :, 3] = alpha
150
+
151
+ return rgba
152
+ elif hasattr(self.matanyone, 'process'):
153
+ # Alternative MatAnyone API
154
+ result = self.matanyone.process(frame, mask_input)
155
+ if result is not None:
156
+ return result
157
+
158
+ return None
159
+
160
+ except Exception as e:
161
+ logger.warning(f"MatAnyone processing failed: {e}")
162
+ return None
163
+
164
+ def _simple_composite(self, frame: np.ndarray, mask: np.ndarray) -> np.ndarray:
165
+ """Simple RGBA composite as final fallback"""
166
+ H, W = frame.shape[:2]
167
+
168
+ # Apply some edge refinement to the mask
169
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
170
+ mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
171
+ mask = cv2.GaussianBlur(mask, (5, 5), 1)
172
+
173
+ # Create RGBA
174
+ rgba = np.zeros((H, W, 4), dtype=np.uint8)
175
+ rgba[:, :, :3] = frame
176
+ rgba[:, :, 3] = mask
177
+
178
+ return rgba
179
+
180
+ def process_video(self, video_path: str, output_path: str, progress_callback=None):
181
+ """Process entire video through the pipeline"""
182
+ import cv2
183
+
184
+ cap = cv2.VideoCapture(video_path)
185
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
186
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
187
+ W = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
188
+ H = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
189
+
190
+ # Setup video writer with transparency (use PNG codec or similar)
191
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
192
+ out = cv2.VideoWriter(output_path, fourcc, fps, (W, H), True)
193
+
194
+ prev_mask = None
195
+ frame_idx = 0
196
+
197
+ logger.info(f"Processing {total_frames} frames at {fps}fps")
198
+
199
+ try:
200
+ while cap.isOpened():
201
+ ret, frame = cap.read()
202
+ if not ret:
203
+ break
204
+
205
+ # Convert BGR to RGB
206
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
207
+
208
+ # Process frame
209
+ processed, mask = self.process_frame(frame_rgb, prev_mask)
210
+ prev_mask = mask # Use for temporal consistency
211
+
212
+ # Convert RGBA to BGR for video writer (or handle alpha separately)
213
+ if processed.shape[2] == 4:
214
+ # For now, composite on green background for compatibility
215
+ green_bg = np.zeros((H, W, 3), dtype=np.uint8)
216
+ green_bg[:, :, 1] = 255 # Pure green
217
+
218
+ alpha = processed[:, :, 3:4] / 255.0
219
+ rgb = processed[:, :, :3]
220
+
221
+ composited = (rgb * alpha + green_bg * (1 - alpha)).astype(np.uint8)
222
+ output_bgr = cv2.cvtColor(composited, cv2.COLOR_RGB2BGR)
223
+ else:
224
+ output_bgr = cv2.cvtColor(processed, cv2.COLOR_RGB2BGR)
225
+
226
+ out.write(output_bgr)
227
+
228
+ frame_idx += 1
229
+ if progress_callback:
230
+ progress_callback(frame_idx / total_frames)
231
+
232
+ if frame_idx % 30 == 0:
233
+ logger.info(f"Processed {frame_idx}/{total_frames} frames")
234
+
235
+ logger.info(f"Video processing complete: {output_path}")
236
+
237
+ finally:
238
+ cap.release()
239
+ out.release()
240
+ cv2.destroyAllWindows()
241
+
242
+
243
+ # Fix for the current MatAnyone loader issue
244
+ class MatAnyoneLoaderFix:
245
+ """Fixes for the MatAnyone dimension mismatch issues"""
246
+
247
+ @staticmethod
248
+ def fix_matanyone_call(matanyone_model):
249
+ """Wrap MatAnyone model to handle dimension issues"""
250
+
251
+ original_call = matanyone_model.__call__ if hasattr(matanyone_model, '__call__') else None
252
+
253
+ def fixed_call(image, mask, *args, **kwargs):
254
+ try:
255
+ # Ensure image is (1, 3, H, W)
256
+ if len(image.shape) == 3:
257
+ image = image.unsqueeze(0)
258
+ if image.shape[1] != 3:
259
+ image = image.permute(0, 3, 1, 2)
260
+
261
+ # Ensure mask is (1, 1, H, W)
262
+ if len(mask.shape) == 2:
263
+ mask = mask.unsqueeze(0).unsqueeze(0)
264
+ elif len(mask.shape) == 3:
265
+ if mask.shape[0] != 1:
266
+ mask = mask.unsqueeze(0)
267
+ if mask.shape[1] != 1 and mask.shape[-1] == 1:
268
+ mask = mask.permute(0, 3, 1, 2)
269
+
270
+ # Ensure same spatial dimensions
271
+ if image.shape[-2:] != mask.shape[-2:]:
272
+ mask = torch.nn.functional.interpolate(
273
+ mask, size=image.shape[-2:], mode='bilinear', align_corners=False
274
+ )
275
+
276
+ # Call original with fixed dimensions
277
+ if original_call:
278
+ return original_call(image, mask, *args, **kwargs)
279
+ else:
280
+ return None
281
+
282
+ except Exception as e:
283
+ logger.error(f"MatAnyone call fix failed: {e}")
284
+ return None
285
+
286
+ if hasattr(matanyone_model, '__call__'):
287
+ matanyone_model.__call__ = fixed_call
288
+
289
+ return matanyone_model
290
+
291
+
292
+ # Integration with existing code
293
+ def initialize_two_stage_processor(sam2_loader, matanyone_loader, device='cuda'):
294
+ """Initialize the fixed two-stage processor"""
295
+
296
+ # Apply MatAnyone fixes
297
+ if matanyone_loader and hasattr(matanyone_loader, 'model'):
298
+ matanyone_loader.model = MatAnyoneLoaderFix.fix_matanyone_call(matanyone_loader.model)
299
+
300
+ processor = TwoStageProcessor(
301
+ sam2_model=sam2_loader.model if sam2_loader else None,
302
+ matanyone_model=matanyone_loader.model if matanyone_loader else None,
303
+ device=device
304
+ )
305
+
306
+ return processor