| from typing import List |
|
|
| import gradio as gr |
| import numpy as np |
| import torch |
| from transformers import CLIPProcessor, CLIPModel |
|
|
| IMAGENET_CLASSES_FILE = "imagenet-classes.txt" |
| EXAMPLES = ["dog.jpeg", "car.png"] |
|
|
| MARKDOWN = """ |
| # Zero-Shot Image Classification with MetaCLIP |
| |
| This is the demo for a zero-shot image classification model based on |
| [MetaCLIP](https://github.com/facebookresearch/MetaCLIP), described in the paper |
| [Demystifying CLIP Data](https://arxiv.org/abs/2309.16671) that formalizes CLIP data |
| curation as a simple algorithm. |
| """ |
|
|
|
|
| def load_text_lines(file_path: str) -> List[str]: |
| with open(file_path, 'r') as file: |
| lines = file.readlines() |
| return [line.rstrip() for line in lines] |
|
|
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model = CLIPModel.from_pretrained("facebook/metaclip-b32-400m").to(device) |
| processor = CLIPProcessor.from_pretrained("facebook/metaclip-b32-400m") |
| imagenet_classes = load_text_lines(IMAGENET_CLASSES_FILE) |
|
|
|
|
| def classify_image(input_image) -> str: |
| inputs = processor( |
| text=imagenet_classes, |
| images=input_image, |
| return_tensors="pt", |
| padding=True).to(device) |
| outputs = model(**inputs) |
| probs = outputs.logits_per_image.softmax(dim=1) |
| class_index = np.argmax(probs.detach().cpu().numpy()) |
| return imagenet_classes[class_index] |
|
|
|
|
| with gr.Blocks() as demo: |
| gr.Markdown(MARKDOWN) |
| with gr.Row(): |
| image = gr.Image(image_mode='RGB', type='pil') |
| output_text = gr.Textbox(label="Output") |
| submit_button = gr.Button("Submit") |
|
|
| submit_button.click(classify_image, inputs=[image], outputs=output_text) |
|
|
| gr.Examples( |
| examples=EXAMPLES, |
| fn=classify_image, |
| inputs=[image], |
| outputs=[output_text], |
| cache_examples=True, |
| run_on_click=True |
| ) |
|
|
| demo.launch(debug=False) |
|
|