import gradio as gr from transformers import AutoModelForImageClassification, AutoFeatureExtractor import torch from PIL import Image import numpy as np import json import requests # 加载模型和特征提取器 model_name = "microsoft/beit-base-patch16-224" model = AutoModelForImageClassification.from_pretrained(model_name) feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) # 获取 ImageNet 类别映射 LABELS_URL = "https://storage.googleapis.com/bit_models/imagenet21k_wordnet_id_map.json" imagenet_classes = requests.get(LABELS_URL).json() # 定义分类函数 def classify_image(image): # 转换 PIL Image 为 numpy 数组 if isinstance(image, Image.Image): image = np.array(image) # 进行特征提取 inputs = feature_extractor(images=image, return_tensors="pt") # 预测类别 with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits predicted_class_idx = logits.argmax(-1).item() # 获取类别名称 class_name = imagenet_classes.get(str(predicted_class_idx), "Unknown") return f"Predicted class: {class_name} (ID: {predicted_class_idx})" # 创建 Gradio 界面 demo = gr.Interface(fn=classify_image, inputs="image", outputs="text", title="Image Classification Demo") demo.launch()