MambaEye / app.py
usingcolor's picture
chore: initialize git repository and add standard hook templates
ed21522
import sys
import os
import subprocess
import time
mamba_dir = os.path.join(os.path.dirname(__file__), "MambaEye")
if not os.path.exists(mamba_dir) or not os.path.exists(os.path.join(mamba_dir, "mambaeye")):
print("Cloning MambaEye repository from GitHub...", flush=True)
if os.path.exists(mamba_dir):
import shutil
shutil.rmtree(mamba_dir)
subprocess.check_call(["git", "clone", "https://github.com/usingcolor/MambaEye.git", mamba_dir])
try:
import mamba_ssm
import causal_conv1d
except ImportError:
print("Installing mamba_ssm and causal_conv1d in backend...", flush=True)
env = os.environ.copy()
env["MAMBA_SKIP_CUDA_BUILD"] = "TRUE"
env["CAUSAL_CONV1D_SKIP_CUDA_BUILD"] = "TRUE"
subprocess.check_call(
[sys.executable, "-m", "pip", "install", "causal-conv1d==1.5.0.post8", "mamba-ssm==2.2.4", "--no-build-isolation"],
env=env
)
sys.path.append(os.path.join(os.path.dirname(__file__), "MambaEye"))
import gradio as gr
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image, ImageDraw
import torchvision.transforms as T
from torchvision.models import ResNet50_Weights
from huggingface_hub import hf_hub_download
import spaces
from mambaeye.model import MambaEye
from mambaeye.scan import generate_scan_positions
from mambaeye.positional_encoding import sinusoidal_position_encoding_2d
from mamba_ssm.utils.generation import InferenceParams
PATCH_SIZE = 16
CATEGORIES = ResNet50_Weights.IMAGENET1K_V1.meta["categories"]
MODEL_CONFIG = {
"num_classes": 1000,
"input_dim": 1280,
"dim": 256,
"depth": 48,
"d_state": 64,
"d_conv": 4,
"expand": 2,
"residual_in_fp32": True,
}
MODEL_REPO = "usingcolor/MambaEye-base"
MODEL_FILENAME = "mambaeye_base_ft.pt"
# --- EAGER CPU RAM PRE-LOADING ---
# Hugging Face ZeroGPU processes fork from this main thread. By cleanly executing the model structural download
# precisely *before* launching the UI, all weights map deeply into the persistent physical System RAM safely.
print(f"Eagerly pre-downloading {MODEL_FILENAME} from {MODEL_REPO} into static CPU RAM...", flush=True)
try:
CHECKPOINT_PATH = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
_GLOBAL_CPU_MODEL = MambaEye(**MODEL_CONFIG)
_GLOBAL_CPU_MODEL.load_state_dict(torch.load(CHECKPOINT_PATH, map_location="cpu", weights_only=True))
_GLOBAL_CPU_MODEL.eval()
print("Model perfectly cached conceptually in System RAM! Completely zero-latency disk I/O remaining.")
except Exception as e:
print(f"Failed cleanly pre-loading model context: {e}")
raise
# --- FALLBACK CSS INJECTION ---
# We use a CSS override to display a precision crosshair since custom dynamic HTML div overlays
# are deeply rejected by Gradio's internal Canvas shadow properties.
CSS_STYLE = """
.gradio-image-hook, .gradio-image-hook * {
cursor: crosshair !important;
}
"""
# -----------------------------
def get_model():
# As the @spaces.GPU worker natively forks off, it effortlessly snags the _GLOBAL_CPU_MODEL reference
# directly passing its exact tensor parameters perfectly over exactly across PCI-e into active VRAM!
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_GLOBAL_CPU_MODEL.to(device)
return _GLOBAL_CPU_MODEL, device
# --- FALLBACK CSS INJECTION ---
# We use a CSS override to display a precision crosshair since custom dynamic HTML div overlays
# are deeply rejected by Gradio's internal Canvas shadow properties.
CSS_STYLE = """
.gradio-image-hook, .gradio-image-hook * {
cursor: crosshair !important;
}
.big-accordion {
border: 2px solid #e5e7eb !important;
}
.big-accordion button, .big-accordion .label-wrap, .big-accordion summary {
font-size: 1.3em !important;
padding: 12px 18px !important;
font-weight: 600 !important;
}
"""
# --- HOVER SCRIPT INJECTION ---
def transfer_inference_params(params, device):
if params is None or getattr(params, "key_value_memory_dict", None) is None:
return params
for k, v in params.key_value_memory_dict.items():
if isinstance(v, torch.Tensor):
params.key_value_memory_dict[k] = v.to(device)
elif isinstance(v, tuple):
params.key_value_memory_dict[k] = tuple(x.to(device) if isinstance(x, torch.Tensor) else x for x in v)
elif isinstance(v, list):
params.key_value_memory_dict[k] = [x.to(device) if isinstance(x, torch.Tensor) else x for x in v]
elif isinstance(v, dict):
for k2, v2 in v.items():
if hasattr(v2, "to"):
params.key_value_memory_dict[k][k2] = v2.to(device)
return params
def format_seq_len(seq_len):
return f"<div style='text-align: center; border: 1px solid #e5e7eb; border-radius: 8px; padding: 10px; margin-bottom: 10px; background-color: #f9fafb;'><span style='font-size: 1.1em; color: #6b7280;'>Total Sequenced Patches</span><br><span style='font-size: 3em; font-weight: bold; color: #3b82f6;'>{seq_len}</span></div>"
def _compute_move_embedding(patch_location: torch.Tensor, cur_location: torch.Tensor = None) -> torch.Tensor:
if cur_location is None:
move_embedding = torch.zeros((patch_location.shape[0], 2), dtype=torch.float32, device=patch_location.device)
return sinusoidal_position_encoding_2d(move_embedding, 256)
return sinusoidal_position_encoding_2d((patch_location - cur_location).float(), 256)
def format_predictions(probs_np):
top5_idx = np.argsort(probs_np)[-5:][::-1]
top5_probs = probs_np[top5_idx]
result = {}
for prob, idx in zip(top5_probs, top5_idx):
class_name = CATEGORIES[idx].split(",")[0].title()
result[class_name] = float(prob)
return result
def preprocess_image(image_arr):
img = Image.fromarray(image_arr).convert("RGB")
width, height = img.size
totensor = T.ToTensor()
img_tensor = totensor(img)
canvas_size = max(width, height)
canvas = torch.zeros(3, canvas_size, canvas_size, dtype=torch.float32)
x_offset = (canvas_size - img_tensor.shape[1]) // 2
y_offset = (canvas_size - img_tensor.shape[2]) // 2
canvas[:, x_offset : x_offset + img_tensor.shape[1], y_offset : y_offset + img_tensor.shape[2]] = img_tensor
return canvas, x_offset, y_offset, height, width
def extract_patch(canvas_tensor, px, py):
canvas_size = canvas_tensor.shape[1]
px = max(0, min(px, canvas_size - PATCH_SIZE))
py = max(0, min(py, canvas_size - PATCH_SIZE))
patch = canvas_tensor[:, px : px + PATCH_SIZE, py : py + PATCH_SIZE]
return patch.flatten()
def draw_patches_on_image(image_arr, positions, x_offset, y_offset, h, w):
img = np.array(image_arr)
# Create the greyed-out ambient background
grey_base = Image.fromarray(img).convert("L").convert("RGB")
grey_base_np = (np.array(grey_base).astype(float) * 0.4 + 160).clip(0, 255).astype(np.uint8)
temp_img = Image.fromarray(grey_base_np)
orig_pil = Image.fromarray(img)
draw = ImageDraw.Draw(temp_img)
for i, (px, py) in enumerate(positions):
orig_y = py - y_offset
orig_x = px - x_offset
orig_px_size = PATCH_SIZE
box = (int(orig_y), int(orig_x), int(orig_y + orig_px_size), int(orig_x + orig_px_size))
# Paste original color into the highlighted region
patch_crop = orig_pil.crop(box)
temp_img.paste(patch_crop, box)
return np.array(temp_img), positions
def init_state_for_image(image):
canvas_tensor, x_offset, y_offset, h, w = preprocess_image(image)
return {
'inference_params': None,
'cur_location': None,
'canvas_tensor': canvas_tensor.cpu(),
'x_offset': x_offset,
'y_offset': y_offset,
'h': h,
'w': w,
'original_image': image,
'drawn_positions': [],
'sequence_length': 0
}
@spaces.GPU
def run_auto_scan(image, scan_pattern, sequence_length):
if image is None:
return None, {"Upload Image": 1.0}, None, "Upload Image"
model, device = get_model()
state = init_state_for_image(image)
x_end = max(state['x_offset'] + 1, state['x_offset'] + state['h'])
y_end = max(state['y_offset'] + 1, state['y_offset'] + state['w'])
import random
rng = random.Random(42)
positions_xy = generate_scan_positions(
x_start=state['y_offset'], x_stop=y_end,
y_start=state['x_offset'], y_stop=x_end,
patch_size=PATCH_SIZE, sequence_length=sequence_length,
scan_pattern=scan_pattern, rng=rng
)
# The scan coordinate generator effectively expects x=cols and y=rows.
# We securely transpose them back to (px=row, py=col) to match our unified backend matrix structure.
positions = [(py, px) for px, py in positions_xy]
inference_params = InferenceParams(max_seqlen=4000, max_batch_size=1)
patches_list = []
moves_list = []
cur_location = None
for px, py in positions:
loc_tensor = torch.tensor([[px, py]], dtype=torch.long, device=device)
move_emb = _compute_move_embedding(loc_tensor, cur_location)
cur_location = loc_tensor
patch = extract_patch(state['canvas_tensor'], px, py).to(device)
patches_list.append(patch)
moves_list.append(move_emb.squeeze(0))
img_seq = torch.stack(patches_list, dim=0).unsqueeze(0) # (1, L, 768)
move_seq = torch.stack(moves_list, dim=0).unsqueeze(0) # (1, L, 512)
with torch.no_grad():
out = model(img_seq, move_seq, inference_params=inference_params)
final_probs = F.softmax(out[0, -1], dim=-1).cpu().numpy()
inference_params.seqlen_offset += img_seq.shape[1]
state['cur_location'] = cur_location.cpu()
state['drawn_positions'] = positions
state['sequence_length'] = sequence_length
state['canvas_tensor'] = state['canvas_tensor'].cpu()
state['inference_params'] = transfer_inference_params(inference_params, torch.device('cpu'))
img_display, _ = draw_patches_on_image(
state['original_image'], state['drawn_positions'],
state['x_offset'], state['y_offset'], state['h'], state['w']
)
return img_display, format_predictions(final_probs), state, f"Auto Scan Complete. Extracted {sequence_length} patches. Click to add more!", format_seq_len(sequence_length)
@spaces.GPU
def process_click_inference(x_orig, y_orig, original_image, state):
if original_image is None:
return None, {"Upload Image": 1.0}, state, "Upload Image"
model, device = get_model()
if state is None or state.get('inference_params') is None:
state = init_state_for_image(original_image)
state['inference_params'] = InferenceParams(max_seqlen=4000, max_batch_size=1)
state['inference_params'] = transfer_inference_params(state['inference_params'], device)
orig_h, orig_w = state['original_image'].shape[:2]
canvas_size = max(orig_h, orig_w)
canvas_y = int(x_orig) + state['y_offset']
canvas_x = int(y_orig) + state['x_offset']
# 1px flexible precision anchoring the patch directly onto the exact center click
px = max(0, min(int(canvas_x - PATCH_SIZE / 2), canvas_size - PATCH_SIZE))
py = max(0, min(int(canvas_y - PATCH_SIZE / 2), canvas_size - PATCH_SIZE))
cur_loc = state['cur_location'].to(device) if state['cur_location'] is not None else None
loc_tensor = torch.tensor([[px, py]], dtype=torch.long, device=device)
move_emb = _compute_move_embedding(loc_tensor, cur_loc)
patch = extract_patch(state['canvas_tensor'], px, py).to(device)
img_seq = patch.unsqueeze(0).unsqueeze(0)
move_seq = move_emb.unsqueeze(0)
with torch.no_grad():
out = model(img_seq, move_seq, inference_params=state['inference_params'])
final_probs = F.softmax(out[0, -1], dim=-1).cpu().numpy()
state['inference_params'].seqlen_offset += 1
state['cur_location'] = loc_tensor.cpu()
state['drawn_positions'].append((px, py))
state['sequence_length'] += 1
state['inference_params'] = transfer_inference_params(state['inference_params'], torch.device('cpu'))
img_display, _ = draw_patches_on_image(
state['original_image'], state['drawn_positions'],
state['x_offset'], state['y_offset'], state['h'], state['w']
)
status_msg = f"πŸ” Revealed patch #{state['sequence_length']}! The model is analyzing... Keep clicking to give it more clues!"
return img_display, format_predictions(final_probs), state, status_msg, format_seq_len(state['sequence_length'])
def on_click(evt: gr.SelectData, original_image, state):
x_orig, y_orig = evt.index
return process_click_inference(x_orig, y_orig, original_image, state)
def on_upload(image):
if image is None:
return None, None, {"Waiting...": 1.0}, None, "Upload Image", 0
# Pre-render the grey background immediately on upload
grey_base = Image.fromarray(image).convert("L").convert("RGB")
grey_base_np = (np.array(grey_base).astype(float) * 0.4 + 160).clip(0, 255).astype(np.uint8)
return grey_base_np, image, {"Click an interesting object in the photo": 1.0}, None, "✨ Image loaded! The model is currently blind. **Click anywhere on the grey canvas** to reveal the first patch and let the model guess!", format_seq_len(0)
def on_clear(original_image):
if original_image is None:
return None, {"Cleared": 1.0}, None, "Cleared", 0
grey_base = Image.fromarray(original_image).convert("L").convert("RGB")
grey_base_np = (np.array(grey_base).astype(float) * 0.4 + 160).clip(0, 255).astype(np.uint8)
return grey_base_np, {"Cleared": 1.0}, init_state_for_image(original_image), "🧹 Selections cleared! The canvas is blank. Where will you click next?", format_seq_len(0)
with gr.Blocks(title="MambaEye Interactive Demo") as demo:
gr.Markdown(
"# MambaEye Interactive Inference Demo\n"
"**πŸ”— [Project Page](https://usingcolor.github.io/MambaEye) β€’ πŸ’» [GitHub Repository](https://github.com/usingcolor/MambaEye)**\n\n"
"This interface incorporates the full **MambaEye-base-ft** model natively.\n\n"
"**Note**: The first inference or Auto Scan may take **1~2 minutes** to compile CUDA kernels and build hardware cache. Subsequent patch clicks will be dramatically faster!"
)
state = gr.State(None)
original_image_state = gr.State(None)
seq_len_display = gr.HTML(value=format_seq_len(0), render=False)
model_output_label = gr.Label(label="MambaEye Output Predictions", num_top_classes=5, render=False)
status_text = gr.Markdown("Status: Waiting for image upload...", render=False)
with gr.Row():
with gr.Column(scale=2):
gr.Markdown("### 🎯 Challenge: See how few clicks the model needs to guess your image!\nClick directly on the most informative parts of the grey image to reveal patches to the model.")
input_image = gr.Image(type="numpy", label="πŸ‘† Interactive Canvas: Click here to extract patches!", interactive=True, elem_classes="gradio-image-hook")
clear_btn = gr.Button("πŸ—‘οΈ Clear Selections & Start Over", variant="secondary")
with gr.Accordion("πŸ€– Advanced: Auto-Scan Features", open=False, elem_classes="big-accordion"):
gr.Markdown("### ✨ Let the model automatically scan a sequence of patches!")
with gr.Row():
scan_pattern = gr.Dropdown(
choices=["random", "spiral", "diagonal", "golden", "horizontal_raster", "horizontal_zigzag", "column_major", "column_snake"],
value="random",
label="Scan Pattern"
)
seq_length = gr.Slider(minimum=1, maximum=4096, step=1, value=256, label="Auto Sequence Length")
auto_btn = gr.Button("Auto Generate Path & Infer", variant="primary")
gr.Examples(
examples=[
"assets/dog.jpg",
"assets/leo.jpg",
"assets/green_mamba.jpg",
],
inputs=input_image,
outputs=[input_image, original_image_state, model_output_label, state, status_text, seq_len_display],
fn=on_upload,
run_on_click=True,
cache_examples=False,
label="Try an Example Image"
)
with gr.Column(scale=1):
seq_len_display.render()
model_output_label.render()
status_text.render()
input_image.upload(
fn=on_upload,
inputs=[input_image],
outputs=[input_image, original_image_state, model_output_label, state, status_text, seq_len_display]
)
auto_btn.click(
fn=run_auto_scan,
inputs=[original_image_state, scan_pattern, seq_length],
outputs=[input_image, model_output_label, state, status_text, seq_len_display]
)
input_image.select(
fn=on_click,
inputs=[original_image_state, state],
outputs=[input_image, model_output_label, state, status_text, seq_len_display]
)
clear_btn.click(
fn=on_clear,
inputs=[original_image_state],
outputs=[input_image, model_output_label, state, status_text, seq_len_display]
)
if __name__ == "__main__":
demo.launch(theme=gr.themes.Soft(), ssr_mode=False, css=CSS_STYLE)