Spaces:
Running on Zero
Running on Zero
| import sys | |
| import os | |
| import subprocess | |
| import time | |
| mamba_dir = os.path.join(os.path.dirname(__file__), "MambaEye") | |
| if not os.path.exists(mamba_dir) or not os.path.exists(os.path.join(mamba_dir, "mambaeye")): | |
| print("Cloning MambaEye repository from GitHub...", flush=True) | |
| if os.path.exists(mamba_dir): | |
| import shutil | |
| shutil.rmtree(mamba_dir) | |
| subprocess.check_call(["git", "clone", "https://github.com/usingcolor/MambaEye.git", mamba_dir]) | |
| try: | |
| import mamba_ssm | |
| import causal_conv1d | |
| except ImportError: | |
| print("Installing mamba_ssm and causal_conv1d in backend...", flush=True) | |
| env = os.environ.copy() | |
| env["MAMBA_SKIP_CUDA_BUILD"] = "TRUE" | |
| env["CAUSAL_CONV1D_SKIP_CUDA_BUILD"] = "TRUE" | |
| subprocess.check_call( | |
| [sys.executable, "-m", "pip", "install", "causal-conv1d==1.5.0.post8", "mamba-ssm==2.2.4", "--no-build-isolation"], | |
| env=env | |
| ) | |
| sys.path.append(os.path.join(os.path.dirname(__file__), "MambaEye")) | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image, ImageDraw | |
| import torchvision.transforms as T | |
| from torchvision.models import ResNet50_Weights | |
| from huggingface_hub import hf_hub_download | |
| import spaces | |
| from mambaeye.model import MambaEye | |
| from mambaeye.scan import generate_scan_positions | |
| from mambaeye.positional_encoding import sinusoidal_position_encoding_2d | |
| from mamba_ssm.utils.generation import InferenceParams | |
| PATCH_SIZE = 16 | |
| CATEGORIES = ResNet50_Weights.IMAGENET1K_V1.meta["categories"] | |
| MODEL_CONFIG = { | |
| "num_classes": 1000, | |
| "input_dim": 1280, | |
| "dim": 256, | |
| "depth": 48, | |
| "d_state": 64, | |
| "d_conv": 4, | |
| "expand": 2, | |
| "residual_in_fp32": True, | |
| } | |
| MODEL_REPO = "usingcolor/MambaEye-base" | |
| MODEL_FILENAME = "mambaeye_base_ft.pt" | |
| # --- EAGER CPU RAM PRE-LOADING --- | |
| # Hugging Face ZeroGPU processes fork from this main thread. By cleanly executing the model structural download | |
| # precisely *before* launching the UI, all weights map deeply into the persistent physical System RAM safely. | |
| print(f"Eagerly pre-downloading {MODEL_FILENAME} from {MODEL_REPO} into static CPU RAM...", flush=True) | |
| try: | |
| CHECKPOINT_PATH = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME) | |
| _GLOBAL_CPU_MODEL = MambaEye(**MODEL_CONFIG) | |
| _GLOBAL_CPU_MODEL.load_state_dict(torch.load(CHECKPOINT_PATH, map_location="cpu", weights_only=True)) | |
| _GLOBAL_CPU_MODEL.eval() | |
| print("Model perfectly cached conceptually in System RAM! Completely zero-latency disk I/O remaining.") | |
| except Exception as e: | |
| print(f"Failed cleanly pre-loading model context: {e}") | |
| raise | |
| # --- FALLBACK CSS INJECTION --- | |
| # We use a CSS override to display a precision crosshair since custom dynamic HTML div overlays | |
| # are deeply rejected by Gradio's internal Canvas shadow properties. | |
| CSS_STYLE = """ | |
| .gradio-image-hook, .gradio-image-hook * { | |
| cursor: crosshair !important; | |
| } | |
| """ | |
| # ----------------------------- | |
| def get_model(): | |
| # As the @spaces.GPU worker natively forks off, it effortlessly snags the _GLOBAL_CPU_MODEL reference | |
| # directly passing its exact tensor parameters perfectly over exactly across PCI-e into active VRAM! | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| _GLOBAL_CPU_MODEL.to(device) | |
| return _GLOBAL_CPU_MODEL, device | |
| # --- FALLBACK CSS INJECTION --- | |
| # We use a CSS override to display a precision crosshair since custom dynamic HTML div overlays | |
| # are deeply rejected by Gradio's internal Canvas shadow properties. | |
| CSS_STYLE = """ | |
| .gradio-image-hook, .gradio-image-hook * { | |
| cursor: crosshair !important; | |
| } | |
| .big-accordion { | |
| border: 2px solid #e5e7eb !important; | |
| } | |
| .big-accordion button, .big-accordion .label-wrap, .big-accordion summary { | |
| font-size: 1.3em !important; | |
| padding: 12px 18px !important; | |
| font-weight: 600 !important; | |
| } | |
| """ | |
| # --- HOVER SCRIPT INJECTION --- | |
| def transfer_inference_params(params, device): | |
| if params is None or getattr(params, "key_value_memory_dict", None) is None: | |
| return params | |
| for k, v in params.key_value_memory_dict.items(): | |
| if isinstance(v, torch.Tensor): | |
| params.key_value_memory_dict[k] = v.to(device) | |
| elif isinstance(v, tuple): | |
| params.key_value_memory_dict[k] = tuple(x.to(device) if isinstance(x, torch.Tensor) else x for x in v) | |
| elif isinstance(v, list): | |
| params.key_value_memory_dict[k] = [x.to(device) if isinstance(x, torch.Tensor) else x for x in v] | |
| elif isinstance(v, dict): | |
| for k2, v2 in v.items(): | |
| if hasattr(v2, "to"): | |
| params.key_value_memory_dict[k][k2] = v2.to(device) | |
| return params | |
| def format_seq_len(seq_len): | |
| return f"<div style='text-align: center; border: 1px solid #e5e7eb; border-radius: 8px; padding: 10px; margin-bottom: 10px; background-color: #f9fafb;'><span style='font-size: 1.1em; color: #6b7280;'>Total Sequenced Patches</span><br><span style='font-size: 3em; font-weight: bold; color: #3b82f6;'>{seq_len}</span></div>" | |
| def _compute_move_embedding(patch_location: torch.Tensor, cur_location: torch.Tensor = None) -> torch.Tensor: | |
| if cur_location is None: | |
| move_embedding = torch.zeros((patch_location.shape[0], 2), dtype=torch.float32, device=patch_location.device) | |
| return sinusoidal_position_encoding_2d(move_embedding, 256) | |
| return sinusoidal_position_encoding_2d((patch_location - cur_location).float(), 256) | |
| def format_predictions(probs_np): | |
| top5_idx = np.argsort(probs_np)[-5:][::-1] | |
| top5_probs = probs_np[top5_idx] | |
| result = {} | |
| for prob, idx in zip(top5_probs, top5_idx): | |
| class_name = CATEGORIES[idx].split(",")[0].title() | |
| result[class_name] = float(prob) | |
| return result | |
| def preprocess_image(image_arr): | |
| img = Image.fromarray(image_arr).convert("RGB") | |
| width, height = img.size | |
| totensor = T.ToTensor() | |
| img_tensor = totensor(img) | |
| canvas_size = max(width, height) | |
| canvas = torch.zeros(3, canvas_size, canvas_size, dtype=torch.float32) | |
| x_offset = (canvas_size - img_tensor.shape[1]) // 2 | |
| y_offset = (canvas_size - img_tensor.shape[2]) // 2 | |
| canvas[:, x_offset : x_offset + img_tensor.shape[1], y_offset : y_offset + img_tensor.shape[2]] = img_tensor | |
| return canvas, x_offset, y_offset, height, width | |
| def extract_patch(canvas_tensor, px, py): | |
| canvas_size = canvas_tensor.shape[1] | |
| px = max(0, min(px, canvas_size - PATCH_SIZE)) | |
| py = max(0, min(py, canvas_size - PATCH_SIZE)) | |
| patch = canvas_tensor[:, px : px + PATCH_SIZE, py : py + PATCH_SIZE] | |
| return patch.flatten() | |
| def draw_patches_on_image(image_arr, positions, x_offset, y_offset, h, w): | |
| img = np.array(image_arr) | |
| # Create the greyed-out ambient background | |
| grey_base = Image.fromarray(img).convert("L").convert("RGB") | |
| grey_base_np = (np.array(grey_base).astype(float) * 0.4 + 160).clip(0, 255).astype(np.uint8) | |
| temp_img = Image.fromarray(grey_base_np) | |
| orig_pil = Image.fromarray(img) | |
| draw = ImageDraw.Draw(temp_img) | |
| for i, (px, py) in enumerate(positions): | |
| orig_y = py - y_offset | |
| orig_x = px - x_offset | |
| orig_px_size = PATCH_SIZE | |
| box = (int(orig_y), int(orig_x), int(orig_y + orig_px_size), int(orig_x + orig_px_size)) | |
| # Paste original color into the highlighted region | |
| patch_crop = orig_pil.crop(box) | |
| temp_img.paste(patch_crop, box) | |
| return np.array(temp_img), positions | |
| def init_state_for_image(image): | |
| canvas_tensor, x_offset, y_offset, h, w = preprocess_image(image) | |
| return { | |
| 'inference_params': None, | |
| 'cur_location': None, | |
| 'canvas_tensor': canvas_tensor.cpu(), | |
| 'x_offset': x_offset, | |
| 'y_offset': y_offset, | |
| 'h': h, | |
| 'w': w, | |
| 'original_image': image, | |
| 'drawn_positions': [], | |
| 'sequence_length': 0 | |
| } | |
| def run_auto_scan(image, scan_pattern, sequence_length): | |
| if image is None: | |
| return None, {"Upload Image": 1.0}, None, "Upload Image" | |
| model, device = get_model() | |
| state = init_state_for_image(image) | |
| x_end = max(state['x_offset'] + 1, state['x_offset'] + state['h']) | |
| y_end = max(state['y_offset'] + 1, state['y_offset'] + state['w']) | |
| import random | |
| rng = random.Random(42) | |
| positions_xy = generate_scan_positions( | |
| x_start=state['y_offset'], x_stop=y_end, | |
| y_start=state['x_offset'], y_stop=x_end, | |
| patch_size=PATCH_SIZE, sequence_length=sequence_length, | |
| scan_pattern=scan_pattern, rng=rng | |
| ) | |
| # The scan coordinate generator effectively expects x=cols and y=rows. | |
| # We securely transpose them back to (px=row, py=col) to match our unified backend matrix structure. | |
| positions = [(py, px) for px, py in positions_xy] | |
| inference_params = InferenceParams(max_seqlen=4000, max_batch_size=1) | |
| patches_list = [] | |
| moves_list = [] | |
| cur_location = None | |
| for px, py in positions: | |
| loc_tensor = torch.tensor([[px, py]], dtype=torch.long, device=device) | |
| move_emb = _compute_move_embedding(loc_tensor, cur_location) | |
| cur_location = loc_tensor | |
| patch = extract_patch(state['canvas_tensor'], px, py).to(device) | |
| patches_list.append(patch) | |
| moves_list.append(move_emb.squeeze(0)) | |
| img_seq = torch.stack(patches_list, dim=0).unsqueeze(0) # (1, L, 768) | |
| move_seq = torch.stack(moves_list, dim=0).unsqueeze(0) # (1, L, 512) | |
| with torch.no_grad(): | |
| out = model(img_seq, move_seq, inference_params=inference_params) | |
| final_probs = F.softmax(out[0, -1], dim=-1).cpu().numpy() | |
| inference_params.seqlen_offset += img_seq.shape[1] | |
| state['cur_location'] = cur_location.cpu() | |
| state['drawn_positions'] = positions | |
| state['sequence_length'] = sequence_length | |
| state['canvas_tensor'] = state['canvas_tensor'].cpu() | |
| state['inference_params'] = transfer_inference_params(inference_params, torch.device('cpu')) | |
| img_display, _ = draw_patches_on_image( | |
| state['original_image'], state['drawn_positions'], | |
| state['x_offset'], state['y_offset'], state['h'], state['w'] | |
| ) | |
| return img_display, format_predictions(final_probs), state, f"Auto Scan Complete. Extracted {sequence_length} patches. Click to add more!", format_seq_len(sequence_length) | |
| def process_click_inference(x_orig, y_orig, original_image, state): | |
| if original_image is None: | |
| return None, {"Upload Image": 1.0}, state, "Upload Image" | |
| model, device = get_model() | |
| if state is None or state.get('inference_params') is None: | |
| state = init_state_for_image(original_image) | |
| state['inference_params'] = InferenceParams(max_seqlen=4000, max_batch_size=1) | |
| state['inference_params'] = transfer_inference_params(state['inference_params'], device) | |
| orig_h, orig_w = state['original_image'].shape[:2] | |
| canvas_size = max(orig_h, orig_w) | |
| canvas_y = int(x_orig) + state['y_offset'] | |
| canvas_x = int(y_orig) + state['x_offset'] | |
| # 1px flexible precision anchoring the patch directly onto the exact center click | |
| px = max(0, min(int(canvas_x - PATCH_SIZE / 2), canvas_size - PATCH_SIZE)) | |
| py = max(0, min(int(canvas_y - PATCH_SIZE / 2), canvas_size - PATCH_SIZE)) | |
| cur_loc = state['cur_location'].to(device) if state['cur_location'] is not None else None | |
| loc_tensor = torch.tensor([[px, py]], dtype=torch.long, device=device) | |
| move_emb = _compute_move_embedding(loc_tensor, cur_loc) | |
| patch = extract_patch(state['canvas_tensor'], px, py).to(device) | |
| img_seq = patch.unsqueeze(0).unsqueeze(0) | |
| move_seq = move_emb.unsqueeze(0) | |
| with torch.no_grad(): | |
| out = model(img_seq, move_seq, inference_params=state['inference_params']) | |
| final_probs = F.softmax(out[0, -1], dim=-1).cpu().numpy() | |
| state['inference_params'].seqlen_offset += 1 | |
| state['cur_location'] = loc_tensor.cpu() | |
| state['drawn_positions'].append((px, py)) | |
| state['sequence_length'] += 1 | |
| state['inference_params'] = transfer_inference_params(state['inference_params'], torch.device('cpu')) | |
| img_display, _ = draw_patches_on_image( | |
| state['original_image'], state['drawn_positions'], | |
| state['x_offset'], state['y_offset'], state['h'], state['w'] | |
| ) | |
| status_msg = f"π Revealed patch #{state['sequence_length']}! The model is analyzing... Keep clicking to give it more clues!" | |
| return img_display, format_predictions(final_probs), state, status_msg, format_seq_len(state['sequence_length']) | |
| def on_click(evt: gr.SelectData, original_image, state): | |
| x_orig, y_orig = evt.index | |
| return process_click_inference(x_orig, y_orig, original_image, state) | |
| def on_upload(image): | |
| if image is None: | |
| return None, None, {"Waiting...": 1.0}, None, "Upload Image", 0 | |
| # Pre-render the grey background immediately on upload | |
| grey_base = Image.fromarray(image).convert("L").convert("RGB") | |
| grey_base_np = (np.array(grey_base).astype(float) * 0.4 + 160).clip(0, 255).astype(np.uint8) | |
| return grey_base_np, image, {"Click an interesting object in the photo": 1.0}, None, "β¨ Image loaded! The model is currently blind. **Click anywhere on the grey canvas** to reveal the first patch and let the model guess!", format_seq_len(0) | |
| def on_clear(original_image): | |
| if original_image is None: | |
| return None, {"Cleared": 1.0}, None, "Cleared", 0 | |
| grey_base = Image.fromarray(original_image).convert("L").convert("RGB") | |
| grey_base_np = (np.array(grey_base).astype(float) * 0.4 + 160).clip(0, 255).astype(np.uint8) | |
| return grey_base_np, {"Cleared": 1.0}, init_state_for_image(original_image), "π§Ή Selections cleared! The canvas is blank. Where will you click next?", format_seq_len(0) | |
| with gr.Blocks(title="MambaEye Interactive Demo") as demo: | |
| gr.Markdown( | |
| "# MambaEye Interactive Inference Demo\n" | |
| "**π [Project Page](https://usingcolor.github.io/MambaEye) β’ π» [GitHub Repository](https://github.com/usingcolor/MambaEye)**\n\n" | |
| "This interface incorporates the full **MambaEye-base-ft** model natively.\n\n" | |
| "**Note**: The first inference or Auto Scan may take **1~2 minutes** to compile CUDA kernels and build hardware cache. Subsequent patch clicks will be dramatically faster!" | |
| ) | |
| state = gr.State(None) | |
| original_image_state = gr.State(None) | |
| seq_len_display = gr.HTML(value=format_seq_len(0), render=False) | |
| model_output_label = gr.Label(label="MambaEye Output Predictions", num_top_classes=5, render=False) | |
| status_text = gr.Markdown("Status: Waiting for image upload...", render=False) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| gr.Markdown("### π― Challenge: See how few clicks the model needs to guess your image!\nClick directly on the most informative parts of the grey image to reveal patches to the model.") | |
| input_image = gr.Image(type="numpy", label="π Interactive Canvas: Click here to extract patches!", interactive=True, elem_classes="gradio-image-hook") | |
| clear_btn = gr.Button("ποΈ Clear Selections & Start Over", variant="secondary") | |
| with gr.Accordion("π€ Advanced: Auto-Scan Features", open=False, elem_classes="big-accordion"): | |
| gr.Markdown("### β¨ Let the model automatically scan a sequence of patches!") | |
| with gr.Row(): | |
| scan_pattern = gr.Dropdown( | |
| choices=["random", "spiral", "diagonal", "golden", "horizontal_raster", "horizontal_zigzag", "column_major", "column_snake"], | |
| value="random", | |
| label="Scan Pattern" | |
| ) | |
| seq_length = gr.Slider(minimum=1, maximum=4096, step=1, value=256, label="Auto Sequence Length") | |
| auto_btn = gr.Button("Auto Generate Path & Infer", variant="primary") | |
| gr.Examples( | |
| examples=[ | |
| "assets/dog.jpg", | |
| "assets/leo.jpg", | |
| "assets/green_mamba.jpg", | |
| ], | |
| inputs=input_image, | |
| outputs=[input_image, original_image_state, model_output_label, state, status_text, seq_len_display], | |
| fn=on_upload, | |
| run_on_click=True, | |
| cache_examples=False, | |
| label="Try an Example Image" | |
| ) | |
| with gr.Column(scale=1): | |
| seq_len_display.render() | |
| model_output_label.render() | |
| status_text.render() | |
| input_image.upload( | |
| fn=on_upload, | |
| inputs=[input_image], | |
| outputs=[input_image, original_image_state, model_output_label, state, status_text, seq_len_display] | |
| ) | |
| auto_btn.click( | |
| fn=run_auto_scan, | |
| inputs=[original_image_state, scan_pattern, seq_length], | |
| outputs=[input_image, model_output_label, state, status_text, seq_len_display] | |
| ) | |
| input_image.select( | |
| fn=on_click, | |
| inputs=[original_image_state, state], | |
| outputs=[input_image, model_output_label, state, status_text, seq_len_display] | |
| ) | |
| clear_btn.click( | |
| fn=on_clear, | |
| inputs=[original_image_state], | |
| outputs=[input_image, model_output_label, state, status_text, seq_len_display] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(theme=gr.themes.Soft(), ssr_mode=False, css=CSS_STYLE) | |