| """ |
| Gradio App for Bird Species Classification |
| Deployed on Hugging Face Spaces |
| """ |
|
|
| import gradio as gr |
| import torch |
| import torch.nn as nn |
| from torchvision import transforms |
| from torchvision.models import convnext_base |
| from PIL import Image |
| import json |
|
|
| |
| with open('class_names.json', 'r') as f: |
| class_names = json.load(f) |
|
|
| |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| |
| def create_model(num_classes=200): |
| """Create ConvNeXt model with same architecture as training""" |
| model = convnext_base(weights=None) |
| |
| |
| num_ftrs = model.classifier[2].in_features |
| model.classifier = nn.Sequential( |
| nn.Flatten(1), |
| nn.LayerNorm((num_ftrs,)), |
| nn.Dropout(0.6), |
| nn.Linear(num_ftrs, 512), |
| nn.GELU(), |
| nn.Dropout(0.5), |
| nn.Linear(512, num_classes) |
| ) |
| |
| return model |
|
|
| |
| print("Loading model...") |
| model = create_model(num_classes=200) |
|
|
| |
| import gradio as gr |
| import torch |
| import torch.nn as nn |
| from torchvision import transforms |
| from torchvision.models import convnext_base |
| from PIL import Image |
| import json |
| from huggingface_hub import hf_hub_download |
|
|
| |
| with open('class_names.json', 'r') as f: |
| class_names = json.load(f) |
|
|
| |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| |
| def create_model(num_classes=200): |
| """Create ConvNeXt model with same architecture as training""" |
| model = convnext_base(weights=None) |
| |
| |
| num_ftrs = model.classifier[2].in_features |
| model.classifier = nn.Sequential( |
| nn.Flatten(1), |
| nn.LayerNorm((num_ftrs,)), |
| nn.Dropout(0.6), |
| nn.Linear(num_ftrs, 512), |
| nn.GELU(), |
| nn.Dropout(0.5), |
| nn.Linear(512, num_classes) |
| ) |
| |
| return model |
|
|
| |
| print("Downloading model from Hugging Face Model Hub...") |
| model_path = hf_hub_download( |
| repo_id="AshProg/bird-classifier-convnext", |
| filename="final_model.pth" |
| ) |
|
|
| |
| model = create_model(num_classes=200) |
| checkpoint = torch.load(model_path, map_location=device) |
| if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: |
| model.load_state_dict(checkpoint['model_state_dict']) |
| if 'val_acc' in checkpoint: |
| val_acc = checkpoint['val_acc'] |
| print(f"Model loaded! Validation accuracy: {val_acc:.2f}%") |
| else: |
| model.load_state_dict(checkpoint) |
| print("Model loaded!") |
|
|
| model = model.to(device) |
| model.eval() |
|
|
| |
| transform = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| ]) |
|
|
| def predict(image): |
| """ |
| Make prediction on uploaded image |
| |
| Args: |
| image: PIL Image |
| |
| Returns: |
| dict: Top 5 predictions with confidence scores |
| """ |
| |
| img_tensor = transform(image).unsqueeze(0).to(device) |
| |
| |
| with torch.no_grad(): |
| outputs = model(img_tensor) |
| probabilities = torch.nn.functional.softmax(outputs, dim=1) |
| |
| |
| top5_prob, top5_idx = torch.topk(probabilities, 5) |
| |
| |
| results = {} |
| for i in range(5): |
| class_id = top5_idx[0][i].item() |
| prob = top5_prob[0][i].item() |
| species_name = class_names.get(str(class_id), f"Class {class_id}") |
| results[species_name] = float(prob) |
| |
| return results |
|
|
| |
| title = "🐦 Bird Species Classification" |
| description = """ |
| Upload an image of a bird and the model will predict the species! |
| |
| **Model Details:** |
| - Architecture: ConvNeXt-Base (87M parameters) |
| - Dataset: CUB-200-2011 (200 bird species) |
| - Test Accuracy: 83.64% |
| - Average Per-Class Accuracy: 83.29% |
| |
| Upload a clear image of a bird to get started! |
| """ |
|
|
| article = """ |
| ### About This Model |
| |
| This bird classifier was trained on the CUB-200-2011 dataset containing 200 North American bird species. |
| |
| **Key Features:** |
| - ✅ 200 bird species classification |
| - ✅ State-of-the-art ConvNeXt architecture |
| - ✅ 83.64% test accuracy |
| - ✅ Real-time inference |
| |
| """ |
|
|
| examples = [ |
| |
| |
| |
| ] |
|
|
| |
| iface = gr.Interface( |
| fn=predict, |
| inputs=gr.Image(type="pil", label="Upload Bird Image"), |
| outputs=gr.Label(num_top_classes=5, label="Top 5 Predictions"), |
| title=title, |
| description=description, |
| article=article, |
| examples=examples if examples else None, |
| theme=gr.themes.Soft(), |
| allow_flagging="never", |
| ) |
|
|
| |
| if __name__ == "__main__": |
| iface.launch() |
|
|