| import gradio as gr |
| import torch |
| import torch.nn as nn |
| import torchvision.transforms as transforms |
| from torchvision import models |
| import numpy as np |
| from PIL import Image |
| import os |
|
|
| |
| CLASS_NAMES = [ |
| '707-320', '737-400', '767-300', 'DC-9-30', 'DH-82', |
| 'Falcon_2000', 'Il-76', 'MD-11', 'Metroliner', 'PA-28' |
| ] |
|
|
| class AircraftClassifier(nn.Module): |
| """ResNet-18 based aircraft classifier""" |
| def __init__(self, num_classes=10): |
| super(AircraftClassifier, self).__init__() |
| |
| self.backbone = models.resnet18(pretrained=True) |
| |
| num_features = self.backbone.fc.in_features |
| self.backbone.fc = nn.Linear(num_features, num_classes) |
| |
| def forward(self, x): |
| return self.backbone(x) |
|
|
| |
| def get_transforms(): |
| """Get image preprocessing transforms""" |
| return transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], |
| std=[0.229, 0.224, 0.225]) |
| ]) |
|
|
| |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| model = AircraftClassifier(num_classes=len(CLASS_NAMES)) |
|
|
| |
| model_path = 'models/aircraft_classifier.pth' |
| if os.path.exists(model_path): |
| try: |
| model.load_state_dict(torch.load(model_path, map_location=device)) |
| print(f"β
Loaded trained model from {model_path}") |
| except Exception as e: |
| print(f"β οΈ Could not load trained model: {e}") |
| print("Using random weights - please train the model first!") |
| else: |
| print(f"β οΈ Model file not found at {model_path}") |
| print("Using random weights - please train the model first!") |
|
|
| model = model.to(device) |
| model.eval() |
|
|
| |
| transform = get_transforms() |
|
|
| def classify_aircraft(image): |
| """ |
| Classify an aircraft image |
| |
| Args: |
| image: PIL Image or numpy array |
| |
| Returns: |
| dict: Classification results with confidence scores |
| """ |
| try: |
| |
| if isinstance(image, np.ndarray): |
| image = Image.fromarray(image) |
| |
| |
| if image.mode != 'RGB': |
| image = image.convert('RGB') |
| |
| |
| input_tensor = transform(image).unsqueeze(0).to(device) |
| |
| |
| with torch.no_grad(): |
| outputs = model(input_tensor) |
| probabilities = torch.softmax(outputs, dim=1) |
| |
| |
| probs = probabilities.cpu().numpy()[0] |
| |
| |
| results = {} |
| for i, class_name in enumerate(CLASS_NAMES): |
| results[class_name] = float(probs[i]) |
| |
| return results |
| |
| except Exception as e: |
| print(f"Error in classification: {e}") |
| |
| return {class_name: 0.0 for class_name in CLASS_NAMES} |
|
|
| def get_top_predictions(image): |
| """ |
| Get top 3 predictions with confidence scores |
| |
| Args: |
| image: PIL Image or numpy array |
| |
| Returns: |
| str: Formatted string with top predictions |
| """ |
| try: |
| results = classify_aircraft(image) |
| |
| |
| sorted_results = sorted(results.items(), key=lambda x: x[1], reverse=True) |
| |
| |
| output_text = "π― **Top Predictions:**\n\n" |
| for i, (class_name, confidence) in enumerate(sorted_results[:3]): |
| confidence_percent = confidence * 100 |
| output_text += f"{i+1}. **{class_name}**: {confidence_percent:.2f}%\n" |
| |
| return output_text |
| |
| except Exception as e: |
| return f"β Error during classification: {str(e)}" |
|
|
| |
| def create_interface(): |
| """Create and configure the Gradio interface""" |
| |
| |
| css = """ |
| .gradio-container { |
| max-width: 900px !important; |
| margin: auto !important; |
| } |
| .title { |
| text-align: center; |
| font-size: 2.5em; |
| font-weight: bold; |
| margin-bottom: 0.5em; |
| } |
| .description { |
| text-align: center; |
| font-size: 1.2em; |
| color: #666; |
| margin-bottom: 2em; |
| } |
| """ |
| |
| with gr.Blocks(css=css, title="Aircraft Classifier") as iface: |
| |
| gr.HTML(""" |
| <div class="title">π©οΈ Aircraft Classifier</div> |
| <div class="description"> |
| Fine-grained aircraft classification using deep learning<br> |
| Upload an image to classify it into one of 10 aircraft types |
| </div> |
| """) |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| |
| input_image = gr.Image( |
| type="pil", |
| label="Upload Aircraft Image", |
| height=400 |
| ) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| with gr.Column(scale=1): |
| |
| classification_output = gr.Label( |
| label="π― Classification Results", |
| num_top_classes=10 |
| ) |
| |
| |
| top_predictions = gr.Textbox( |
| label="π Detailed Results", |
| lines=6, |
| interactive=False |
| ) |
| |
| |
| gr.HTML(""" |
| <div style="margin-top: 2em; padding: 1em; background-color: #f8f9fa; border-radius: 8px;"> |
| <h3>π§ Model Information</h3> |
| <ul> |
| <li><b>Architecture:</b> ResNet-18 with transfer learning</li> |
| <li><b>Dataset:</b> FGVC-Aircraft (10 classes)</li> |
| <li><b>Accuracy:</b> 87.17% on test set</li> |
| <li><b>Classes:</b> 707-320, 737-400, 767-300, DC-9-30, DH-82, Falcon_2000, Il-76, MD-11, Metroliner, PA-28</li> |
| </ul> |
| </div> |
| """) |
| |
| |
| input_image.change( |
| fn=classify_aircraft, |
| inputs=[input_image], |
| outputs=[classification_output] |
| ) |
| |
| input_image.change( |
| fn=get_top_predictions, |
| inputs=[input_image], |
| outputs=[top_predictions] |
| ) |
| |
| return iface |
|
|
| |
| if __name__ == "__main__": |
| print("π Starting Aircraft Classifier Gradio Interface...") |
| print(f"π± Device: {device}") |
| print(f"π― Classes: {len(CLASS_NAMES)}") |
| |
| |
| iface = create_interface() |
| iface.launch( |
| share=True, |
| server_name="0.0.0.0", |
| server_port=7860, |
| show_error=True |
| ) |