usingcolor commited on
Commit
81be487
·
1 Parent(s): df9d184

Use Xet for assets

Browse files
Files changed (5) hide show
  1. .gitattributes +2 -0
  2. app.py +31 -26
  3. assets/dog.jpg +0 -0
  4. assets/green_mamba.jpg +3 -0
  5. assets/leo.jpg +3 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/green_mamba.jpg filter=lfs diff=lfs merge=lfs -text
37
+ assets/leo.jpg filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -73,22 +73,22 @@ CSS_STYLE = """
73
  # -----------------------------
74
 
75
  def get_model():
76
- global _GLOBAL_MODEL
 
77
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
78
- if _GLOBAL_MODEL is None:
79
- print(f"Downloading {MODEL_FILENAME} from {MODEL_REPO}...")
80
- try:
81
- checkpoint_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
82
- model = MambaEye(**MODEL_CONFIG)
83
- model.load_state_dict(torch.load(checkpoint_path, map_location=device, weights_only=True))
84
- model.to(device)
85
- model.eval()
86
- _GLOBAL_MODEL = model
87
- print("Model loaded successfully.")
88
- except Exception as e:
89
- print(f"Failed to load model: {e}")
90
- raise
91
- return _GLOBAL_MODEL, device
92
 
93
  def transfer_inference_params(params, device):
94
  if params is None or getattr(params, "key_value_memory_dict", None) is None:
@@ -248,7 +248,7 @@ def run_auto_scan(image, scan_pattern, sequence_length):
248
  state['x_offset'], state['y_offset'], state['h'], state['w']
249
  )
250
 
251
- return img_display, format_predictions(final_probs), state, f"Auto Scan Complete. Extracted {sequence_length} patches. Click to add more!"
252
 
253
  @spaces.GPU
254
  def process_click_inference(x_orig, y_orig, original_image, state):
@@ -269,6 +269,9 @@ def process_click_inference(x_orig, y_orig, original_image, state):
269
  canvas_y = int(x_orig * ratio) + state['y_offset']
270
  canvas_x = int(y_orig * ratio) + state['x_offset']
271
 
 
 
 
272
  # 1px flexible precision anchoring the patch directly onto the exact center click
273
  px = max(0, min(int(canvas_x - PATCH_SIZE / 2), TARGET_CANVAS_SIZE - PATCH_SIZE))
274
  py = max(0, min(int(canvas_y - PATCH_SIZE / 2), TARGET_CANVAS_SIZE - PATCH_SIZE))
@@ -298,7 +301,7 @@ def process_click_inference(x_orig, y_orig, original_image, state):
298
  state['x_offset'], state['y_offset'], state['h'], state['w']
299
  )
300
 
301
- return img_display, format_predictions(final_probs), state, f"Added patch {state['sequence_length']} (Total {state['inference_params'].seqlen_offset} steps)."
302
 
303
  def on_click(evt: gr.SelectData, original_image, state):
304
  x_orig, y_orig = evt.index
@@ -306,26 +309,27 @@ def on_click(evt: gr.SelectData, original_image, state):
306
 
307
  def on_upload(image):
308
  if image is None:
309
- return None, None, {"Waiting...": 1.0}, None, "Upload Image"
310
 
311
  # Pre-render the grey background immediately on upload
312
  grey_base = Image.fromarray(image).convert("L").convert("RGB")
313
  grey_base_np = (np.array(grey_base).astype(float) * 0.4 + 160).clip(0, 255).astype(np.uint8)
314
 
315
- return grey_base_np, image, {"Click Auto Scan or click the image": 1.0}, None, "Ready. You can Auto Scan or click."
316
 
317
  def on_clear(original_image):
318
  if original_image is None:
319
- return None, {"Cleared": 1.0}, None, "Cleared"
320
 
321
  grey_base = Image.fromarray(original_image).convert("L").convert("RGB")
322
  grey_base_np = (np.array(grey_base).astype(float) * 0.4 + 160).clip(0, 255).astype(np.uint8)
323
 
324
- return grey_base_np, {"Cleared": 1.0}, init_state_for_image(original_image), "Selections cleared. Ready for new patch sequence."
325
 
 
326
  with gr.Blocks(title="MambaEye Interactive Demo", css=CSS_STYLE) as demo:
327
  gr.Markdown("# MambaEye Interactive Inference Demo")
328
- gr.Markdown("This interface incorporates the full **MambaEye-base** model natively.")
329
 
330
  with gr.Row():
331
  with gr.Column(scale=2):
@@ -346,6 +350,7 @@ with gr.Blocks(title="MambaEye Interactive Demo", css=CSS_STYLE) as demo:
346
 
347
  with gr.Column(scale=1):
348
  model_output_label = gr.Label(label="MambaEye Output Predictions", num_top_classes=5)
 
349
  status_text = gr.Markdown("Status: Waiting for image upload...")
350
 
351
  state = gr.State(None)
@@ -354,25 +359,25 @@ with gr.Blocks(title="MambaEye Interactive Demo", css=CSS_STYLE) as demo:
354
  input_image.upload(
355
  fn=on_upload,
356
  inputs=[input_image],
357
- outputs=[input_image, original_image_state, model_output_label, state, status_text]
358
  )
359
 
360
  auto_btn.click(
361
  fn=run_auto_scan,
362
  inputs=[original_image_state, scan_pattern, seq_length],
363
- outputs=[input_image, model_output_label, state, status_text]
364
  )
365
 
366
  input_image.select(
367
  fn=on_click,
368
  inputs=[original_image_state, state],
369
- outputs=[input_image, model_output_label, state, status_text]
370
  )
371
 
372
  clear_btn.click(
373
  fn=on_clear,
374
  inputs=[original_image_state],
375
- outputs=[input_image, model_output_label, state, status_text]
376
  )
377
 
378
  if __name__ == "__main__":
 
73
  # -----------------------------
74
 
75
  def get_model():
76
+ # As the @spaces.GPU worker natively forks off, it effortlessly snags the _GLOBAL_CPU_MODEL reference
77
+ # directly passing its exact tensor parameters perfectly over exactly across PCI-e into active VRAM!
78
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
79
+ _GLOBAL_CPU_MODEL.to(device)
80
+ return _GLOBAL_CPU_MODEL, device
81
+
82
+ # --- FALLBACK CSS INJECTION ---
83
+ # We use a CSS override to display a precision crosshair since custom dynamic HTML div overlays
84
+ # are deeply rejected by Gradio's internal Canvas shadow properties.
85
+ CSS_STYLE = """
86
+ .gradio-image-hook, .gradio-image-hook * {
87
+ cursor: crosshair !important;
88
+ }
89
+ """
90
+
91
+ # --- HOVER SCRIPT INJECTION ---
 
92
 
93
  def transfer_inference_params(params, device):
94
  if params is None or getattr(params, "key_value_memory_dict", None) is None:
 
248
  state['x_offset'], state['y_offset'], state['h'], state['w']
249
  )
250
 
251
+ return img_display, format_predictions(final_probs), state, f"Auto Scan Complete. Extracted {sequence_length} patches. Click to add more!", sequence_length
252
 
253
  @spaces.GPU
254
  def process_click_inference(x_orig, y_orig, original_image, state):
 
269
  canvas_y = int(x_orig * ratio) + state['y_offset']
270
  canvas_x = int(y_orig * ratio) + state['x_offset']
271
 
272
+ # 1px flexible precision anchoring the patch directly onto the exact center click
273
+ px = max(0, min(int(canvas_x - PATCH_SIZE / 2), TARGET_CANVAS_SIZE - PATCH_SIZE))
274
+ py = max(0, min(int(canvas_y - PATCH_SIZE / 2), TARGET_CANVAS_SIZE - PATCH_SIZE))
275
  # 1px flexible precision anchoring the patch directly onto the exact center click
276
  px = max(0, min(int(canvas_x - PATCH_SIZE / 2), TARGET_CANVAS_SIZE - PATCH_SIZE))
277
  py = max(0, min(int(canvas_y - PATCH_SIZE / 2), TARGET_CANVAS_SIZE - PATCH_SIZE))
 
301
  state['x_offset'], state['y_offset'], state['h'], state['w']
302
  )
303
 
304
+ return img_display, format_predictions(final_probs), state, f"Added patch {state['sequence_length']} (Total {state['inference_params'].seqlen_offset} steps).", state['sequence_length']
305
 
306
  def on_click(evt: gr.SelectData, original_image, state):
307
  x_orig, y_orig = evt.index
 
309
 
310
  def on_upload(image):
311
  if image is None:
312
+ return None, None, {"Waiting...": 1.0}, None, "Upload Image", 0
313
 
314
  # Pre-render the grey background immediately on upload
315
  grey_base = Image.fromarray(image).convert("L").convert("RGB")
316
  grey_base_np = (np.array(grey_base).astype(float) * 0.4 + 160).clip(0, 255).astype(np.uint8)
317
 
318
+ return grey_base_np, image, {"Click Auto Scan or click the image": 1.0}, None, "Ready. You can Auto Scan or click.", 0
319
 
320
  def on_clear(original_image):
321
  if original_image is None:
322
+ return None, {"Cleared": 1.0}, None, "Cleared", 0
323
 
324
  grey_base = Image.fromarray(original_image).convert("L").convert("RGB")
325
  grey_base_np = (np.array(grey_base).astype(float) * 0.4 + 160).clip(0, 255).astype(np.uint8)
326
 
327
+ return grey_base_np, {"Cleared": 1.0}, init_state_for_image(original_image), "Selections cleared. Ready for new patch sequence.", 0
328
 
329
+ with gr.Blocks(title="MambaEye Interactive Demo", css=CSS_STYLE) as demo:
330
  with gr.Blocks(title="MambaEye Interactive Demo", css=CSS_STYLE) as demo:
331
  gr.Markdown("# MambaEye Interactive Inference Demo")
332
+ gr.Markdown("This interface incorporates the full **MambaEye-base** model natively.\n\n**Note**: The first inference or Auto Scan may take **1~2 minutes** to compile CUDA kernels and build hardware cache. Subsequent patch clicks will be dramatically faster!")
333
 
334
  with gr.Row():
335
  with gr.Column(scale=2):
 
350
 
351
  with gr.Column(scale=1):
352
  model_output_label = gr.Label(label="MambaEye Output Predictions", num_top_classes=5)
353
+ seq_len_display = gr.Number(label="Total Sequenced Patches", value=0, interactive=False)
354
  status_text = gr.Markdown("Status: Waiting for image upload...")
355
 
356
  state = gr.State(None)
 
359
  input_image.upload(
360
  fn=on_upload,
361
  inputs=[input_image],
362
+ outputs=[input_image, original_image_state, model_output_label, state, status_text, seq_len_display]
363
  )
364
 
365
  auto_btn.click(
366
  fn=run_auto_scan,
367
  inputs=[original_image_state, scan_pattern, seq_length],
368
+ outputs=[input_image, model_output_label, state, status_text, seq_len_display]
369
  )
370
 
371
  input_image.select(
372
  fn=on_click,
373
  inputs=[original_image_state, state],
374
+ outputs=[input_image, model_output_label, state, status_text, seq_len_display]
375
  )
376
 
377
  clear_btn.click(
378
  fn=on_clear,
379
  inputs=[original_image_state],
380
+ outputs=[input_image, model_output_label, state, status_text, seq_len_display]
381
  )
382
 
383
  if __name__ == "__main__":
assets/dog.jpg ADDED
assets/green_mamba.jpg ADDED

Git LFS Details

  • SHA256: 1d269ad4a9cbc7283b6c34fb4cce2cce9e3be503d140476a8c0f8c62d93a2175
  • Pointer size: 132 Bytes
  • Size of remote file: 1.06 MB
assets/leo.jpg ADDED

Git LFS Details

  • SHA256: 70ce1fd8334e58776b96c89ceaae9fa26c9533be05b2e3d8d129ba097228b2f6
  • Pointer size: 131 Bytes
  • Size of remote file: 209 kB