| from __future__ import annotations |
|
|
| import argparse |
| import sys |
| from threading import Lock |
| from pathlib import Path |
|
|
| from PIL import Image |
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parents[1])) |
|
|
| from cropvlm import CROP_CLASSES, load_cropvlm |
| from cropvlm.model import parse_class_names |
|
|
|
|
| DEFAULT_CLASSES_TEXT = "\n".join(CROP_CLASSES) |
|
|
|
|
| def build_demo(checkpoint: str, device: str | None, prompt_template: str, top_k: int) -> gr.Blocks: |
| import gradio as gr |
|
|
| classifier = load_cropvlm( |
| checkpoint=checkpoint, |
| class_names=CROP_CLASSES, |
| device=device, |
| prompt_template=prompt_template, |
| ) |
| classifier_lock = Lock() |
| current_classes = tuple(CROP_CLASSES) |
|
|
| def classify(image: Image.Image, classes_text: str, top_k_value: int): |
| if image is None: |
| return {}, [] |
| nonlocal current_classes |
| requested_classes = tuple(parse_class_names(classes_text)) |
| if not requested_classes: |
| return {}, [] |
| with classifier_lock: |
| if requested_classes != current_classes: |
| classifier.set_classes(requested_classes) |
| current_classes = requested_classes |
| predictions = classifier.predict_with_scores(image, top_k=int(top_k_value)) |
|
|
| label_scores = {label: probability for label, probability, _ in predictions} |
| score_text = "\n".join( |
| f"{rank}. {label}: probability={probability:.6f}, cosine={cosine:.6f}" |
| for rank, (label, probability, cosine) in enumerate(predictions, start=1) |
| ) |
| return label_scores, score_text |
|
|
| examples_dir = Path(__file__).resolve().parents[1] / "examples" |
| example_paths = [ |
| str(examples_dir / name) |
| for name in ["cacao.png", "olive.png", "cauliflower.png", "sugarcane.png", "sunflower.png"] |
| if (examples_dir / name).exists() |
| ] |
|
|
| with gr.Blocks(title="CropVLM Zero-Shot Demo") as demo: |
| gr.Markdown("# CropVLM Zero-Shot Image Classification") |
| with gr.Row(): |
| with gr.Column(): |
| image = gr.Image(type="pil", label="Image") |
| classes = gr.Textbox( |
| value=DEFAULT_CLASSES_TEXT, |
| lines=12, |
| label="Class names", |
| ) |
| top_k_slider = gr.Slider( |
| minimum=1, |
| maximum=10, |
| value=top_k, |
| step=1, |
| label="Top-k", |
| ) |
| button = gr.Button("Classify", variant="primary") |
| with gr.Column(): |
| label = gr.Label(num_top_classes=top_k, label="Predictions") |
| score_text = gr.Textbox( |
| label="Scores", |
| lines=8, |
| interactive=False, |
| ) |
|
|
| outputs = [label, score_text] |
| button.click(classify, inputs=[image, classes, top_k_slider], outputs=outputs) |
| classes.change(lambda: ({}, ""), outputs=outputs) |
|
|
| if example_paths: |
| gr.Examples( |
| examples=[[path, DEFAULT_CLASSES_TEXT, top_k] for path in example_paths], |
| inputs=[image, classes, top_k_slider], |
| outputs=outputs, |
| fn=classify, |
| cache_examples=False, |
| ) |
|
|
| return demo |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--checkpoint", default="models/CropVLM.pth") |
| parser.add_argument("--device", default=None) |
| parser.add_argument("--prompt-template", default="{}") |
| parser.add_argument("--top-k", type=int, default=5) |
| parser.add_argument("--server-name", default="127.0.0.1") |
| parser.add_argument("--server-port", type=int, default=7860) |
| args = parser.parse_args() |
|
|
| demo = build_demo( |
| checkpoint=args.checkpoint, |
| device=args.device, |
| prompt_template=args.prompt_template, |
| top_k=args.top_k, |
| ) |
| demo.launch(server_name=args.server_name, server_port=args.server_port) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|