| import gradio as gr |
| import torch |
| import torch.nn.functional as F |
| import torchvision.transforms as transforms |
| from PIL import Image |
| import os |
| from ResNet_for_CC import CC_model |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| model_path = "CC_net.pt" |
| model = CC_model(num_classes1=14) |
|
|
| |
| state_dict = torch.load(model_path, map_location=device) |
| model.load_state_dict(state_dict, strict=False) |
| model.to(device) |
| model.eval() |
|
|
| |
| class_labels = [ |
| "T-Shirt", "Shirt", "Knitwear", "Chiffon", "Sweater", "Hoodie", |
| "Windbreaker", "Jacket", "Downcoat", "Suit", "Shawl", "Dress", |
| "Vest", "Underwear" |
| ] |
|
|
| |
| default_images = { |
| "Shawl": "shawlOG.webp", |
| "Jacket": "jacket.jpg", |
| "Sweater": "sweater.webp", |
| "Vest": "dress.jpg" |
| } |
|
|
| |
| default_images_gallery = [(path, label) for label, path in default_images.items()] |
|
|
| |
| def preprocess_image(image): |
| """Applies necessary transformations to the input image.""" |
| transform = transforms.Compose([ |
| transforms.Resize(256), |
| transforms.CenterCrop(224), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| ]) |
| return transform(image).unsqueeze(0).to(device) |
|
|
| |
| def classify_image(selected_default, uploaded_image): |
| """Processes either a default or uploaded image and returns the predicted clothing category.""" |
| try: |
| |
| if uploaded_image is not None: |
| image = Image.fromarray(uploaded_image) |
| else: |
| image_path = default_images[selected_default] |
| image = Image.open(image_path) |
|
|
| image = preprocess_image(image) |
| |
| with torch.no_grad(): |
| output = model(image) |
| if isinstance(output, tuple): |
| output = output[1] |
|
|
| probabilities = F.softmax(output, dim=1) |
| predicted_class = torch.argmax(probabilities, dim=1).item() |
|
|
| if 0 <= predicted_class < len(class_labels): |
| predicted_label = class_labels[predicted_class] |
| confidence = probabilities[0][predicted_class].item() * 100 |
| return f"Predicted Class: {predicted_label} (Confidence: {confidence:.2f}%)" |
| else: |
| return "[ERROR] Model returned an invalid class index." |
| |
| except Exception as e: |
| return f"Error in classification: {e}" |
|
|
| |
| with gr.Blocks() as interface: |
| gr.Markdown("# Clothing1M Image Classifier") |
| gr.Markdown("Upload a clothing image or select from the predefined images below.") |
|
|
| |
| gallery = gr.Gallery( |
| value=default_images_gallery, |
| label="Default Images", |
| elem_id="default_gallery" |
| ) |
|
|
| |
| default_selector = gr.Dropdown( |
| choices=list(default_images.keys()), |
| label="Select a Default Image", |
| value="Shawl" |
| ) |
|
|
| |
| image_upload = gr.Image(type="numpy", label="Or Upload Your Own Image") |
|
|
| |
| output_text = gr.Textbox(label="Classification Result") |
|
|
| |
| classify_button = gr.Button("Classify Image") |
|
|
| |
| classify_button.click( |
| fn=classify_image, |
| inputs=[default_selector, image_upload], |
| outputs=output_text |
| ) |
|
|
| |
| if __name__ == "__main__": |
| print("[INFO] Launching Gradio interface...") |
| interface.launch() |
|
|