File size: 5,133 Bytes
b8c80e7
 
 
 
 
5e8eb58
 
b8c80e7
 
 
 
 
 
 
5e8eb58
 
 
b8c80e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e8eb58
 
 
 
 
 
 
 
 
 
 
b8c80e7
 
 
 
 
 
 
 
 
5e8eb58
 
 
b8c80e7
 
 
 
 
 
 
 
 
 
 
 
 
5e8eb58
b8c80e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf0941e
b8c80e7
 
 
 
 
 
 
 
 
 
 
5e8eb58
 
b8c80e7
 
 
3715931
1aa8dc3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""DermNet-Skin23 classifier — single-model EVA-02-L Gradio demo.

Loads `best_ema_v2.pt` from iamcode6/dermnet-skin23-eva02 on startup and serves
predictions via Gradio. Runs on CPU (free HF Space tier).
"""
import os

import gradio as gr
import torch
import timm
from huggingface_hub import hf_hub_download
from PIL import Image
from timm.data import create_transform

# Use both vCPUs on the free Space tier
torch.set_num_threads(int(os.environ.get("TORCH_NUM_THREADS", 2)))

# ==============================================================================
# Load model from HF Hub at startup
# ==============================================================================
CKPT_REPO = "iamcode6/dermnet-skin23-eva02"
CKPT_FILE = "best_ema_v2.pt"

print(f"Downloading {CKPT_REPO}/{CKPT_FILE}...")
ckpt_path = hf_hub_download(repo_id=CKPT_REPO, filename=CKPT_FILE)
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)

MODEL_NAME = ckpt["model_name"]
IMG_SIZE = ckpt["img_size"]
class_to_idx = ckpt["class_to_idx"]
idx_to_class = {v: k for k, v in class_to_idx.items()}
NUM_CLASSES = len(class_to_idx)

print(f"Building {MODEL_NAME} @ {IMG_SIZE}, {NUM_CLASSES} classes...")
model = timm.create_model(
    MODEL_NAME, pretrained=False, num_classes=NUM_CLASSES, img_size=IMG_SIZE
)
model.load_state_dict(ckpt["model_state_dict"])
model.eval()

device = torch.device("cpu")
model = model.to(device)
print(f"Model ready (val acc reported in checkpoint: "
      f"{ckpt.get('val_acc', 0)*100:.2f}%, F1: {ckpt.get('val_macro_f1', 0):.4f})")

val_tf = create_transform(
    input_size=IMG_SIZE, is_training=False, crop_pct=0.95,
    interpolation="bicubic",
    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225],
)


# ==============================================================================
# Warm up the model (first forward pass is JIT-slower; pre-pay it at startup)
# ==============================================================================
print("Warming up model...")
with torch.no_grad():
    dummy = torch.zeros(1, 3, IMG_SIZE, IMG_SIZE)
    _ = model(dummy)
print("Warmup done — ready to serve.")


# ==============================================================================
# Inference — single forward pass (no TTA on free CPU; would 2x latency)
# ==============================================================================
@torch.no_grad()
def predict(image):
    if image is None:
        return {}
    if not isinstance(image, Image.Image):
        image = Image.fromarray(image)
    image = image.convert("RGB")

    x = val_tf(image).unsqueeze(0).to(device)
    logits = model(x)
    probs = torch.softmax(logits, dim=-1).squeeze(0)
    return {idx_to_class[i]: float(probs[i]) for i in range(NUM_CLASSES)}


# ==============================================================================
# UI
# ==============================================================================
DESCRIPTION = """# DermNet-Skin23 Classifier

23-class clinical skin disease classifier built on **EVA-02-L** (~304M params, ViT-L/14) fine-tuned on a consolidated DermNet + Skin40 dataset (17,557 train images).

**Single-model accuracy**: **81.48%** acc / **0.7969** macro F1 on a 3,856-image val split.
**Full 5-model ensemble reaches 82.86% / 0.8113** — see the ensemble repo linked below.

Upload a clinical or dermoscopy photo, click **Classify**, and the model returns the top-5 most likely categories with calibrated confidence. **Inference takes ~30-45 seconds** on the free CPU tier — large vision transformer at 448×448 resolution.

---

⚠️ **NOT a diagnostic tool.** For research and educational purposes only. Always consult a qualified dermatologist for medical concerns. The 23 categories are coarse merges (e.g., all Tinea variants under one bucket), and the model has known weaknesses on rare conditions like Cellulitis (F1 0.69) and over-broad categories like Acne and Rosacea (precision 0.66).
"""

LINKS = """
**Project artifacts on Hugging Face**:
- [Single-model checkpoints (this Space's model)](https://huggingface.co/iamcode6/dermnet-skin23-eva02)
- [ConvNeXt ensemble partner + ensemble script](https://huggingface.co/iamcode6/dermnet-skin23-convnext)
- [DINOv2-Giant 3rd backbone](https://huggingface.co/iamcode6/dermnet-skin23-dinov2g)

**Trained on**: AMD Instinct MI300X (192 GB HBM3) via DigitalOcean, ROCm 7.0, PyTorch with HIP.
"""

with gr.Blocks(title="DermNet-Skin23 Classifier", theme=gr.themes.Soft()) as demo:
    gr.Markdown(DESCRIPTION)

    with gr.Row():
        with gr.Column(scale=1):
            input_image = gr.Image(type="pil", label="Upload skin image", height=380)
            predict_btn = gr.Button("Classify", variant="primary", size="lg")
        with gr.Column(scale=1):
            output_label = gr.Label(num_top_classes=5, label="Top-5 predictions")

    gr.Markdown(LINKS)

    # Only run on explicit button click — avoids racing the Gradio queue
    # if a user uploads multiple images quickly
    predict_btn.click(predict, inputs=input_image, outputs=output_label)

if __name__ == "__main__":
    demo.launch(ssr_mode=False)