CropVLM / scripts /gradio_demo.py
boudiafA's picture
Update links and checkpoint path
fe65c5f verified
from __future__ import annotations
import argparse
import sys
from threading import Lock
from pathlib import Path
from PIL import Image
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from cropvlm import CROP_CLASSES, load_cropvlm
from cropvlm.model import parse_class_names
DEFAULT_CLASSES_TEXT = "\n".join(CROP_CLASSES)
def build_demo(checkpoint: str, device: str | None, prompt_template: str, top_k: int) -> gr.Blocks:
import gradio as gr
classifier = load_cropvlm(
checkpoint=checkpoint,
class_names=CROP_CLASSES,
device=device,
prompt_template=prompt_template,
)
classifier_lock = Lock()
current_classes = tuple(CROP_CLASSES)
def classify(image: Image.Image, classes_text: str, top_k_value: int):
if image is None:
return {}, []
nonlocal current_classes
requested_classes = tuple(parse_class_names(classes_text))
if not requested_classes:
return {}, []
with classifier_lock:
if requested_classes != current_classes:
classifier.set_classes(requested_classes)
current_classes = requested_classes
predictions = classifier.predict_with_scores(image, top_k=int(top_k_value))
label_scores = {label: probability for label, probability, _ in predictions}
score_text = "\n".join(
f"{rank}. {label}: probability={probability:.6f}, cosine={cosine:.6f}"
for rank, (label, probability, cosine) in enumerate(predictions, start=1)
)
return label_scores, score_text
examples_dir = Path(__file__).resolve().parents[1] / "examples"
example_paths = [
str(examples_dir / name)
for name in ["cacao.png", "olive.png", "cauliflower.png", "sugarcane.png", "sunflower.png"]
if (examples_dir / name).exists()
]
with gr.Blocks(title="CropVLM Zero-Shot Demo") as demo:
gr.Markdown("# CropVLM Zero-Shot Image Classification")
with gr.Row():
with gr.Column():
image = gr.Image(type="pil", label="Image")
classes = gr.Textbox(
value=DEFAULT_CLASSES_TEXT,
lines=12,
label="Class names",
)
top_k_slider = gr.Slider(
minimum=1,
maximum=10,
value=top_k,
step=1,
label="Top-k",
)
button = gr.Button("Classify", variant="primary")
with gr.Column():
label = gr.Label(num_top_classes=top_k, label="Predictions")
score_text = gr.Textbox(
label="Scores",
lines=8,
interactive=False,
)
outputs = [label, score_text]
button.click(classify, inputs=[image, classes, top_k_slider], outputs=outputs)
classes.change(lambda: ({}, ""), outputs=outputs)
if example_paths:
gr.Examples(
examples=[[path, DEFAULT_CLASSES_TEXT, top_k] for path in example_paths],
inputs=[image, classes, top_k_slider],
outputs=outputs,
fn=classify,
cache_examples=False,
)
return demo
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", default="models/CropVLM.pth")
parser.add_argument("--device", default=None)
parser.add_argument("--prompt-template", default="{}")
parser.add_argument("--top-k", type=int, default=5)
parser.add_argument("--server-name", default="127.0.0.1")
parser.add_argument("--server-port", type=int, default=7860)
args = parser.parse_args()
demo = build_demo(
checkpoint=args.checkpoint,
device=args.device,
prompt_template=args.prompt_template,
top_k=args.top_k,
)
demo.launch(server_name=args.server_name, server_port=args.server_port)
if __name__ == "__main__":
main()