import sys import os import subprocess import time mamba_dir = os.path.join(os.path.dirname(__file__), "MambaEye") if not os.path.exists(mamba_dir) or not os.path.exists(os.path.join(mamba_dir, "mambaeye")): print("Cloning MambaEye repository from GitHub...", flush=True) if os.path.exists(mamba_dir): import shutil shutil.rmtree(mamba_dir) subprocess.check_call(["git", "clone", "https://github.com/usingcolor/MambaEye.git", mamba_dir]) try: import mamba_ssm import causal_conv1d except ImportError: print("Installing mamba_ssm and causal_conv1d in backend...", flush=True) env = os.environ.copy() env["MAMBA_SKIP_CUDA_BUILD"] = "TRUE" env["CAUSAL_CONV1D_SKIP_CUDA_BUILD"] = "TRUE" subprocess.check_call( [sys.executable, "-m", "pip", "install", "causal-conv1d==1.5.0.post8", "mamba-ssm==2.2.4", "--no-build-isolation"], env=env ) sys.path.append(os.path.join(os.path.dirname(__file__), "MambaEye")) import gradio as gr import numpy as np import torch import torch.nn.functional as F from PIL import Image, ImageDraw import torchvision.transforms as T from torchvision.models import ResNet50_Weights from huggingface_hub import hf_hub_download import spaces from mambaeye.model import MambaEye from mambaeye.scan import generate_scan_positions from mambaeye.positional_encoding import sinusoidal_position_encoding_2d from mamba_ssm.utils.generation import InferenceParams PATCH_SIZE = 16 CATEGORIES = ResNet50_Weights.IMAGENET1K_V1.meta["categories"] MODEL_CONFIG = { "num_classes": 1000, "input_dim": 1280, "dim": 256, "depth": 48, "d_state": 64, "d_conv": 4, "expand": 2, "residual_in_fp32": True, } MODEL_REPO = "usingcolor/MambaEye-base" MODEL_FILENAME = "mambaeye_base_ft.pt" # --- EAGER CPU RAM PRE-LOADING --- # Hugging Face ZeroGPU processes fork from this main thread. By cleanly executing the model structural download # precisely *before* launching the UI, all weights map deeply into the persistent physical System RAM safely. print(f"Eagerly pre-downloading {MODEL_FILENAME} from {MODEL_REPO} into static CPU RAM...", flush=True) try: CHECKPOINT_PATH = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME) _GLOBAL_CPU_MODEL = MambaEye(**MODEL_CONFIG) _GLOBAL_CPU_MODEL.load_state_dict(torch.load(CHECKPOINT_PATH, map_location="cpu", weights_only=True)) _GLOBAL_CPU_MODEL.eval() print("Model perfectly cached conceptually in System RAM! Completely zero-latency disk I/O remaining.") except Exception as e: print(f"Failed cleanly pre-loading model context: {e}") raise # --- FALLBACK CSS INJECTION --- # We use a CSS override to display a precision crosshair since custom dynamic HTML div overlays # are deeply rejected by Gradio's internal Canvas shadow properties. CSS_STYLE = """ .gradio-image-hook, .gradio-image-hook * { cursor: crosshair !important; } """ # ----------------------------- def get_model(): # As the @spaces.GPU worker natively forks off, it effortlessly snags the _GLOBAL_CPU_MODEL reference # directly passing its exact tensor parameters perfectly over exactly across PCI-e into active VRAM! device = torch.device("cuda" if torch.cuda.is_available() else "cpu") _GLOBAL_CPU_MODEL.to(device) return _GLOBAL_CPU_MODEL, device # --- FALLBACK CSS INJECTION --- # We use a CSS override to display a precision crosshair since custom dynamic HTML div overlays # are deeply rejected by Gradio's internal Canvas shadow properties. CSS_STYLE = """ .gradio-image-hook, .gradio-image-hook * { cursor: crosshair !important; } .big-accordion { border: 2px solid #e5e7eb !important; } .big-accordion button, .big-accordion .label-wrap, .big-accordion summary { font-size: 1.3em !important; padding: 12px 18px !important; font-weight: 600 !important; } """ # --- HOVER SCRIPT INJECTION --- def transfer_inference_params(params, device): if params is None or getattr(params, "key_value_memory_dict", None) is None: return params for k, v in params.key_value_memory_dict.items(): if isinstance(v, torch.Tensor): params.key_value_memory_dict[k] = v.to(device) elif isinstance(v, tuple): params.key_value_memory_dict[k] = tuple(x.to(device) if isinstance(x, torch.Tensor) else x for x in v) elif isinstance(v, list): params.key_value_memory_dict[k] = [x.to(device) if isinstance(x, torch.Tensor) else x for x in v] elif isinstance(v, dict): for k2, v2 in v.items(): if hasattr(v2, "to"): params.key_value_memory_dict[k][k2] = v2.to(device) return params def format_seq_len(seq_len): return f"
Total Sequenced Patches
{seq_len}
" def _compute_move_embedding(patch_location: torch.Tensor, cur_location: torch.Tensor = None) -> torch.Tensor: if cur_location is None: move_embedding = torch.zeros((patch_location.shape[0], 2), dtype=torch.float32, device=patch_location.device) return sinusoidal_position_encoding_2d(move_embedding, 256) return sinusoidal_position_encoding_2d((patch_location - cur_location).float(), 256) def format_predictions(probs_np): top5_idx = np.argsort(probs_np)[-5:][::-1] top5_probs = probs_np[top5_idx] result = {} for prob, idx in zip(top5_probs, top5_idx): class_name = CATEGORIES[idx].split(",")[0].title() result[class_name] = float(prob) return result def preprocess_image(image_arr): img = Image.fromarray(image_arr).convert("RGB") width, height = img.size totensor = T.ToTensor() img_tensor = totensor(img) canvas_size = max(width, height) canvas = torch.zeros(3, canvas_size, canvas_size, dtype=torch.float32) x_offset = (canvas_size - img_tensor.shape[1]) // 2 y_offset = (canvas_size - img_tensor.shape[2]) // 2 canvas[:, x_offset : x_offset + img_tensor.shape[1], y_offset : y_offset + img_tensor.shape[2]] = img_tensor return canvas, x_offset, y_offset, height, width def extract_patch(canvas_tensor, px, py): canvas_size = canvas_tensor.shape[1] px = max(0, min(px, canvas_size - PATCH_SIZE)) py = max(0, min(py, canvas_size - PATCH_SIZE)) patch = canvas_tensor[:, px : px + PATCH_SIZE, py : py + PATCH_SIZE] return patch.flatten() def draw_patches_on_image(image_arr, positions, x_offset, y_offset, h, w): img = np.array(image_arr) # Create the greyed-out ambient background grey_base = Image.fromarray(img).convert("L").convert("RGB") grey_base_np = (np.array(grey_base).astype(float) * 0.4 + 160).clip(0, 255).astype(np.uint8) temp_img = Image.fromarray(grey_base_np) orig_pil = Image.fromarray(img) draw = ImageDraw.Draw(temp_img) for i, (px, py) in enumerate(positions): orig_y = py - y_offset orig_x = px - x_offset orig_px_size = PATCH_SIZE box = (int(orig_y), int(orig_x), int(orig_y + orig_px_size), int(orig_x + orig_px_size)) # Paste original color into the highlighted region patch_crop = orig_pil.crop(box) temp_img.paste(patch_crop, box) return np.array(temp_img), positions def init_state_for_image(image): canvas_tensor, x_offset, y_offset, h, w = preprocess_image(image) return { 'inference_params': None, 'cur_location': None, 'canvas_tensor': canvas_tensor.cpu(), 'x_offset': x_offset, 'y_offset': y_offset, 'h': h, 'w': w, 'original_image': image, 'drawn_positions': [], 'sequence_length': 0 } @spaces.GPU def run_auto_scan(image, scan_pattern, sequence_length): if image is None: return None, {"Upload Image": 1.0}, None, "Upload Image" model, device = get_model() state = init_state_for_image(image) x_end = max(state['x_offset'] + 1, state['x_offset'] + state['h']) y_end = max(state['y_offset'] + 1, state['y_offset'] + state['w']) import random rng = random.Random(42) positions_xy = generate_scan_positions( x_start=state['y_offset'], x_stop=y_end, y_start=state['x_offset'], y_stop=x_end, patch_size=PATCH_SIZE, sequence_length=sequence_length, scan_pattern=scan_pattern, rng=rng ) # The scan coordinate generator effectively expects x=cols and y=rows. # We securely transpose them back to (px=row, py=col) to match our unified backend matrix structure. positions = [(py, px) for px, py in positions_xy] inference_params = InferenceParams(max_seqlen=4000, max_batch_size=1) patches_list = [] moves_list = [] cur_location = None for px, py in positions: loc_tensor = torch.tensor([[px, py]], dtype=torch.long, device=device) move_emb = _compute_move_embedding(loc_tensor, cur_location) cur_location = loc_tensor patch = extract_patch(state['canvas_tensor'], px, py).to(device) patches_list.append(patch) moves_list.append(move_emb.squeeze(0)) img_seq = torch.stack(patches_list, dim=0).unsqueeze(0) # (1, L, 768) move_seq = torch.stack(moves_list, dim=0).unsqueeze(0) # (1, L, 512) with torch.no_grad(): out = model(img_seq, move_seq, inference_params=inference_params) final_probs = F.softmax(out[0, -1], dim=-1).cpu().numpy() inference_params.seqlen_offset += img_seq.shape[1] state['cur_location'] = cur_location.cpu() state['drawn_positions'] = positions state['sequence_length'] = sequence_length state['canvas_tensor'] = state['canvas_tensor'].cpu() state['inference_params'] = transfer_inference_params(inference_params, torch.device('cpu')) img_display, _ = draw_patches_on_image( state['original_image'], state['drawn_positions'], state['x_offset'], state['y_offset'], state['h'], state['w'] ) return img_display, format_predictions(final_probs), state, f"Auto Scan Complete. Extracted {sequence_length} patches. Click to add more!", format_seq_len(sequence_length) @spaces.GPU def process_click_inference(x_orig, y_orig, original_image, state): if original_image is None: return None, {"Upload Image": 1.0}, state, "Upload Image" model, device = get_model() if state is None or state.get('inference_params') is None: state = init_state_for_image(original_image) state['inference_params'] = InferenceParams(max_seqlen=4000, max_batch_size=1) state['inference_params'] = transfer_inference_params(state['inference_params'], device) orig_h, orig_w = state['original_image'].shape[:2] canvas_size = max(orig_h, orig_w) canvas_y = int(x_orig) + state['y_offset'] canvas_x = int(y_orig) + state['x_offset'] # 1px flexible precision anchoring the patch directly onto the exact center click px = max(0, min(int(canvas_x - PATCH_SIZE / 2), canvas_size - PATCH_SIZE)) py = max(0, min(int(canvas_y - PATCH_SIZE / 2), canvas_size - PATCH_SIZE)) cur_loc = state['cur_location'].to(device) if state['cur_location'] is not None else None loc_tensor = torch.tensor([[px, py]], dtype=torch.long, device=device) move_emb = _compute_move_embedding(loc_tensor, cur_loc) patch = extract_patch(state['canvas_tensor'], px, py).to(device) img_seq = patch.unsqueeze(0).unsqueeze(0) move_seq = move_emb.unsqueeze(0) with torch.no_grad(): out = model(img_seq, move_seq, inference_params=state['inference_params']) final_probs = F.softmax(out[0, -1], dim=-1).cpu().numpy() state['inference_params'].seqlen_offset += 1 state['cur_location'] = loc_tensor.cpu() state['drawn_positions'].append((px, py)) state['sequence_length'] += 1 state['inference_params'] = transfer_inference_params(state['inference_params'], torch.device('cpu')) img_display, _ = draw_patches_on_image( state['original_image'], state['drawn_positions'], state['x_offset'], state['y_offset'], state['h'], state['w'] ) status_msg = f"๐Ÿ” Revealed patch #{state['sequence_length']}! The model is analyzing... Keep clicking to give it more clues!" return img_display, format_predictions(final_probs), state, status_msg, format_seq_len(state['sequence_length']) def on_click(evt: gr.SelectData, original_image, state): x_orig, y_orig = evt.index return process_click_inference(x_orig, y_orig, original_image, state) def on_upload(image): if image is None: return None, None, {"Waiting...": 1.0}, None, "Upload Image", 0 # Pre-render the grey background immediately on upload grey_base = Image.fromarray(image).convert("L").convert("RGB") grey_base_np = (np.array(grey_base).astype(float) * 0.4 + 160).clip(0, 255).astype(np.uint8) return grey_base_np, image, {"Click an interesting object in the photo": 1.0}, None, "โœจ Image loaded! The model is currently blind. **Click anywhere on the grey canvas** to reveal the first patch and let the model guess!", format_seq_len(0) def on_clear(original_image): if original_image is None: return None, {"Cleared": 1.0}, None, "Cleared", 0 grey_base = Image.fromarray(original_image).convert("L").convert("RGB") grey_base_np = (np.array(grey_base).astype(float) * 0.4 + 160).clip(0, 255).astype(np.uint8) return grey_base_np, {"Cleared": 1.0}, init_state_for_image(original_image), "๐Ÿงน Selections cleared! The canvas is blank. Where will you click next?", format_seq_len(0) with gr.Blocks(title="MambaEye Interactive Demo") as demo: gr.Markdown( "# MambaEye Interactive Inference Demo\n" "**๐Ÿ”— [Project Page](https://usingcolor.github.io/MambaEye) โ€ข ๐Ÿ’ป [GitHub Repository](https://github.com/usingcolor/MambaEye)**\n\n" "This interface incorporates the full **MambaEye-base-ft** model natively.\n\n" "**Note**: The first inference or Auto Scan may take **1~2 minutes** to compile CUDA kernels and build hardware cache. Subsequent patch clicks will be dramatically faster!" ) state = gr.State(None) original_image_state = gr.State(None) seq_len_display = gr.HTML(value=format_seq_len(0), render=False) model_output_label = gr.Label(label="MambaEye Output Predictions", num_top_classes=5, render=False) status_text = gr.Markdown("Status: Waiting for image upload...", render=False) with gr.Row(): with gr.Column(scale=2): gr.Markdown("### ๐ŸŽฏ Challenge: See how few clicks the model needs to guess your image!\nClick directly on the most informative parts of the grey image to reveal patches to the model.") input_image = gr.Image(type="numpy", label="๐Ÿ‘† Interactive Canvas: Click here to extract patches!", interactive=True, elem_classes="gradio-image-hook") clear_btn = gr.Button("๐Ÿ—‘๏ธ Clear Selections & Start Over", variant="secondary") with gr.Accordion("๐Ÿค– Advanced: Auto-Scan Features", open=False, elem_classes="big-accordion"): gr.Markdown("### โœจ Let the model automatically scan a sequence of patches!") with gr.Row(): scan_pattern = gr.Dropdown( choices=["random", "spiral", "diagonal", "golden", "horizontal_raster", "horizontal_zigzag", "column_major", "column_snake"], value="random", label="Scan Pattern" ) seq_length = gr.Slider(minimum=1, maximum=4096, step=1, value=256, label="Auto Sequence Length") auto_btn = gr.Button("Auto Generate Path & Infer", variant="primary") gr.Examples( examples=[ "assets/dog.jpg", "assets/leo.jpg", "assets/green_mamba.jpg", ], inputs=input_image, outputs=[input_image, original_image_state, model_output_label, state, status_text, seq_len_display], fn=on_upload, run_on_click=True, cache_examples=False, label="Try an Example Image" ) with gr.Column(scale=1): seq_len_display.render() model_output_label.render() status_text.render() input_image.upload( fn=on_upload, inputs=[input_image], outputs=[input_image, original_image_state, model_output_label, state, status_text, seq_len_display] ) auto_btn.click( fn=run_auto_scan, inputs=[original_image_state, scan_pattern, seq_length], outputs=[input_image, model_output_label, state, status_text, seq_len_display] ) input_image.select( fn=on_click, inputs=[original_image_state, state], outputs=[input_image, model_output_label, state, status_text, seq_len_display] ) clear_btn.click( fn=on_clear, inputs=[original_image_state], outputs=[input_image, model_output_label, state, status_text, seq_len_display] ) if __name__ == "__main__": demo.launch(theme=gr.themes.Soft(), ssr_mode=False, css=CSS_STYLE)