James040's picture
Update app.py
e558c0a verified
import gradio as gr
import torch
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM
# Use CPU as requested
device = "cpu"
def load_vlm(model_name):
"""Helper to load model and processor."""
try:
print(f"Loading {model_name}...")
model = AutoModelForCausalLM.from_pretrained(
f'microsoft/{model_name}',
trust_remote_code=True
).to(device).eval()
processor = AutoProcessor.from_pretrained(
f'microsoft/{model_name}',
trust_remote_code=True
)
return model, processor
except Exception as e:
print(f"Error loading {model_name}: {e}")
return None, None
# Load both models
model_base, proc_base = load_vlm('Florence-2-base')
model_large, proc_large = load_vlm('Florence-2-large')
def describe_image(uploaded_image, model_choice):
if uploaded_image is None:
return "Please upload an image."
# Select model based on UI choice
if model_choice == "Florence-2-base":
model, processor = model_base, proc_base
else:
model, processor = model_large, proc_large
if model is None:
return f"{model_choice} failed to load."
if not isinstance(uploaded_image, Image.Image):
uploaded_image = Image.fromarray(uploaded_image)
# Core generation logic
inputs = processor(text="<MORE_DETAILED_CAPTION>", images=uploaded_image, return_tensors="pt").to(device)
with torch.no_grad():
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
num_beams=3,
do_sample=False,
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
result = processor.post_process_generation(
generated_text,
task="<MORE_DETAILED_CAPTION>",
image_size=(uploaded_image.width, uploaded_image.height)
)
return result["<MORE_DETAILED_CAPTION>"]
# Simplified Gradio Layout
css = ".submit-btn { background-color: #4682B4 !important; color: white !important; }"
with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
gr.Markdown("# **Florence-2 Models Image Captions**")
gr.Markdown("> Select the model to use. **Base** is faster; **Large** is more accurate.")
with gr.Row():
with gr.Column():
image_input = gr.Image(label="Upload Image", type="pil")
model_choice = gr.Radio(
choices=["Florence-2-base", "Florence-2-large"],
label="Model Choice",
value="Florence-2-base"
)
generate_btn = gr.Button("Generate Caption", elem_classes="submit-btn")
with gr.Column():
output = gr.Textbox(label="Generated Caption", lines=6, interactive=True)
generate_btn.click(
fn=describe_image,
inputs=[image_input, model_choice],
outputs=output
)
if __name__ == "__main__":
demo.launch(ssr_mode=False)