| import subprocess |
| import sys |
|
|
| import torch |
| print("β
PyTorch version:", torch.__version__) |
|
|
|
|
| |
| try: |
| import torch |
| except ModuleNotFoundError: |
| print("π¨ Torch not found! Installing...") |
| subprocess.run([sys.executable, "-m", "pip", "install", "torch", "torchvision", "torchaudio"], check=True) |
| import torch |
|
|
|
|
| |
| class_labels = [ |
| "T-Shirt", "Shirt", "Knitwear", "Chiffon", "Sweater", "Hoodie", |
| "Windbreaker", "Jacket", "Downcoat", "Suit", "Shawl", "Dress", |
| "Vest", "Underwear" |
| ] |
|
|
| |
| def create_model_selfsup(net='resnet50', num_class=14, checkpoint_path='/content/ckpt_clothing_resnet50.pth'): |
| """Loads a self-supervised pretrained model for Clothing1M classification""" |
| print(f"π Loading model from: {checkpoint_path}") |
|
|
| |
| checkpoint = torch.load(checkpoint_path, map_location="cuda" if torch.cuda.is_available() else "cpu", weights_only=False) |
|
|
| |
| state_dict = {k.replace('module.', ''): v for k, v in checkpoint['model'].items()} |
|
|
| |
| model = SupCEResNet(net, num_classes=num_class, pool=True) |
| model.load_state_dict(state_dict, strict=False) |
|
|
| |
| model = model.to("cuda" if torch.cuda.is_available() else "cpu") |
| model.eval() |
|
|
| print("β
Model loaded successfully!") |
| return model |
|
|
| |
| model = create_model_selfsup() |
|
|
| |
| def preprocess_image(image): |
| """Transforms input image for the model""" |
| transform = transforms.Compose([ |
| transforms.Resize(256), |
| transforms.CenterCrop(224), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| ]) |
| return transform(image).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| def predict_clothing(image): |
| """Runs inference on an uploaded image""" |
| image = Image.fromarray(image) |
| image = preprocess_image(image) |
|
|
| with torch.no_grad(): |
| output = model(image) |
| predicted_class = torch.argmax(output, dim=1).item() |
|
|
| return class_labels[predicted_class] |
|
|
| |
| gr.Interface( |
| fn=predict_clothing, |
| inputs=gr.Image(type="numpy"), |
| outputs=gr.Textbox(label="Predicted Clothing Type"), |
| title="Clothing1M Classification", |
| description="Upload an image to classify clothing into one of 14 categories." |
| ).launch() |
|
|