Spaces:
Running
Running
File size: 5,133 Bytes
b8c80e7 5e8eb58 b8c80e7 5e8eb58 b8c80e7 5e8eb58 b8c80e7 5e8eb58 b8c80e7 5e8eb58 b8c80e7 cf0941e b8c80e7 5e8eb58 b8c80e7 3715931 1aa8dc3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 | """DermNet-Skin23 classifier — single-model EVA-02-L Gradio demo.
Loads `best_ema_v2.pt` from iamcode6/dermnet-skin23-eva02 on startup and serves
predictions via Gradio. Runs on CPU (free HF Space tier).
"""
import os
import gradio as gr
import torch
import timm
from huggingface_hub import hf_hub_download
from PIL import Image
from timm.data import create_transform
# Use both vCPUs on the free Space tier
torch.set_num_threads(int(os.environ.get("TORCH_NUM_THREADS", 2)))
# ==============================================================================
# Load model from HF Hub at startup
# ==============================================================================
CKPT_REPO = "iamcode6/dermnet-skin23-eva02"
CKPT_FILE = "best_ema_v2.pt"
print(f"Downloading {CKPT_REPO}/{CKPT_FILE}...")
ckpt_path = hf_hub_download(repo_id=CKPT_REPO, filename=CKPT_FILE)
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
MODEL_NAME = ckpt["model_name"]
IMG_SIZE = ckpt["img_size"]
class_to_idx = ckpt["class_to_idx"]
idx_to_class = {v: k for k, v in class_to_idx.items()}
NUM_CLASSES = len(class_to_idx)
print(f"Building {MODEL_NAME} @ {IMG_SIZE}, {NUM_CLASSES} classes...")
model = timm.create_model(
MODEL_NAME, pretrained=False, num_classes=NUM_CLASSES, img_size=IMG_SIZE
)
model.load_state_dict(ckpt["model_state_dict"])
model.eval()
device = torch.device("cpu")
model = model.to(device)
print(f"Model ready (val acc reported in checkpoint: "
f"{ckpt.get('val_acc', 0)*100:.2f}%, F1: {ckpt.get('val_macro_f1', 0):.4f})")
val_tf = create_transform(
input_size=IMG_SIZE, is_training=False, crop_pct=0.95,
interpolation="bicubic",
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225],
)
# ==============================================================================
# Warm up the model (first forward pass is JIT-slower; pre-pay it at startup)
# ==============================================================================
print("Warming up model...")
with torch.no_grad():
dummy = torch.zeros(1, 3, IMG_SIZE, IMG_SIZE)
_ = model(dummy)
print("Warmup done — ready to serve.")
# ==============================================================================
# Inference — single forward pass (no TTA on free CPU; would 2x latency)
# ==============================================================================
@torch.no_grad()
def predict(image):
if image is None:
return {}
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
image = image.convert("RGB")
x = val_tf(image).unsqueeze(0).to(device)
logits = model(x)
probs = torch.softmax(logits, dim=-1).squeeze(0)
return {idx_to_class[i]: float(probs[i]) for i in range(NUM_CLASSES)}
# ==============================================================================
# UI
# ==============================================================================
DESCRIPTION = """# DermNet-Skin23 Classifier
23-class clinical skin disease classifier built on **EVA-02-L** (~304M params, ViT-L/14) fine-tuned on a consolidated DermNet + Skin40 dataset (17,557 train images).
**Single-model accuracy**: **81.48%** acc / **0.7969** macro F1 on a 3,856-image val split.
**Full 5-model ensemble reaches 82.86% / 0.8113** — see the ensemble repo linked below.
Upload a clinical or dermoscopy photo, click **Classify**, and the model returns the top-5 most likely categories with calibrated confidence. **Inference takes ~30-45 seconds** on the free CPU tier — large vision transformer at 448×448 resolution.
---
⚠️ **NOT a diagnostic tool.** For research and educational purposes only. Always consult a qualified dermatologist for medical concerns. The 23 categories are coarse merges (e.g., all Tinea variants under one bucket), and the model has known weaknesses on rare conditions like Cellulitis (F1 0.69) and over-broad categories like Acne and Rosacea (precision 0.66).
"""
LINKS = """
**Project artifacts on Hugging Face**:
- [Single-model checkpoints (this Space's model)](https://huggingface.co/iamcode6/dermnet-skin23-eva02)
- [ConvNeXt ensemble partner + ensemble script](https://huggingface.co/iamcode6/dermnet-skin23-convnext)
- [DINOv2-Giant 3rd backbone](https://huggingface.co/iamcode6/dermnet-skin23-dinov2g)
**Trained on**: AMD Instinct MI300X (192 GB HBM3) via DigitalOcean, ROCm 7.0, PyTorch with HIP.
"""
with gr.Blocks(title="DermNet-Skin23 Classifier", theme=gr.themes.Soft()) as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="Upload skin image", height=380)
predict_btn = gr.Button("Classify", variant="primary", size="lg")
with gr.Column(scale=1):
output_label = gr.Label(num_top_classes=5, label="Top-5 predictions")
gr.Markdown(LINKS)
# Only run on explicit button click — avoids racing the Gradio queue
# if a user uploads multiple images quickly
predict_btn.click(predict, inputs=input_image, outputs=output_label)
if __name__ == "__main__":
demo.launch(ssr_mode=False)
|