usingcolor commited on
Commit
93331e9
·
1 Parent(s): b468d75

chore: initialize git repository and remove Dockerfile

Browse files
Files changed (4) hide show
  1. Dockerfile +0 -42
  2. README.md +3 -1
  3. app.py +48 -25
  4. requirements.txt +1 -0
Dockerfile DELETED
@@ -1,42 +0,0 @@
1
- FROM pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel
2
-
3
- ENV DEBIAN_FRONTEND=noninteractive
4
-
5
- # Install system dependencies
6
- RUN apt-get update && apt-get install -y \
7
- git ffmpeg libsm6 libxext6 cmake rsync libgl1 curl \
8
- && rm -rf /var/lib/apt/lists/*
9
-
10
- # Install uv for fast python package installations
11
- RUN pip install uv
12
-
13
- # Create non-root user required by HuggingFace Spaces
14
- RUN useradd -m -u 1000 user
15
- USER user
16
- ENV HOME=/home/user \
17
- PATH=/home/user/.local/bin:$PATH \
18
- PYTHONWARNINGS=ignore \
19
- GRADIO_SERVER_NAME="0.0.0.0" \
20
- GRADIO_SERVER_PORT="7860"
21
-
22
- WORKDIR $HOME/app
23
-
24
- # Copy requirements and install base dependencies via uv
25
- COPY --chown=user requirements.txt $HOME/app/
26
- RUN uv pip install --system --upgrade pip
27
- RUN uv pip install --system -r requirements.txt
28
-
29
- # Specify CUDA architectures to compile for various GPUs commonly found on HF Spaces
30
- # 7.5: T4, 8.0: A100, 8.6: A10G/RTX3090, 8.9: L4, 9.0: H100
31
- ENV TORCH_CUDA_ARCH_LIST="7.5 8.0 8.6 8.9 9.0+PTX"
32
-
33
- # Install causal-conv1d and mamba-ssm requiring compilation
34
- # Since this uses the `devel` image, `nvcc` is available safely!
35
- RUN uv pip install --system causal-conv1d==1.5.0.post8 --no-build-isolation
36
- RUN uv pip install --system mamba-ssm==2.2.4 --no-build-isolation
37
-
38
- # Copy the rest of the application code
39
- COPY --chown=user . $HOME/app/
40
-
41
- # Run the Gradio app
42
- CMD ["python", "app.py"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -3,7 +3,9 @@ title: MambaEye
3
  emoji: 👁️
4
  colorFrom: blue
5
  colorTo: green
6
- sdk: docker
 
 
7
  pinned: false
8
  ---
9
 
 
3
  emoji: 👁️
4
  colorFrom: blue
5
  colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 6.10.0
8
+ app_file: app.py
9
  pinned: false
10
  ---
11
 
app.py CHANGED
@@ -1,7 +1,26 @@
1
  import sys
2
  import os
 
3
  import time
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  # Add the cloned MambaEye repository to the Python path
6
  sys.path.append(os.path.join(os.path.dirname(__file__), "MambaEye"))
7
 
@@ -13,6 +32,7 @@ from PIL import Image, ImageDraw
13
  import torchvision.transforms as T
14
  from torchvision.models import ResNet50_Weights
15
  from huggingface_hub import hf_hub_download
 
16
 
17
  # MambaEye Imports
18
  from mambaeye.model import MambaEye
@@ -52,8 +72,10 @@ def get_model():
52
  try:
53
  checkpoint_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
54
  model = MambaEye(**MODEL_CONFIG)
55
- model.load_state_dict(torch.load(checkpoint_path, map_location=DEVICE))
56
- model.to(DEVICE)
 
 
57
  model.eval()
58
  _GLOBAL_MODEL = model
59
  print("Model loaded successfully.")
@@ -99,8 +121,6 @@ def preprocess_image(image_arr):
99
  return canvas, x_offset, y_offset, new_h, new_w
100
 
101
  def extract_patch(canvas_tensor, px, py):
102
- # px, py are coordinates on the canvas
103
- # Bound them
104
  px = max(0, min(px, TARGET_CANVAS_SIZE - PATCH_SIZE))
105
  py = max(0, min(py, TARGET_CANVAS_SIZE - PATCH_SIZE))
106
  patch = canvas_tensor[:, px : px + PATCH_SIZE, py : py + PATCH_SIZE]
@@ -114,7 +134,6 @@ def draw_patches_on_image(image_arr, positions, x_offset, y_offset, h, w):
114
  ratio = min(TARGET_CANVAS_SIZE / orig_w, TARGET_CANVAS_SIZE / orig_h)
115
 
116
  for i, (px, py) in enumerate(positions):
117
- # Map back to original image coordinates
118
  orig_y = (py - y_offset) / ratio
119
  orig_x = (px - x_offset) / ratio
120
  orig_px_size = PATCH_SIZE / ratio
@@ -138,7 +157,7 @@ def init_state_for_image(image):
138
  return {
139
  'inference_params': None,
140
  'cur_location': None,
141
- 'canvas_tensor': canvas_tensor,
142
  'x_offset': x_offset,
143
  'y_offset': y_offset,
144
  'h': h,
@@ -148,11 +167,14 @@ def init_state_for_image(image):
148
  'sequence_length': 0
149
  }
150
 
 
151
  def run_auto_scan(image, scan_pattern, sequence_length):
152
  if image is None:
153
  return None, {"Upload Image": 1.0}, None, "Upload Image"
154
 
155
  model = get_model()
 
 
156
  state = init_state_for_image(image)
157
 
158
  x_end = max(state['x_offset'] + 1, state['x_offset'] + state['h'])
@@ -168,7 +190,6 @@ def run_auto_scan(image, scan_pattern, sequence_length):
168
  scan_pattern=scan_pattern, rng=rng
169
  )
170
 
171
- # We allow up to Max Seq Length (say 4000) for ongoing clicks later.
172
  inference_params = InferenceParams(max_seqlen=4000, max_batch_size=1)
173
  state['inference_params'] = inference_params
174
 
@@ -181,21 +202,23 @@ def run_auto_scan(image, scan_pattern, sequence_length):
181
  move_emb = _compute_move_embedding(loc_tensor, cur_location)
182
  cur_location = loc_tensor
183
 
184
- patch = extract_patch(state['canvas_tensor'], px, py)
185
  patches_list.append(patch)
186
  moves_list.append(move_emb.squeeze(0))
187
 
188
- img_seq = torch.stack(patches_list, dim=0).unsqueeze(0).to(DEVICE) # (1, L, 768)
189
- move_seq = torch.stack(moves_list, dim=0).unsqueeze(0).to(DEVICE) # (1, L, 512)
190
 
191
  with torch.no_grad():
192
  out = model(img_seq, move_seq, inference_params=inference_params)
193
  final_probs = F.softmax(out[0, -1], dim=-1).cpu().numpy()
194
  inference_params.seqlen_offset += img_seq.shape[1]
195
 
196
- state['cur_location'] = cur_location
197
  state['drawn_positions'] = positions
198
  state['sequence_length'] = sequence_length
 
 
199
 
200
  img_display, _ = draw_patches_on_image(
201
  state['original_image'], state['drawn_positions'],
@@ -204,16 +227,18 @@ def run_auto_scan(image, scan_pattern, sequence_length):
204
 
205
  return img_display, format_predictions(final_probs), state, f"Auto Scan Complete. Extracted {sequence_length} patches. Click to add more!"
206
 
 
207
  def on_click(evt: gr.SelectData, original_image, state):
208
  if original_image is None:
209
  return None, {"Upload Image": 1.0}, state, "Upload Image"
210
 
 
 
 
211
  if state is None or state.get('inference_params') is None:
212
  # Initialize state to begin a new purely user-guided sequence
213
  state = init_state_for_image(original_image)
214
  state['inference_params'] = InferenceParams(max_seqlen=4000, max_batch_size=1)
215
-
216
- model = get_model()
217
 
218
  x_orig, y_orig = evt.index
219
  orig_h, orig_w = state['original_image'].shape[:2]
@@ -225,20 +250,21 @@ def on_click(evt: gr.SelectData, original_image, state):
225
  px = (canvas_x // PATCH_SIZE) * PATCH_SIZE
226
  py = (canvas_y // PATCH_SIZE) * PATCH_SIZE
227
 
 
228
  loc_tensor = torch.tensor([[px, py]], dtype=torch.long, device=DEVICE)
229
- move_emb = _compute_move_embedding(loc_tensor, state['cur_location'])
230
 
231
- patch = extract_patch(state['canvas_tensor'], px, py)
232
 
233
- img_seq = patch.unsqueeze(0).unsqueeze(0).to(DEVICE) # (1, 1, 768)
234
- move_seq = move_emb.unsqueeze(0).to(DEVICE) # (1, 1, 512)
235
 
236
  with torch.no_grad():
237
  out = model(img_seq, move_seq, inference_params=state['inference_params'])
238
  final_probs = F.softmax(out[0, -1], dim=-1).cpu().numpy()
239
  state['inference_params'].seqlen_offset += 1
240
 
241
- state['cur_location'] = loc_tensor
242
  state['drawn_positions'].append((px, py))
243
  state['sequence_length'] += 1
244
 
@@ -247,13 +273,12 @@ def on_click(evt: gr.SelectData, original_image, state):
247
  state['x_offset'], state['y_offset'], state['h'], state['w']
248
  )
249
 
250
- return img_display, format_predictions(final_probs), state, f"Added patch {state['sequence_length']} (Total {state['inference_params'].seqlen_offset} inference steps done)."
251
 
252
  def on_upload(image):
253
  if image is None:
254
  return None, {"Waiting...": 1.0}, None, "Upload Image"
255
- # Pre-warm model in background
256
- get_model()
257
  return image, {"Click Auto Scan or click the image": 1.0}, None, "Ready. You can Auto Scan or click."
258
 
259
  def on_clear(original_image):
@@ -261,11 +286,10 @@ def on_clear(original_image):
261
  return None, {"Cleared": 1.0}, None, "Cleared"
262
  return original_image, {"Cleared": 1.0}, init_state_for_image(original_image), "Selections cleared. Ready for new patch sequence."
263
 
264
-
265
  # Build the Gradio App Blocks
266
  with gr.Blocks(title="MambaEye Interactive Demo", theme=gr.themes.Soft()) as demo:
267
- gr.Markdown("# MambaEye Interactive Inference Demo")
268
- gr.Markdown("This interface incorporates the full **MambaEye-base** model inference in real-time.")
269
 
270
  with gr.Row():
271
  with gr.Column(scale=2):
@@ -297,7 +321,6 @@ with gr.Blocks(title="MambaEye Interactive Demo", theme=gr.themes.Soft()) as dem
297
  inputs=[input_image],
298
  outputs=[input_image, model_output_label, state, status_text]
299
  ).then(
300
- # Save original image separately for redraw and clearing
301
  fn=lambda img: img, inputs=[input_image], outputs=[original_image_state]
302
  )
303
 
 
1
  import sys
2
  import os
3
+ import subprocess
4
  import time
5
 
6
+ # --- Dynamic Dependency Injection for HuggingFace Spaces ---
7
+ # HuggingFace ZeroGPU builder environments lack `nvcc`.
8
+ # We intercept the import and softly compile mamba-ssm using CPU-fallback PyTorch natives
9
+ # so we pass the build requirements perfectly.
10
+ try:
11
+ import mamba_ssm
12
+ import causal_conv1d
13
+ except ImportError:
14
+ print("Installing mamba_ssm and causal_conv1d in backend...", flush=True)
15
+ env = os.environ.copy()
16
+ # Bypass CUDA extensions because we don't have nvcc locally or in standard Hub build container
17
+ env["MAMBA_SKIP_CUDA_BUILD"] = "TRUE"
18
+ env["CAUSAL_CONV1D_SKIP_CUDA_BUILD"] = "TRUE"
19
+ subprocess.check_call(
20
+ [sys.executable, "-m", "pip", "install", "causal-conv1d==1.5.0.post8", "mamba-ssm==2.2.4", "--no-build-isolation"],
21
+ env=env
22
+ )
23
+
24
  # Add the cloned MambaEye repository to the Python path
25
  sys.path.append(os.path.join(os.path.dirname(__file__), "MambaEye"))
26
 
 
32
  import torchvision.transforms as T
33
  from torchvision.models import ResNet50_Weights
34
  from huggingface_hub import hf_hub_download
35
+ import spaces
36
 
37
  # MambaEye Imports
38
  from mambaeye.model import MambaEye
 
72
  try:
73
  checkpoint_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
74
  model = MambaEye(**MODEL_CONFIG)
75
+
76
+ # On zero_gpu, downloading weights might happen on CPU first
77
+ map_loc = torch.device('cpu')
78
+ model.load_state_dict(torch.load(checkpoint_path, map_location=map_loc))
79
  model.eval()
80
  _GLOBAL_MODEL = model
81
  print("Model loaded successfully.")
 
121
  return canvas, x_offset, y_offset, new_h, new_w
122
 
123
  def extract_patch(canvas_tensor, px, py):
 
 
124
  px = max(0, min(px, TARGET_CANVAS_SIZE - PATCH_SIZE))
125
  py = max(0, min(py, TARGET_CANVAS_SIZE - PATCH_SIZE))
126
  patch = canvas_tensor[:, px : px + PATCH_SIZE, py : py + PATCH_SIZE]
 
134
  ratio = min(TARGET_CANVAS_SIZE / orig_w, TARGET_CANVAS_SIZE / orig_h)
135
 
136
  for i, (px, py) in enumerate(positions):
 
137
  orig_y = (py - y_offset) / ratio
138
  orig_x = (px - x_offset) / ratio
139
  orig_px_size = PATCH_SIZE / ratio
 
157
  return {
158
  'inference_params': None,
159
  'cur_location': None,
160
+ 'canvas_tensor': canvas_tensor.cpu(),
161
  'x_offset': x_offset,
162
  'y_offset': y_offset,
163
  'h': h,
 
167
  'sequence_length': 0
168
  }
169
 
170
+ @spaces.GPU
171
  def run_auto_scan(image, scan_pattern, sequence_length):
172
  if image is None:
173
  return None, {"Upload Image": 1.0}, None, "Upload Image"
174
 
175
  model = get_model()
176
+ model.to(DEVICE)
177
+
178
  state = init_state_for_image(image)
179
 
180
  x_end = max(state['x_offset'] + 1, state['x_offset'] + state['h'])
 
190
  scan_pattern=scan_pattern, rng=rng
191
  )
192
 
 
193
  inference_params = InferenceParams(max_seqlen=4000, max_batch_size=1)
194
  state['inference_params'] = inference_params
195
 
 
202
  move_emb = _compute_move_embedding(loc_tensor, cur_location)
203
  cur_location = loc_tensor
204
 
205
+ patch = extract_patch(state['canvas_tensor'], px, py).to(DEVICE)
206
  patches_list.append(patch)
207
  moves_list.append(move_emb.squeeze(0))
208
 
209
+ img_seq = torch.stack(patches_list, dim=0).unsqueeze(0) # (1, L, 768)
210
+ move_seq = torch.stack(moves_list, dim=0).unsqueeze(0) # (1, L, 512)
211
 
212
  with torch.no_grad():
213
  out = model(img_seq, move_seq, inference_params=inference_params)
214
  final_probs = F.softmax(out[0, -1], dim=-1).cpu().numpy()
215
  inference_params.seqlen_offset += img_seq.shape[1]
216
 
217
+ state['cur_location'] = cur_location.cpu()
218
  state['drawn_positions'] = positions
219
  state['sequence_length'] = sequence_length
220
+ # On ZeroGPU spaces safely store Tensors back to CPU State
221
+ state['canvas_tensor'] = state['canvas_tensor'].cpu()
222
 
223
  img_display, _ = draw_patches_on_image(
224
  state['original_image'], state['drawn_positions'],
 
227
 
228
  return img_display, format_predictions(final_probs), state, f"Auto Scan Complete. Extracted {sequence_length} patches. Click to add more!"
229
 
230
+ @spaces.GPU
231
  def on_click(evt: gr.SelectData, original_image, state):
232
  if original_image is None:
233
  return None, {"Upload Image": 1.0}, state, "Upload Image"
234
 
235
+ model = get_model()
236
+ model.to(DEVICE)
237
+
238
  if state is None or state.get('inference_params') is None:
239
  # Initialize state to begin a new purely user-guided sequence
240
  state = init_state_for_image(original_image)
241
  state['inference_params'] = InferenceParams(max_seqlen=4000, max_batch_size=1)
 
 
242
 
243
  x_orig, y_orig = evt.index
244
  orig_h, orig_w = state['original_image'].shape[:2]
 
250
  px = (canvas_x // PATCH_SIZE) * PATCH_SIZE
251
  py = (canvas_y // PATCH_SIZE) * PATCH_SIZE
252
 
253
+ cur_loc = state['cur_location'].to(DEVICE) if state['cur_location'] is not None else None
254
  loc_tensor = torch.tensor([[px, py]], dtype=torch.long, device=DEVICE)
255
+ move_emb = _compute_move_embedding(loc_tensor, cur_loc)
256
 
257
+ patch = extract_patch(state['canvas_tensor'], px, py).to(DEVICE)
258
 
259
+ img_seq = patch.unsqueeze(0).unsqueeze(0) # (1, 1, 768)
260
+ move_seq = move_emb.unsqueeze(0) # (1, 1, 512)
261
 
262
  with torch.no_grad():
263
  out = model(img_seq, move_seq, inference_params=state['inference_params'])
264
  final_probs = F.softmax(out[0, -1], dim=-1).cpu().numpy()
265
  state['inference_params'].seqlen_offset += 1
266
 
267
+ state['cur_location'] = loc_tensor.cpu()
268
  state['drawn_positions'].append((px, py))
269
  state['sequence_length'] += 1
270
 
 
273
  state['x_offset'], state['y_offset'], state['h'], state['w']
274
  )
275
 
276
+ return img_display, format_predictions(final_probs), state, f"Added patch {state['sequence_length']} (Total {state['inference_params'].seqlen_offset} inference steps)."
277
 
278
  def on_upload(image):
279
  if image is None:
280
  return None, {"Waiting...": 1.0}, None, "Upload Image"
281
+ # Delay model load until auto-scan triggers, saving memory overhead in preloads
 
282
  return image, {"Click Auto Scan or click the image": 1.0}, None, "Ready. You can Auto Scan or click."
283
 
284
  def on_clear(original_image):
 
286
  return None, {"Cleared": 1.0}, None, "Cleared"
287
  return original_image, {"Cleared": 1.0}, init_state_for_image(original_image), "Selections cleared. Ready for new patch sequence."
288
 
 
289
  # Build the Gradio App Blocks
290
  with gr.Blocks(title="MambaEye Interactive Demo", theme=gr.themes.Soft()) as demo:
291
+ gr.Markdown("# MambaEye Interactive inference Demo")
292
+ gr.Markdown("This interface incorporates the full **MambaEye-base** model inference natively. Using **ZeroGPU** inference via PyTorch equivalents.")
293
 
294
  with gr.Row():
295
  with gr.Column(scale=2):
 
321
  inputs=[input_image],
322
  outputs=[input_image, model_output_label, state, status_text]
323
  ).then(
 
324
  fn=lambda img: img, inputs=[input_image], outputs=[original_image_state]
325
  )
326
 
requirements.txt CHANGED
@@ -6,3 +6,4 @@ torchvision==0.21.0
6
  lightning==2.6.1
7
  huggingface_hub
8
  omegaconf==2.3.0
 
 
6
  lightning==2.6.1
7
  huggingface_hub
8
  omegaconf==2.3.0
9
+ spaces