Secking Claude commited on
Commit
32d4b75
·
1 Parent(s): 97106f1

Implement full Modal segmentation integration for Segmentation Editing tab

Browse files

- Add session state management (frames_state, masks_state, magic_code_state)
- Implement download_segment_files() to fetch frames and alpha masks from Modal volume
- Add composite_image_with_mask() for alpha overlay rendering
- Implement handle_segment_selection() with automatic file download and slider init
- Add handle_image_click() with SAM-3 segmentation function integration via Modal
- Implement load_segment_frame() with show/hide mask toggle support
- Add "Show Alpha Mask" checkbox to UI (Row 3)
- Wire all event handlers: segment selection, frame slider, mask toggle, image click
- Implement download_segment() to export edited frames + masks as ZIP
- Add Pillow and numpy to requirements.txt
- Update submit_magic_code() output signature to match new row visibility

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

Files changed (2) hide show
  1. app.py +313 -27
  2. requirements.txt +2 -0
app.py CHANGED
@@ -8,11 +8,16 @@
8
  import os
9
  import re
10
  import cv2
 
11
  import time
12
  import base64
13
  import modal
14
  import logging
 
 
15
  import gradio as gr
 
 
16
 
17
  logging.basicConfig(level=logging.DEBUG)
18
  logger = logging.getLogger(__name__)
@@ -189,6 +194,7 @@ def submit_magic_code(magic_code):
189
  gr.update(visible=False),
190
  gr.update(visible=False),
191
  gr.update(visible=False),
 
192
  gr.update(visible=False)
193
  )
194
 
@@ -199,6 +205,7 @@ def submit_magic_code(magic_code):
199
  gr.update(visible=True),
200
  gr.update(visible=True),
201
  gr.update(visible=True),
 
202
  gr.update(visible=True)
203
  )
204
 
@@ -210,6 +217,7 @@ def submit_magic_code(magic_code):
210
  gr.update(visible=False),
211
  gr.update(visible=False),
212
  gr.update(visible=False),
 
213
  gr.update(visible=False)
214
  )
215
 
@@ -221,33 +229,280 @@ def submit_magic_code(magic_code):
221
  gr.update(visible=False),
222
  gr.update(visible=False),
223
  gr.update(visible=False),
 
224
  gr.update(visible=False)
225
  )
226
 
227
 
228
- def load_segment_frame(segment_id, frame_number):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  """
230
- Placeholder function to load a specific frame from a segment for editing.
231
- TODO: Connect to backend API to retrieve frame image from video segment.
232
  """
233
- if not segment_id:
234
- return None
 
 
 
235
 
236
- logger.info(f"Loading segment: {segment_id}, frame: {frame_number}")
237
- # Placeholder: Return None (no image loaded yet)
238
- return None
239
 
 
 
 
240
 
241
- def download_segment(segment_id, edited_image, frame_number):
 
 
 
 
 
 
 
 
 
 
 
242
  """
243
- Placeholder function to download edited segment.
244
- TODO: Connect to backend API to save edited segmentation and download result.
245
  """
246
- if not segment_id:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  return "No segment selected"
248
 
249
  logger.info(f"Download requested for segment: {segment_id}")
250
- return f"Download functionality coming soon for {segment_id}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
 
252
 
253
  # Create a professional Gradio interface using the Golden ratio (1.618) for proportions
@@ -480,6 +735,11 @@ with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="indigo", secondary_hue
480
  with gr.Tab("Segmentation Editing"):
481
  gr.Markdown("### Segmentation Editing Workspace")
482
 
 
 
 
 
 
483
  # Row 1: Magic Code textbox + Submit button
484
  with gr.Row(elem_classes="input-section"):
485
  with gr.Column(scale=3):
@@ -500,8 +760,16 @@ with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="indigo", secondary_hue
500
  info="Select a segment to edit"
501
  )
502
 
503
- # Row 3: ImageEditor component (SAM-3 style interaction)
504
- with gr.Row(elem_classes="input-section", visible=False) as seg_row3:
 
 
 
 
 
 
 
 
505
  seg_image_editor = gr.ImageEditor(
506
  label="Image Click Segmentation",
507
  type="pil",
@@ -509,8 +777,8 @@ with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="indigo", secondary_hue
509
  brush=gr.Brush(colors=["#FF0000"], color_mode="fixed")
510
  )
511
 
512
- # Row 4: Frame number slider
513
- with gr.Row(elem_classes="input-section", visible=False) as seg_row4:
514
  seg_frame_slider = gr.Slider(
515
  minimum=0,
516
  maximum=100,
@@ -521,9 +789,10 @@ with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="indigo", secondary_hue
521
  info="Select video frame to segment"
522
  )
523
 
524
- # Row 5: Download Segment button
525
- with gr.Row(visible=False) as seg_row5:
526
  seg_download_btn = gr.Button("Download Segment", variant="secondary")
 
527
 
528
  # Wire Content Moderation processing
529
  cm_process_btn.click(
@@ -536,25 +805,42 @@ with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="indigo", secondary_hue
536
  seg_submit_btn.click(
537
  fn=submit_magic_code,
538
  inputs=[seg_magic_code],
539
- outputs=[seg_id_dropdown, seg_magic_code, seg_row2, seg_row3, seg_row4, seg_row5]
540
  )
541
 
 
542
  seg_id_dropdown.change(
543
- fn=load_segment_frame,
544
- inputs=[seg_id_dropdown, seg_frame_slider],
545
- outputs=[seg_image_editor]
546
  )
547
 
 
548
  seg_frame_slider.change(
549
  fn=load_segment_frame,
550
- inputs=[seg_id_dropdown, seg_frame_slider],
551
- outputs=[seg_image_editor]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
552
  )
553
 
 
554
  seg_download_btn.click(
555
  fn=download_segment,
556
- inputs=[seg_id_dropdown, seg_image_editor, seg_frame_slider],
557
- outputs=[seg_magic_code] # Display status message in magic code textbox
558
  )
559
 
560
  if __name__ == "__main__":
 
8
  import os
9
  import re
10
  import cv2
11
+ import io
12
  import time
13
  import base64
14
  import modal
15
  import logging
16
+ import tempfile
17
+ import numpy as np
18
  import gradio as gr
19
+ from PIL import Image
20
+ from typing import Dict, Optional, Tuple
21
 
22
  logging.basicConfig(level=logging.DEBUG)
23
  logger = logging.getLogger(__name__)
 
194
  gr.update(visible=False),
195
  gr.update(visible=False),
196
  gr.update(visible=False),
197
+ gr.update(visible=False),
198
  gr.update(visible=False)
199
  )
200
 
 
205
  gr.update(visible=True),
206
  gr.update(visible=True),
207
  gr.update(visible=True),
208
+ gr.update(visible=True),
209
  gr.update(visible=True)
210
  )
211
 
 
217
  gr.update(visible=False),
218
  gr.update(visible=False),
219
  gr.update(visible=False),
220
+ gr.update(visible=False),
221
  gr.update(visible=False)
222
  )
223
 
 
229
  gr.update(visible=False),
230
  gr.update(visible=False),
231
  gr.update(visible=False),
232
+ gr.update(visible=False),
233
  gr.update(visible=False)
234
  )
235
 
236
 
237
+ def download_segment_files(magic_code: str, segment_id: str) -> Tuple[Dict[int, Image.Image], Dict[int, Image.Image], int]:
238
+ """
239
+ Download all frame images and alpha masks for a given segment from Modal volume.
240
+
241
+ Returns:
242
+ frames_dict: Dict mapping frame_index -> PIL.Image (original frames)
243
+ masks_dict: Dict mapping frame_index -> PIL.Image (alpha masks)
244
+ max_frame: Maximum frame index found
245
+ """
246
+ try:
247
+ modal_volume_name = os.environ['MODERATION_MODAL_VOLUME']
248
+ volume = modal.Volume.from_name(modal_volume_name)
249
+
250
+ segment_path = f"/{magic_code}/{segment_id}"
251
+ frames_dict = {}
252
+ masks_dict = {}
253
+
254
+ logger.info(f"Downloading files from {segment_path}")
255
+
256
+ # List all files in the segment directory
257
+ try:
258
+ files = list(volume.listdir(segment_path))
259
+ except Exception as e:
260
+ logger.error(f"Failed to list segment directory: {e}")
261
+ return {}, {}, 0
262
+
263
+ # Parse filenames and download
264
+ frame_pattern = re.compile(r'^frame_(\d+)\.(jpg|png)$')
265
+ alpha_pattern = re.compile(r'^alpha_frame_(\d+)\.png$')
266
+
267
+ for entry in files:
268
+ if entry.type != modal.volume.FileEntryType.FILE:
269
+ continue
270
+
271
+ filename = os.path.basename(entry.path)
272
+
273
+ # Check if it's a frame file
274
+ frame_match = frame_pattern.match(filename)
275
+ if frame_match:
276
+ frame_idx = int(frame_match.group(1))
277
+ try:
278
+ # Download frame
279
+ file_data = volume.read_file(f"{segment_path}/{filename}")
280
+ img = Image.open(io.BytesIO(file_data))
281
+ frames_dict[frame_idx] = img.copy()
282
+ logger.debug(f"Downloaded frame {frame_idx}")
283
+ except Exception as e:
284
+ logger.error(f"Failed to download {filename}: {e}")
285
+ continue
286
+
287
+ # Check if it's an alpha mask file
288
+ alpha_match = alpha_pattern.match(filename)
289
+ if alpha_match:
290
+ frame_idx = int(alpha_match.group(1))
291
+ try:
292
+ # Download alpha mask
293
+ file_data = volume.read_file(f"{segment_path}/{filename}")
294
+ img = Image.open(io.BytesIO(file_data))
295
+ masks_dict[frame_idx] = img.copy()
296
+ logger.debug(f"Downloaded alpha mask {frame_idx}")
297
+ except Exception as e:
298
+ logger.error(f"Failed to download {filename}: {e}")
299
+ continue
300
+
301
+ max_frame = max(frames_dict.keys()) if frames_dict else 0
302
+ logger.info(f"Downloaded {len(frames_dict)} frames and {len(masks_dict)} masks. Max frame: {max_frame}")
303
+
304
+ return frames_dict, masks_dict, max_frame
305
+
306
+ except Exception as e:
307
+ logger.error(f"Error downloading segment files: {e}")
308
+ return {}, {}, 0
309
+
310
+
311
+ def composite_image_with_mask(frame: Image.Image, mask: Optional[Image.Image], show_mask: bool) -> Image.Image:
312
+ """
313
+ Composite the original frame with the alpha mask overlay.
314
+
315
+ Args:
316
+ frame: Original RGB/RGBA image
317
+ mask: Alpha mask (grayscale or RGBA)
318
+ show_mask: Whether to show the mask overlay
319
+
320
+ Returns:
321
+ Composited PIL Image
322
+ """
323
+ if not show_mask or mask is None:
324
+ return frame.copy()
325
+
326
+ # Convert frame to RGBA if needed
327
+ if frame.mode != 'RGBA':
328
+ frame_rgba = frame.convert('RGBA')
329
+ else:
330
+ frame_rgba = frame.copy()
331
+
332
+ # Convert mask to 'L' (grayscale) if needed
333
+ if mask.mode != 'L':
334
+ mask_gray = mask.convert('L')
335
+ else:
336
+ mask_gray = mask.copy()
337
+
338
+ # Resize mask to match frame if needed
339
+ if mask_gray.size != frame_rgba.size:
340
+ mask_gray = mask_gray.resize(frame_rgba.size, Image.Resampling.LANCZOS)
341
+
342
+ # Create a colored overlay (semi-transparent red)
343
+ overlay = Image.new('RGBA', frame_rgba.size, (255, 0, 0, 128))
344
+
345
+ # Use mask as alpha channel for the overlay
346
+ overlay.putalpha(mask_gray)
347
+
348
+ # Composite
349
+ result = Image.alpha_composite(frame_rgba, overlay)
350
+
351
+ return result
352
+
353
+
354
+ def load_segment_frame(segment_id, frame_number, show_mask, magic_code_state, frames_state, masks_state):
355
  """
356
+ Load and display a specific frame with optional alpha mask overlay.
 
357
  """
358
+ if not segment_id or frames_state is None:
359
+ return None, gr.update()
360
+
361
+ frames_dict = frames_state
362
+ masks_dict = masks_state
363
 
364
+ frame_idx = int(frame_number)
 
 
365
 
366
+ if frame_idx not in frames_dict:
367
+ logger.warning(f"Frame {frame_idx} not found in downloaded frames")
368
+ return None, gr.update()
369
 
370
+ frame = frames_dict[frame_idx]
371
+ mask = masks_dict.get(frame_idx, None)
372
+
373
+ # Composite image with mask
374
+ result_image = composite_image_with_mask(frame, mask, show_mask)
375
+
376
+ logger.info(f"Loaded frame {frame_idx} with mask overlay: {show_mask}")
377
+
378
+ return result_image, gr.update()
379
+
380
+
381
+ def handle_segment_selection(segment_id, magic_code):
382
  """
383
+ Handle segment selection: download all files and initialize the view.
 
384
  """
385
+ if not segment_id or not magic_code:
386
+ return None, gr.update(maximum=0, value=0), {}, {}, magic_code
387
+
388
+ logger.info(f"Segment selected: {segment_id}")
389
+
390
+ # Download all frames and masks
391
+ frames_dict, masks_dict, max_frame = download_segment_files(magic_code, segment_id)
392
+
393
+ if not frames_dict:
394
+ logger.error("No frames downloaded")
395
+ return None, gr.update(maximum=0, value=0), {}, {}, magic_code
396
+
397
+ # Load first frame
398
+ frame_0 = frames_dict.get(0, None)
399
+ if frame_0 is None:
400
+ # Use first available frame
401
+ frame_0 = frames_dict[min(frames_dict.keys())]
402
+
403
+ # Initial display without mask
404
+ result_image = composite_image_with_mask(frame_0, masks_dict.get(0, None), False)
405
+
406
+ return (
407
+ result_image,
408
+ gr.update(minimum=0, maximum=max_frame, value=0),
409
+ frames_dict,
410
+ masks_dict,
411
+ magic_code
412
+ )
413
+
414
+
415
+ def handle_image_click(segment_id, frame_number, magic_code, frames_state, masks_state, evt: gr.SelectData):
416
+ """
417
+ Handle click on ImageEditor: send coordinates to SAM-3, get new mask, update display.
418
+ """
419
+ if not segment_id or frames_state is None or evt is None:
420
+ return None, masks_state
421
+
422
+ # Extract click coordinates
423
+ x, y = evt.index[0], evt.index[1]
424
+ frame_idx = int(frame_number)
425
+
426
+ logger.info(f"Click detected at ({x}, {y}) on frame {frame_idx}")
427
+
428
+ if frame_idx not in frames_state:
429
+ logger.error(f"Frame {frame_idx} not in state")
430
+ return None, masks_state
431
+
432
+ try:
433
+ # Get the original frame
434
+ frame = frames_state[frame_idx]
435
+
436
+ # Convert frame to bytes
437
+ img_byte_arr = io.BytesIO()
438
+ frame.save(img_byte_arr, format='PNG')
439
+ img_bytes = img_byte_arr.getvalue()
440
+
441
+ # Call Modal SAM-3 segmentation function
442
+ logger.info(f"Calling SAM-3 segmentation with coordinates ({x}, {y})")
443
+
444
+ try:
445
+ sam_function = modal.Function.from_name("Content-Moderation", "sam3_segmentation_function")
446
+ mask_bytes = sam_function.remote(image_bytes=img_bytes, x=x, y=y)
447
+
448
+ # Parse returned mask
449
+ new_mask = Image.open(io.BytesIO(mask_bytes))
450
+
451
+ # Update masks dict
452
+ masks_state[frame_idx] = new_mask.copy()
453
+
454
+ # Composite and return updated image (always show mask after click)
455
+ result_image = composite_image_with_mask(frame, new_mask, True)
456
+
457
+ logger.info(f"Successfully updated mask for frame {frame_idx}")
458
+
459
+ return result_image, masks_state
460
+
461
+ except Exception as e:
462
+ logger.error(f"Error calling SAM-3 function: {e}")
463
+ # Return current view without update
464
+ frame = frames_state[frame_idx]
465
+ mask = masks_state.get(frame_idx, None)
466
+ result_image = composite_image_with_mask(frame, mask, True)
467
+ return result_image, masks_state
468
+
469
+ except Exception as e:
470
+ logger.error(f"Error handling image click: {e}")
471
+ return None, masks_state
472
+
473
+
474
+ def download_segment(segment_id, frames_state, masks_state):
475
+ """
476
+ Package and download the edited segment (frames + masks) as a ZIP file.
477
+ """
478
+ if not segment_id or not frames_state:
479
  return "No segment selected"
480
 
481
  logger.info(f"Download requested for segment: {segment_id}")
482
+
483
+ try:
484
+ # Create temporary directory for files
485
+ with tempfile.TemporaryDirectory() as tmpdir:
486
+ # Save all frames and masks
487
+ for frame_idx, frame_img in frames_state.items():
488
+ frame_path = os.path.join(tmpdir, f"frame_{frame_idx:06d}.png")
489
+ frame_img.save(frame_path)
490
+
491
+ if frame_idx in masks_state:
492
+ mask_path = os.path.join(tmpdir, f"alpha_frame_{frame_idx:06d}.png")
493
+ masks_state[frame_idx].save(mask_path)
494
+
495
+ # Create ZIP
496
+ import shutil
497
+ zip_path = f"/tmp/segment_{segment_id}.zip"
498
+ shutil.make_archive(zip_path.replace('.zip', ''), 'zip', tmpdir)
499
+
500
+ logger.info(f"Created ZIP at {zip_path}")
501
+ return zip_path
502
+
503
+ except Exception as e:
504
+ logger.error(f"Error creating download package: {e}")
505
+ return f"Error: {str(e)}"
506
 
507
 
508
  # Create a professional Gradio interface using the Golden ratio (1.618) for proportions
 
735
  with gr.Tab("Segmentation Editing"):
736
  gr.Markdown("### Segmentation Editing Workspace")
737
 
738
+ # State management for session data
739
+ magic_code_state = gr.State(value=None)
740
+ frames_state = gr.State(value=None)
741
+ masks_state = gr.State(value=None)
742
+
743
  # Row 1: Magic Code textbox + Submit button
744
  with gr.Row(elem_classes="input-section"):
745
  with gr.Column(scale=3):
 
760
  info="Select a segment to edit"
761
  )
762
 
763
+ # Row 3: Show Alpha Mask checkbox
764
+ with gr.Row(elem_classes="input-section", visible=False) as seg_row3_checkbox:
765
+ seg_show_mask = gr.Checkbox(
766
+ label="Show Alpha Mask",
767
+ value=False,
768
+ interactive=True
769
+ )
770
+
771
+ # Row 4: ImageEditor component (SAM-3 style interaction)
772
+ with gr.Row(elem_classes="input-section", visible=False) as seg_row4:
773
  seg_image_editor = gr.ImageEditor(
774
  label="Image Click Segmentation",
775
  type="pil",
 
777
  brush=gr.Brush(colors=["#FF0000"], color_mode="fixed")
778
  )
779
 
780
+ # Row 5: Frame number slider
781
+ with gr.Row(elem_classes="input-section", visible=False) as seg_row5:
782
  seg_frame_slider = gr.Slider(
783
  minimum=0,
784
  maximum=100,
 
789
  info="Select video frame to segment"
790
  )
791
 
792
+ # Row 6: Download Segment button
793
+ with gr.Row(visible=False) as seg_row6:
794
  seg_download_btn = gr.Button("Download Segment", variant="secondary")
795
+ seg_download_file = gr.File(label="Download", visible=False)
796
 
797
  # Wire Content Moderation processing
798
  cm_process_btn.click(
 
805
  seg_submit_btn.click(
806
  fn=submit_magic_code,
807
  inputs=[seg_magic_code],
808
+ outputs=[seg_id_dropdown, seg_magic_code, seg_row2, seg_row3_checkbox, seg_row4, seg_row5, seg_row6]
809
  )
810
 
811
+ # Segment selection handler
812
  seg_id_dropdown.change(
813
+ fn=handle_segment_selection,
814
+ inputs=[seg_id_dropdown, seg_magic_code],
815
+ outputs=[seg_image_editor, seg_frame_slider, frames_state, masks_state, magic_code_state]
816
  )
817
 
818
+ # Frame slider handler
819
  seg_frame_slider.change(
820
  fn=load_segment_frame,
821
+ inputs=[seg_id_dropdown, seg_frame_slider, seg_show_mask, magic_code_state, frames_state, masks_state],
822
+ outputs=[seg_image_editor, seg_show_mask]
823
+ )
824
+
825
+ # Show mask checkbox handler
826
+ seg_show_mask.change(
827
+ fn=load_segment_frame,
828
+ inputs=[seg_id_dropdown, seg_frame_slider, seg_show_mask, magic_code_state, frames_state, masks_state],
829
+ outputs=[seg_image_editor, seg_show_mask]
830
+ )
831
+
832
+ # Image click handler (for SAM-3 segmentation)
833
+ seg_image_editor.select(
834
+ fn=handle_image_click,
835
+ inputs=[seg_id_dropdown, seg_frame_slider, magic_code_state, frames_state, masks_state],
836
+ outputs=[seg_image_editor, masks_state]
837
  )
838
 
839
+ # Download button handler
840
  seg_download_btn.click(
841
  fn=download_segment,
842
+ inputs=[seg_id_dropdown, frames_state, masks_state],
843
+ outputs=[seg_download_file]
844
  )
845
 
846
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -1,3 +1,5 @@
1
  modal
2
  gradio
3
  opencv-python-headless
 
 
 
1
  modal
2
  gradio
3
  opencv-python-headless
4
+ Pillow
5
+ numpy