import gradio as gr import torch from PIL import Image from transformers import AutoProcessor, AutoModelForCausalLM # Use CPU as requested device = "cpu" def load_vlm(model_name): """Helper to load model and processor.""" try: print(f"Loading {model_name}...") model = AutoModelForCausalLM.from_pretrained( f'microsoft/{model_name}', trust_remote_code=True ).to(device).eval() processor = AutoProcessor.from_pretrained( f'microsoft/{model_name}', trust_remote_code=True ) return model, processor except Exception as e: print(f"Error loading {model_name}: {e}") return None, None # Load both models model_base, proc_base = load_vlm('Florence-2-base') model_large, proc_large = load_vlm('Florence-2-large') def describe_image(uploaded_image, model_choice): if uploaded_image is None: return "Please upload an image." # Select model based on UI choice if model_choice == "Florence-2-base": model, processor = model_base, proc_base else: model, processor = model_large, proc_large if model is None: return f"{model_choice} failed to load." if not isinstance(uploaded_image, Image.Image): uploaded_image = Image.fromarray(uploaded_image) # Core generation logic inputs = processor(text="", images=uploaded_image, return_tensors="pt").to(device) with torch.no_grad(): generated_ids = model.generate( input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, num_beams=3, do_sample=False, ) generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] result = processor.post_process_generation( generated_text, task="", image_size=(uploaded_image.width, uploaded_image.height) ) return result[""] # Simplified Gradio Layout css = ".submit-btn { background-color: #4682B4 !important; color: white !important; }" with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo: gr.Markdown("# **Florence-2 Models Image Captions**") gr.Markdown("> Select the model to use. **Base** is faster; **Large** is more accurate.") with gr.Row(): with gr.Column(): image_input = gr.Image(label="Upload Image", type="pil") model_choice = gr.Radio( choices=["Florence-2-base", "Florence-2-large"], label="Model Choice", value="Florence-2-base" ) generate_btn = gr.Button("Generate Caption", elem_classes="submit-btn") with gr.Column(): output = gr.Textbox(label="Generated Caption", lines=6, interactive=True) generate_btn.click( fn=describe_image, inputs=[image_input, model_choice], outputs=output ) if __name__ == "__main__": demo.launch(ssr_mode=False)