denizaybey commited on
Commit
050c72c
·
2 Parent(s): ad90bbfbd7ca9e

Merge branch 'feat/download-segment-zip' into 'main'

Browse files

Add Segmentation Editing tab with 5-row layout

See merge request sonne-technology/bsod-tv/waveform-matching-gradio-front-end!27

Files changed (2) hide show
  1. app.py +556 -27
  2. requirements.txt +2 -0
app.py CHANGED
@@ -8,11 +8,16 @@
8
  import os
9
  import re
10
  import cv2
 
11
  import time
12
  import base64
13
  import modal
14
  import logging
 
 
15
  import gradio as gr
 
 
16
 
17
  logging.basicConfig(level=logging.DEBUG)
18
  logger = logging.getLogger(__name__)
@@ -189,6 +194,7 @@ def submit_magic_code(magic_code):
189
  gr.update(visible=False),
190
  gr.update(visible=False),
191
  gr.update(visible=False),
 
192
  gr.update(visible=False)
193
  )
194
 
@@ -199,6 +205,7 @@ def submit_magic_code(magic_code):
199
  gr.update(visible=True),
200
  gr.update(visible=True),
201
  gr.update(visible=True),
 
202
  gr.update(visible=True)
203
  )
204
 
@@ -210,6 +217,7 @@ def submit_magic_code(magic_code):
210
  gr.update(visible=False),
211
  gr.update(visible=False),
212
  gr.update(visible=False),
 
213
  gr.update(visible=False)
214
  )
215
 
@@ -221,33 +229,360 @@ def submit_magic_code(magic_code):
221
  gr.update(visible=False),
222
  gr.update(visible=False),
223
  gr.update(visible=False),
 
224
  gr.update(visible=False)
225
  )
226
 
227
 
228
- def load_segment_frame(segment_id, frame_number):
229
  """
230
- Placeholder function to load a specific frame from a segment for editing.
231
- TODO: Connect to backend API to retrieve frame image from video segment.
 
 
 
 
232
  """
233
- if not segment_id:
234
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
- logger.info(f"Loading segment: {segment_id}, frame: {frame_number}")
237
- # Placeholder: Return None (no image loaded yet)
238
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
 
240
 
241
- def download_segment(segment_id, edited_image, frame_number):
242
  """
243
- Placeholder function to download edited segment.
244
- TODO: Connect to backend API to save edited segmentation and download result.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  """
246
  if not segment_id:
247
- return "No segment selected"
 
 
 
 
248
 
249
  logger.info(f"Download requested for segment: {segment_id}")
250
- return f"Download functionality coming soon for {segment_id}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
 
252
 
253
  # Create a professional Gradio interface using the Golden ratio (1.618) for proportions
@@ -436,6 +771,53 @@ label {
436
  padding: 15px;
437
  }
438
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
  """
440
 
441
  # Create a Blocks interface for more customization
@@ -480,6 +862,11 @@ with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="indigo", secondary_hue
480
  with gr.Tab("Segmentation Editing"):
481
  gr.Markdown("### Segmentation Editing Workspace")
482
 
 
 
 
 
 
483
  # Row 1: Magic Code textbox + Submit button
484
  with gr.Row(elem_classes="input-section"):
485
  with gr.Column(scale=3):
@@ -500,8 +887,16 @@ with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="indigo", secondary_hue
500
  info="Select a segment to edit"
501
  )
502
 
503
- # Row 3: ImageEditor component (SAM-3 style interaction)
504
- with gr.Row(elem_classes="input-section", visible=False) as seg_row3:
 
 
 
 
 
 
 
 
505
  seg_image_editor = gr.ImageEditor(
506
  label="Image Click Segmentation",
507
  type="pil",
@@ -509,8 +904,8 @@ with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="indigo", secondary_hue
509
  brush=gr.Brush(colors=["#FF0000"], color_mode="fixed")
510
  )
511
 
512
- # Row 4: Frame number slider
513
- with gr.Row(elem_classes="input-section", visible=False) as seg_row4:
514
  seg_frame_slider = gr.Slider(
515
  minimum=0,
516
  maximum=100,
@@ -521,9 +916,71 @@ with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="indigo", secondary_hue
521
  info="Select video frame to segment"
522
  )
523
 
524
- # Row 5: Download Segment button
525
- with gr.Row(visible=False) as seg_row5:
526
  seg_download_btn = gr.Button("Download Segment", variant="secondary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
527
 
528
  # Wire Content Moderation processing
529
  cm_process_btn.click(
@@ -536,25 +993,97 @@ with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="indigo", secondary_hue
536
  seg_submit_btn.click(
537
  fn=submit_magic_code,
538
  inputs=[seg_magic_code],
539
- outputs=[seg_id_dropdown, seg_magic_code, seg_row2, seg_row3, seg_row4, seg_row5]
540
  )
541
 
 
542
  seg_id_dropdown.change(
543
- fn=load_segment_frame,
544
- inputs=[seg_id_dropdown, seg_frame_slider],
545
- outputs=[seg_image_editor]
546
  )
547
 
 
548
  seg_frame_slider.change(
549
  fn=load_segment_frame,
550
- inputs=[seg_id_dropdown, seg_frame_slider],
551
- outputs=[seg_image_editor]
552
  )
553
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
554
  seg_download_btn.click(
555
  fn=download_segment,
556
- inputs=[seg_id_dropdown, seg_image_editor, seg_frame_slider],
557
- outputs=[seg_magic_code] # Display status message in magic code textbox
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
558
  )
559
 
560
  if __name__ == "__main__":
 
8
  import os
9
  import re
10
  import cv2
11
+ import io
12
  import time
13
  import base64
14
  import modal
15
  import logging
16
+ import tempfile
17
+ import numpy as np
18
  import gradio as gr
19
+ from PIL import Image
20
+ from typing import Dict, Optional, Tuple
21
 
22
  logging.basicConfig(level=logging.DEBUG)
23
  logger = logging.getLogger(__name__)
 
194
  gr.update(visible=False),
195
  gr.update(visible=False),
196
  gr.update(visible=False),
197
+ gr.update(visible=False),
198
  gr.update(visible=False)
199
  )
200
 
 
205
  gr.update(visible=True),
206
  gr.update(visible=True),
207
  gr.update(visible=True),
208
+ gr.update(visible=True),
209
  gr.update(visible=True)
210
  )
211
 
 
217
  gr.update(visible=False),
218
  gr.update(visible=False),
219
  gr.update(visible=False),
220
+ gr.update(visible=False),
221
  gr.update(visible=False)
222
  )
223
 
 
229
  gr.update(visible=False),
230
  gr.update(visible=False),
231
  gr.update(visible=False),
232
+ gr.update(visible=False),
233
  gr.update(visible=False)
234
  )
235
 
236
 
237
+ def download_segment_files(magic_code: str, segment_id: str) -> Tuple[Dict[int, Image.Image], Dict[int, Image.Image], int]:
238
  """
239
+ Download all frame images and alpha masks for a given segment from Modal volume.
240
+
241
+ Returns:
242
+ frames_dict: Dict mapping frame_index -> PIL.Image (original frames)
243
+ masks_dict: Dict mapping frame_index -> PIL.Image (alpha masks)
244
+ max_frame: Maximum frame index found
245
  """
246
+ try:
247
+ modal_volume_name = os.environ['MODERATION_MODAL_VOLUME']
248
+ volume = modal.Volume.from_name(modal_volume_name)
249
+
250
+ segment_path = f"/{magic_code}/{segment_id}"
251
+ frames_dict = {}
252
+ masks_dict = {}
253
+
254
+ logger.info(f"Downloading files from {segment_path}")
255
+
256
+ # List all files in the segment directory
257
+ try:
258
+ files = list(volume.listdir(segment_path))
259
+ except Exception as e:
260
+ logger.error(f"Failed to list segment directory: {e}")
261
+ return {}, {}, 0
262
+
263
+ # Parse filenames and download
264
+ frame_pattern = re.compile(r'^frame_(\d+)\.(jpg|png)$')
265
+ alpha_pattern = re.compile(r'^alpha_frame_(\d+)\.png$')
266
+
267
+ for entry in files:
268
+ if entry.type != modal.volume.FileEntryType.FILE:
269
+ continue
270
+
271
+ filename = os.path.basename(entry.path)
272
+
273
+ # Check if it's a frame file
274
+ frame_match = frame_pattern.match(filename)
275
+ if frame_match:
276
+ frame_idx = int(frame_match.group(1))
277
+ try:
278
+ # Download frame
279
+ file_data = volume.read_file(f"{segment_path}/{filename}")
280
+ img = Image.open(io.BytesIO(file_data))
281
+ frames_dict[frame_idx] = img.copy()
282
+ logger.debug(f"Downloaded frame {frame_idx}")
283
+ except Exception as e:
284
+ logger.error(f"Failed to download {filename}: {e}")
285
+ continue
286
+
287
+ # Check if it's an alpha mask file
288
+ alpha_match = alpha_pattern.match(filename)
289
+ if alpha_match:
290
+ frame_idx = int(alpha_match.group(1))
291
+ try:
292
+ # Download alpha mask
293
+ file_data = volume.read_file(f"{segment_path}/{filename}")
294
+ img = Image.open(io.BytesIO(file_data))
295
+ masks_dict[frame_idx] = img.copy()
296
+ logger.debug(f"Downloaded alpha mask {frame_idx}")
297
+ except Exception as e:
298
+ logger.error(f"Failed to download {filename}: {e}")
299
+ continue
300
+
301
+ max_frame = max(frames_dict.keys()) if frames_dict else 0
302
+ logger.info(f"Downloaded {len(frames_dict)} frames and {len(masks_dict)} masks. Max frame: {max_frame}")
303
+
304
+ return frames_dict, masks_dict, max_frame
305
+
306
+ except Exception as e:
307
+ logger.error(f"Error downloading segment files: {e}")
308
+ return {}, {}, 0
309
+
310
+
311
+ def composite_image_with_mask(frame: Image.Image, mask: Optional[Image.Image], show_mask: bool) -> Image.Image:
312
+ """
313
+ Composite the original frame with the alpha mask overlay.
314
+
315
+ Args:
316
+ frame: Original RGB/RGBA image
317
+ mask: Alpha mask (grayscale or RGBA)
318
+ show_mask: Whether to show the mask overlay
319
+
320
+ Returns:
321
+ Composited PIL Image
322
+ """
323
+ if not show_mask or mask is None:
324
+ return frame.copy()
325
+
326
+ # Convert frame to RGBA if needed
327
+ if frame.mode != 'RGBA':
328
+ frame_rgba = frame.convert('RGBA')
329
+ else:
330
+ frame_rgba = frame.copy()
331
+
332
+ # Convert mask to 'L' (grayscale) if needed
333
+ if mask.mode != 'L':
334
+ mask_gray = mask.convert('L')
335
+ else:
336
+ mask_gray = mask.copy()
337
+
338
+ # Resize mask to match frame if needed
339
+ if mask_gray.size != frame_rgba.size:
340
+ mask_gray = mask_gray.resize(frame_rgba.size, Image.Resampling.LANCZOS)
341
+
342
+ # Create a colored overlay (semi-transparent red)
343
+ overlay = Image.new('RGBA', frame_rgba.size, (255, 0, 0, 128))
344
+
345
+ # Use mask as alpha channel for the overlay
346
+ overlay.putalpha(mask_gray)
347
+
348
+ # Composite
349
+ result = Image.alpha_composite(frame_rgba, overlay)
350
+
351
+ return result
352
+
353
+
354
+ def load_segment_frame(segment_id, frame_number, show_mask, magic_code_state, frames_state, masks_state):
355
+ """
356
+ Load and display a specific frame with optional alpha mask overlay.
357
+ """
358
+ if not segment_id or frames_state is None:
359
+ return None, gr.update()
360
+
361
+ frames_dict = frames_state
362
+ masks_dict = masks_state
363
 
364
+ frame_idx = int(frame_number)
365
+
366
+ if frame_idx not in frames_dict:
367
+ logger.warning(f"Frame {frame_idx} not found in downloaded frames")
368
+ return None, gr.update()
369
+
370
+ frame = frames_dict[frame_idx]
371
+ mask = masks_dict.get(frame_idx, None)
372
+
373
+ # Composite image with mask
374
+ result_image = composite_image_with_mask(frame, mask, show_mask)
375
+
376
+ logger.info(f"Loaded frame {frame_idx} with mask overlay: {show_mask}")
377
+
378
+ return result_image, gr.update()
379
+
380
+
381
+ def handle_keyboard_navigation(key_code, segment_id, current_frame, show_mask, magic_code_state, frames_state, masks_state):
382
+ """
383
+ Handle left/right arrow key navigation for frame slider.
384
+
385
+ Args:
386
+ key_code: JavaScript key code ('ArrowLeft' or 'ArrowRight')
387
+ segment_id: Current segment ID
388
+ current_frame: Current frame number
389
+ show_mask: Whether to show alpha mask overlay
390
+ magic_code_state: Magic code state
391
+ frames_state: Frames dictionary state
392
+ masks_state: Masks dictionary state
393
+
394
+ Returns:
395
+ Tuple of (updated image, updated slider value)
396
+ """
397
+ if not segment_id or frames_state is None:
398
+ return None, gr.update()
399
+
400
+ frames_dict = frames_state
401
+ masks_dict = masks_state
402
+
403
+ # Get min/max from available frames
404
+ available_frames = sorted(frames_dict.keys())
405
+ if not available_frames:
406
+ return None, gr.update()
407
+
408
+ min_frame = available_frames[0]
409
+ max_frame = available_frames[-1]
410
+
411
+ # Calculate new frame number
412
+ new_frame = int(current_frame)
413
+
414
+ if key_code == 'ArrowLeft':
415
+ new_frame = max(min_frame, new_frame - 1)
416
+ elif key_code == 'ArrowRight':
417
+ new_frame = min(max_frame, new_frame + 1)
418
+ else:
419
+ # Unknown key, no change
420
+ return None, gr.update()
421
+
422
+ # If frame didn't change (at boundary), return early
423
+ if new_frame == int(current_frame):
424
+ return None, gr.update()
425
+
426
+ logger.info(f"Keyboard navigation: {key_code} -> frame {new_frame}")
427
+
428
+ # Load the new frame using existing logic
429
+ if new_frame not in frames_dict:
430
+ logger.warning(f"Frame {new_frame} not found in downloaded frames")
431
+ return None, gr.update()
432
+
433
+ frame = frames_dict[new_frame]
434
+ mask = masks_dict.get(new_frame, None)
435
+
436
+ # Composite image with mask
437
+ result_image = composite_image_with_mask(frame, mask, show_mask)
438
+
439
+ # Return updated image and new slider value
440
+ return result_image, gr.update(value=new_frame)
441
+
442
+
443
+ def handle_segment_selection(segment_id, magic_code):
444
+ """
445
+ Handle segment selection: download all files and initialize the view.
446
+ """
447
+ if not segment_id or not magic_code:
448
+ return None, gr.update(maximum=0, value=0), {}, {}, magic_code, gr.update(interactive=False)
449
+
450
+ logger.info(f"Segment selected: {segment_id}")
451
+
452
+ # Download all frames and masks
453
+ frames_dict, masks_dict, max_frame = download_segment_files(magic_code, segment_id)
454
+
455
+ if not frames_dict:
456
+ logger.error("No frames downloaded")
457
+ return None, gr.update(maximum=0, value=0), {}, {}, magic_code, gr.update(interactive=False)
458
+
459
+ # Load first frame
460
+ frame_0 = frames_dict.get(0, None)
461
+ if frame_0 is None:
462
+ # Use first available frame
463
+ frame_0 = frames_dict[min(frames_dict.keys())]
464
+
465
+ # Initial display without mask
466
+ result_image = composite_image_with_mask(frame_0, masks_dict.get(0, None), False)
467
+
468
+ # Enable download button only if masks are available
469
+ has_masks = len(masks_dict) > 0
470
+
471
+ return (
472
+ result_image,
473
+ gr.update(minimum=0, maximum=max_frame, value=0),
474
+ frames_dict,
475
+ masks_dict,
476
+ magic_code,
477
+ gr.update(interactive=has_masks)
478
+ )
479
 
480
 
481
+ def handle_image_click(segment_id, frame_number, magic_code, frames_state, masks_state, evt: gr.SelectData):
482
  """
483
+ Handle click on ImageEditor: send coordinates to SAM-3, get new mask, update display.
484
+ """
485
+ if not segment_id or frames_state is None or evt is None:
486
+ return None, masks_state, gr.update()
487
+
488
+ # Extract click coordinates
489
+ x, y = evt.index[0], evt.index[1]
490
+ frame_idx = int(frame_number)
491
+
492
+ logger.info(f"Click detected at ({x}, {y}) on frame {frame_idx}")
493
+
494
+ if frame_idx not in frames_state:
495
+ logger.error(f"Frame {frame_idx} not in state")
496
+ return None, masks_state, gr.update()
497
+
498
+ try:
499
+ # Get the original frame
500
+ frame = frames_state[frame_idx]
501
+
502
+ # Convert frame to bytes
503
+ img_byte_arr = io.BytesIO()
504
+ frame.save(img_byte_arr, format='PNG')
505
+ img_bytes = img_byte_arr.getvalue()
506
+
507
+ # Call Modal SAM-3 segmentation function
508
+ logger.info(f"Calling SAM-3 segmentation with coordinates ({x}, {y})")
509
+
510
+ try:
511
+ sam_function = modal.Function.from_name("Content-Moderation", "sam3_segmentation_function")
512
+ mask_bytes = sam_function.remote(image_bytes=img_bytes, x=x, y=y)
513
+
514
+ # Parse returned mask
515
+ new_mask = Image.open(io.BytesIO(mask_bytes))
516
+
517
+ # Update masks dict
518
+ masks_state[frame_idx] = new_mask.copy()
519
+
520
+ # Composite and return updated image (always show mask after click)
521
+ result_image = composite_image_with_mask(frame, new_mask, True)
522
+
523
+ logger.info(f"Successfully updated mask for frame {frame_idx}")
524
+
525
+ # Enable download button since we now have at least one mask
526
+ return result_image, masks_state, gr.update(interactive=True)
527
+
528
+ except Exception as e:
529
+ logger.error(f"Error calling SAM-3 function: {e}")
530
+ # Return current view without update
531
+ frame = frames_state[frame_idx]
532
+ mask = masks_state.get(frame_idx, None)
533
+ result_image = composite_image_with_mask(frame, mask, True)
534
+ return result_image, masks_state, gr.update()
535
+
536
+ except Exception as e:
537
+ logger.error(f"Error handling image click: {e}")
538
+ return None, masks_state, gr.update()
539
+
540
+
541
+ def download_segment(segment_id, frames_state, masks_state):
542
+ """
543
+ Package and download only the alpha masks for the selected segment as a ZIP file.
544
+ ZIP filename will be {segment_id}.zip.
545
+
546
+ Returns:
547
+ Tuple of (status_message, file_path, status_visibility)
548
  """
549
  if not segment_id:
550
+ return gr.update(value="No segment selected", visible=True), None, gr.update(visible=True)
551
+
552
+ if not masks_state:
553
+ logger.warning(f"No alpha masks available for segment: {segment_id}")
554
+ return gr.update(value="No alpha masks available", visible=True), None, gr.update(visible=True)
555
 
556
  logger.info(f"Download requested for segment: {segment_id}")
557
+
558
+ try:
559
+ import shutil
560
+
561
+ # Create temporary directory for alpha mask files only
562
+ with tempfile.TemporaryDirectory() as tmpdir:
563
+ # Save only alpha masks
564
+ for frame_idx, mask_img in masks_state.items():
565
+ mask_path = os.path.join(tmpdir, f"alpha_frame_{frame_idx:06d}.png")
566
+ mask_img.save(mask_path)
567
+
568
+ # Create ZIP with segment UUID as filename
569
+ zip_path = f"/tmp/{segment_id}.zip"
570
+ shutil.make_archive(zip_path.replace('.zip', ''), 'zip', tmpdir)
571
+
572
+ logger.info(f"Created ZIP at {zip_path} with {len(masks_state)} alpha masks")
573
+ return (
574
+ gr.update(value=f"✓ Downloaded {len(masks_state)} alpha masks", visible=True),
575
+ zip_path,
576
+ gr.update(visible=True)
577
+ )
578
+
579
+ except Exception as e:
580
+ logger.error(f"Error creating download package: {e}")
581
+ return (
582
+ gr.update(value=f"Error: {str(e)}", visible=True),
583
+ None,
584
+ gr.update(visible=True)
585
+ )
586
 
587
 
588
  # Create a professional Gradio interface using the Golden ratio (1.618) for proportions
 
771
  padding: 15px;
772
  }
773
  }
774
+
775
+ /* Loading modal overlay */
776
+ #download-loading-modal {
777
+ display: none;
778
+ position: fixed;
779
+ top: 0;
780
+ left: 0;
781
+ width: 100%;
782
+ height: 100%;
783
+ background-color: rgba(0, 0, 0, 0.7);
784
+ z-index: 9999;
785
+ justify-content: center;
786
+ align-items: center;
787
+ }
788
+
789
+ #download-loading-modal.show {
790
+ display: flex;
791
+ }
792
+
793
+ .loading-content {
794
+ background-color: var(--card-bg);
795
+ padding: 40px;
796
+ border-radius: var(--border-radius);
797
+ text-align: center;
798
+ box-shadow: 0 10px 40px rgba(0, 0, 0, 0.5);
799
+ }
800
+
801
+ .spinner {
802
+ border: 4px solid var(--border-color);
803
+ border-top: 4px solid var(--primary-color);
804
+ border-radius: 50%;
805
+ width: 50px;
806
+ height: 50px;
807
+ animation: spin 1s linear infinite;
808
+ margin: 0 auto 20px;
809
+ }
810
+
811
+ @keyframes spin {
812
+ 0% { transform: rotate(0deg); }
813
+ 100% { transform: rotate(360deg); }
814
+ }
815
+
816
+ .loading-text {
817
+ color: var(--text-color);
818
+ font-size: 18px;
819
+ font-weight: 500;
820
+ }
821
  """
822
 
823
  # Create a Blocks interface for more customization
 
862
  with gr.Tab("Segmentation Editing"):
863
  gr.Markdown("### Segmentation Editing Workspace")
864
 
865
+ # State management for session data
866
+ magic_code_state = gr.State(value=None)
867
+ frames_state = gr.State(value=None)
868
+ masks_state = gr.State(value=None)
869
+
870
  # Row 1: Magic Code textbox + Submit button
871
  with gr.Row(elem_classes="input-section"):
872
  with gr.Column(scale=3):
 
887
  info="Select a segment to edit"
888
  )
889
 
890
+ # Row 3: Show Alpha Mask checkbox
891
+ with gr.Row(elem_classes="input-section", visible=False) as seg_row3_checkbox:
892
+ seg_show_mask = gr.Checkbox(
893
+ label="Show Alpha Mask",
894
+ value=False,
895
+ interactive=True
896
+ )
897
+
898
+ # Row 4: ImageEditor component (SAM-3 style interaction)
899
+ with gr.Row(elem_classes="input-section", visible=False) as seg_row4:
900
  seg_image_editor = gr.ImageEditor(
901
  label="Image Click Segmentation",
902
  type="pil",
 
904
  brush=gr.Brush(colors=["#FF0000"], color_mode="fixed")
905
  )
906
 
907
+ # Row 5: Frame number slider
908
+ with gr.Row(elem_classes="input-section", visible=False) as seg_row5:
909
  seg_frame_slider = gr.Slider(
910
  minimum=0,
911
  maximum=100,
 
916
  info="Select video frame to segment"
917
  )
918
 
919
+ # Row 6: Download Segment button
920
+ with gr.Row(visible=False) as seg_row6:
921
  seg_download_btn = gr.Button("Download Segment", variant="secondary")
922
+ seg_download_status = gr.Textbox(label="Status", value="", visible=False, interactive=False)
923
+ seg_download_file = gr.File(label="Download", visible=False)
924
+
925
+ # Hidden component for keyboard event capture
926
+ seg_keyboard_input = gr.Textbox(visible=False, elem_id="seg_keyboard_input")
927
+
928
+ # Loading modal HTML
929
+ gr.HTML("""
930
+ <div id="download-loading-modal">
931
+ <div class="loading-content">
932
+ <div class="spinner"></div>
933
+ <div class="loading-text">Preparing download…</div>
934
+ </div>
935
+ </div>
936
+ <script>
937
+ function showDownloadLoading() {
938
+ const modal = document.getElementById('download-loading-modal');
939
+ if (modal) modal.classList.add('show');
940
+ }
941
+
942
+ function hideDownloadLoading() {
943
+ const modal = document.getElementById('download-loading-modal');
944
+ if (modal) modal.classList.remove('show');
945
+ }
946
+
947
+ // Listen for download button clicks
948
+ document.addEventListener('DOMContentLoaded', function() {
949
+ const checkButton = setInterval(function() {
950
+ const downloadBtn = document.querySelector('button:has-text("Download Segment")');
951
+ if (!downloadBtn) {
952
+ // Fallback: find button by content
953
+ const buttons = document.querySelectorAll('button');
954
+ for (let btn of buttons) {
955
+ if (btn.textContent.includes('Download Segment')) {
956
+ setupDownloadButton(btn);
957
+ clearInterval(checkButton);
958
+ break;
959
+ }
960
+ }
961
+ } else {
962
+ setupDownloadButton(downloadBtn);
963
+ clearInterval(checkButton);
964
+ }
965
+ }, 500);
966
+
967
+ function setupDownloadButton(btn) {
968
+ btn.addEventListener('click', function() {
969
+ showDownloadLoading();
970
+ // Hide modal after function completes (watch for Gradio loading to finish)
971
+ setTimeout(function checkLoading() {
972
+ const gradioLoading = document.querySelector('.loading');
973
+ if (!gradioLoading) {
974
+ setTimeout(hideDownloadLoading, 500);
975
+ } else {
976
+ setTimeout(checkLoading, 200);
977
+ }
978
+ }, 200);
979
+ });
980
+ }
981
+ });
982
+ </script>
983
+ """)
984
 
985
  # Wire Content Moderation processing
986
  cm_process_btn.click(
 
993
  seg_submit_btn.click(
994
  fn=submit_magic_code,
995
  inputs=[seg_magic_code],
996
+ outputs=[seg_id_dropdown, seg_magic_code, seg_row2, seg_row3_checkbox, seg_row4, seg_row5, seg_row6]
997
  )
998
 
999
+ # Segment selection handler
1000
  seg_id_dropdown.change(
1001
+ fn=handle_segment_selection,
1002
+ inputs=[seg_id_dropdown, seg_magic_code],
1003
+ outputs=[seg_image_editor, seg_frame_slider, frames_state, masks_state, magic_code_state, seg_download_btn]
1004
  )
1005
 
1006
+ # Frame slider handler
1007
  seg_frame_slider.change(
1008
  fn=load_segment_frame,
1009
+ inputs=[seg_id_dropdown, seg_frame_slider, seg_show_mask, magic_code_state, frames_state, masks_state],
1010
+ outputs=[seg_image_editor, seg_show_mask]
1011
  )
1012
 
1013
+ # Show mask checkbox handler
1014
+ seg_show_mask.change(
1015
+ fn=load_segment_frame,
1016
+ inputs=[seg_id_dropdown, seg_frame_slider, seg_show_mask, magic_code_state, frames_state, masks_state],
1017
+ outputs=[seg_image_editor, seg_show_mask]
1018
+ )
1019
+
1020
+ # Image click handler (for SAM-3 segmentation)
1021
+ seg_image_editor.select(
1022
+ fn=handle_image_click,
1023
+ inputs=[seg_id_dropdown, seg_frame_slider, magic_code_state, frames_state, masks_state],
1024
+ outputs=[seg_image_editor, masks_state, seg_download_btn]
1025
+ )
1026
+
1027
+ # Download button handler
1028
  seg_download_btn.click(
1029
  fn=download_segment,
1030
+ inputs=[seg_id_dropdown, frames_state, masks_state],
1031
+ outputs=[seg_download_status, seg_download_file, seg_download_status]
1032
+ )
1033
+
1034
+ # Keyboard navigation handler
1035
+ seg_keyboard_input.change(
1036
+ fn=handle_keyboard_navigation,
1037
+ inputs=[
1038
+ seg_keyboard_input,
1039
+ seg_id_dropdown,
1040
+ seg_frame_slider,
1041
+ seg_show_mask,
1042
+ magic_code_state,
1043
+ frames_state,
1044
+ masks_state
1045
+ ],
1046
+ outputs=[seg_image_editor, seg_frame_slider]
1047
+ )
1048
+
1049
+ # Add JavaScript to capture arrow key events
1050
+ demo.load(
1051
+ None,
1052
+ None,
1053
+ None,
1054
+ js="""
1055
+ () => {
1056
+ // Wait for the DOM to be ready
1057
+ setTimeout(() => {
1058
+ const keyboardInput = document.getElementById('seg_keyboard_input');
1059
+ if (!keyboardInput) {
1060
+ console.warn('Keyboard input element not found');
1061
+ return;
1062
+ }
1063
+
1064
+ // Add keydown listener to document
1065
+ document.addEventListener('keydown', (e) => {
1066
+ // Only handle arrow keys
1067
+ if (e.key === 'ArrowLeft' || e.key === 'ArrowRight') {
1068
+ // Check if we're in the Segmentation Editing tab
1069
+ const segTab = document.querySelector('[id*="segmentation-editing"]');
1070
+ const activeTab = document.querySelector('.tab-nav button.selected');
1071
+
1072
+ if (activeTab && activeTab.textContent.includes('Segmentation Editing')) {
1073
+ e.preventDefault();
1074
+
1075
+ // Update the hidden input to trigger the change event
1076
+ const textarea = keyboardInput.querySelector('textarea');
1077
+ if (textarea) {
1078
+ textarea.value = e.key;
1079
+ textarea.dispatchEvent(new Event('input', { bubbles: true }));
1080
+ }
1081
+ }
1082
+ }
1083
+ });
1084
+ }, 1000);
1085
+ }
1086
+ """
1087
  )
1088
 
1089
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -1,3 +1,5 @@
1
  modal
2
  gradio
3
  opencv-python-headless
 
 
 
1
  modal
2
  gradio
3
  opencv-python-headless
4
+ Pillow
5
+ numpy